leo_passes/write_transforming/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{
20    ArrayAccess,
21    AssignStatement,
22    AstVisitor,
23    DefinitionPlace,
24    DefinitionStatement,
25    Expression,
26    Identifier,
27    IntegerType,
28    Literal,
29    Location,
30    MemberAccess,
31    Node as _,
32    Path,
33    Program,
34    Statement,
35    Type,
36};
37use leo_span::Symbol;
38
39use indexmap::IndexMap;
40
41/// This visitor associates a variable for each member of a struct or array that is
42/// written to. Whenever a member of the struct or array is written to, we change the
43/// assignment to access the variable instead. Whenever the struct or array itself
44/// is accessed, we first rebuild the struct or array from its variables.
45pub struct WriteTransformingVisitor<'a> {
46    pub state: &'a mut CompilerState,
47
48    /// For any struct whose members are written to, a map of its field names to variables
49    /// corresponding to the members.
50    pub struct_members: IndexMap<Symbol, IndexMap<Symbol, Identifier>>,
51
52    /// For any array whose members are written to, a vec containing the variables for each index.
53    pub array_members: IndexMap<Symbol, Vec<Identifier>>,
54
55    pub program: Symbol,
56}
57
58impl<'a> WriteTransformingVisitor<'a> {
59    pub fn new(state: &'a mut CompilerState, program: &Program) -> Self {
60        let visitor = WriteTransformingVisitor {
61            state,
62            struct_members: Default::default(),
63            array_members: Default::default(),
64            program: Symbol::intern(""),
65        };
66
67        // We need to do an initial pass through the AST to identify all arrays and structs that are written to.
68        let mut wtf = WriteTransformingFiller(visitor);
69        wtf.fill(program);
70        wtf.0
71    }
72
73    /// If `name` is a struct or array whose members are written to, make
74    /// `DefinitionStatement`s for each of its variables that will correspond to
75    /// the members. Note that we create them for all members; unnecessary ones
76    /// will be removed by DCE.
77    pub fn define_variable_members(&mut self, name: Identifier, accumulate: &mut Vec<Statement>) {
78        // The `cloned` here and in the branch below are unfortunate but we need
79        // to mutably borrow `self` again below.
80        if let Some(members) = self.array_members.get(&name.name).cloned() {
81            for (i, member) in members.iter().cloned().enumerate() {
82                // Create a definition for each array index.
83                let index = Literal::integer(
84                    IntegerType::U8,
85                    i.to_string(),
86                    Default::default(),
87                    self.state.node_builder.next_id(),
88                );
89                self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
90                let access = ArrayAccess {
91                    array: Path::from(name).into_absolute().into(),
92                    index: index.into(),
93                    span: Default::default(),
94                    id: self.state.node_builder.next_id(),
95                };
96                self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
97                let def = DefinitionStatement {
98                    place: DefinitionPlace::Single(member),
99                    type_: None,
100                    value: access.into(),
101                    span: Default::default(),
102                    id: self.state.node_builder.next_id(),
103                };
104                accumulate.push(def.into());
105                // And recurse - maybe its members are also written to.
106                self.define_variable_members(member, accumulate);
107            }
108        } else if let Some(members) = self.struct_members.get(&name.name) {
109            for (&field_name, &member) in members.clone().iter() {
110                // Create a definition for each field.
111                let access = MemberAccess {
112                    inner: Path::from(name).into_absolute().into(),
113                    name: Identifier::new(field_name, self.state.node_builder.next_id()),
114                    span: Default::default(),
115                    id: self.state.node_builder.next_id(),
116                };
117                self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
118                let def = DefinitionStatement {
119                    place: DefinitionPlace::Single(member),
120                    type_: None,
121                    value: access.into(),
122                    span: Default::default(),
123                    id: self.state.node_builder.next_id(),
124                };
125                accumulate.push(def.into());
126                // And recurse - maybe its members are also written to.
127                self.define_variable_members(member, accumulate);
128            }
129        }
130    }
131
132    /// If we're assigning to a struct or array member, find the variable name we're actually writing to,
133    /// recursively if necessary.
134    /// That is, if we have
135    /// `arr[0u32][1u32] = ...`,
136    /// we find the corresponding variable `arr_0_1`.
137    pub fn reconstruct_assign_place(&mut self, input: Expression) -> Identifier {
138        use Expression::*;
139        match input {
140            ArrayAccess(array_access) => {
141                let identifier = self.reconstruct_assign_place(array_access.array);
142                self.get_array_member(identifier.name, &array_access.index).expect("We have visited all array writes.")
143            }
144            Path(path) => leo_ast::Identifier { name: path.identifier().name, span: path.span, id: path.id },
145            MemberAccess(member_access) => {
146                let identifier = self.reconstruct_assign_place(member_access.inner);
147                self.get_struct_member(identifier.name, member_access.name.name)
148                    .expect("We have visited all struct writes.")
149            }
150            TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
151            _ => panic!("Type checking should have ensured there are no other places for assignments"),
152        }
153    }
154
155    /// If we're assigning to a struct or array, create assignments to the individual members, if applicable.
156    pub fn reconstruct_assign_recurse(&self, place: Identifier, value: Expression, accumulate: &mut Vec<Statement>) {
157        if let Some(array_members) = self.array_members.get(&place.name) {
158            if let Expression::Array(value_array) = value {
159                // This was an assignment like
160                // `arr = [a, b, c];`
161                // Change it to this:
162                // `arr_0 = a; arr_1 = b; arr_2 = c`
163                for (&member, rhs_element) in array_members.iter().zip(value_array.elements) {
164                    self.reconstruct_assign_recurse(member, rhs_element, accumulate);
165                }
166            } else {
167                // This was an assignment like
168                // `arr = x;`
169                // Change it to this:
170                // `arr = x; arr_0 = x[0]; arr_1 = x[1]; arr_2 = x[2];`
171                let one_assign = AssignStatement {
172                    place: Path::from(place).into_absolute().into(),
173                    value,
174                    span: Default::default(),
175                    id: self.state.node_builder.next_id(),
176                }
177                .into();
178                accumulate.push(one_assign);
179                for (i, &member) in array_members.iter().enumerate() {
180                    let access = ArrayAccess {
181                        array: Path::from(place).into_absolute().into(),
182                        index: Literal::integer(
183                            IntegerType::U32,
184                            format!("{i}u32"),
185                            Default::default(),
186                            self.state.node_builder.next_id(),
187                        )
188                        .into(),
189                        span: Default::default(),
190                        id: self.state.node_builder.next_id(),
191                    };
192                    self.reconstruct_assign_recurse(member, access.into(), accumulate);
193                }
194            }
195        } else if let Some(struct_members) = self.struct_members.get(&place.name) {
196            if let Expression::Struct(value_struct) = value {
197                // This was an assignment like
198                // `struc = S { field0: a, field1: b };`
199                // Change it to this:
200                // `struc_field0 = a; struc_field1 = b;`
201                for initializer in value_struct.members.into_iter() {
202                    let member_name = struct_members.get(&initializer.identifier.name).expect("Member should exist.");
203                    let rhs_expression =
204                        initializer.expression.expect("This should have been normalized to have a value.");
205                    self.reconstruct_assign_recurse(*member_name, rhs_expression, accumulate);
206                }
207            } else {
208                // This was an assignment like
209                // `struc = x;`
210                // Change it to this:
211                // `struc = x; struc_field0 = x.field0; struc_field1 = x.field1;`
212                let one_assign = AssignStatement {
213                    place: Path::from(place).into_absolute().into(),
214                    value,
215                    span: Default::default(),
216                    id: self.state.node_builder.next_id(),
217                }
218                .into();
219                accumulate.push(one_assign);
220                for (field, member_name) in struct_members.iter() {
221                    let access = MemberAccess {
222                        inner: Path::from(place).into_absolute().into(),
223                        name: Identifier::new(*field, self.state.node_builder.next_id()),
224                        span: Default::default(),
225                        id: self.state.node_builder.next_id(),
226                    };
227                    self.reconstruct_assign_recurse(*member_name, access.into(), accumulate);
228                }
229            }
230        } else {
231            let stmt = AssignStatement {
232                value,
233                place: Path::from(place).into_absolute().into(),
234                id: self.state.node_builder.next_id(),
235                span: Default::default(),
236            }
237            .into();
238            accumulate.push(stmt);
239        }
240    }
241}
242
243struct WriteTransformingFiller<'a>(WriteTransformingVisitor<'a>);
244
245impl AstVisitor for WriteTransformingFiller<'_> {
246    type AdditionalInput = ();
247    type Output = ();
248
249    /* Expressions */
250    fn visit_expression(&mut self, _input: &Expression, _additional: &Self::AdditionalInput) -> Self::Output {}
251
252    /* Statements */
253    fn visit_assign(&mut self, input: &AssignStatement) {
254        self.access_recurse(&input.place);
255    }
256}
257
258impl WriteTransformingFiller<'_> {
259    fn fill(&mut self, program: &Program) {
260        for (_, module) in program.modules.iter() {
261            self.0.program = module.program_name;
262            for (_, function) in module.functions.iter() {
263                self.visit_block(&function.block);
264            }
265        }
266        for (_, scope) in program.program_scopes.iter() {
267            for (_, function) in scope.functions.iter() {
268                self.0.program = scope.program_id.name.name;
269                self.visit_block(&function.block);
270            }
271        }
272    }
273
274    /// Find assignments to arrays and structs and populate `array_members` and `struct_members` with new
275    /// variables names.
276    fn access_recurse(&mut self, place: &Expression) -> Identifier {
277        match place {
278            Expression::Path(path) => Identifier { name: path.identifier().name, span: path.span, id: path.id },
279            Expression::ArrayAccess(array_access) => {
280                let array_name = self.access_recurse(&array_access.array);
281                let members = self.0.array_members.entry(array_name.name).or_insert_with(|| {
282                    let ty = self.0.state.type_table.get(&array_access.array.id()).unwrap();
283                    let Type::Array(arr) = ty else { panic!("Type checking should have prevented this.") };
284                    (0..arr.length.as_u32().expect("length should be known at this point"))
285                        .map(|i| {
286                            let id = self.0.state.node_builder.next_id();
287                            let symbol = self.0.state.assigner.unique_symbol(format_args!("{array_name}#{i}"), "$");
288                            self.0.state.type_table.insert(id, arr.element_type().clone());
289                            Identifier::new(symbol, id)
290                        })
291                        .collect()
292                });
293                let Expression::Literal(lit) = &array_access.index else {
294                    panic!("Const propagation should have ensured this is a literal.");
295                };
296                members[lit
297                    .as_u32()
298                    .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
299                    as usize]
300            }
301            Expression::MemberAccess(member_access) => {
302                let struct_name = self.access_recurse(&member_access.inner);
303                let members = self.0.struct_members.entry(struct_name.name).or_insert_with(|| {
304                    let ty = self.0.state.type_table.get(&member_access.inner.id()).unwrap();
305                    let Type::Composite(comp) = ty else {
306                        panic!("Type checking should have prevented this.");
307                    };
308                    let struct_ = self
309                        .0
310                        .state
311                        .symbol_table
312                        .lookup_struct(&comp.path.absolute_path())
313                        .or_else(|| {
314                            self.0.state.symbol_table.lookup_record(&Location::new(
315                                comp.program.unwrap_or(self.0.program),
316                                comp.path.absolute_path(),
317                            ))
318                        })
319                        .unwrap();
320                    struct_
321                        .members
322                        .iter()
323                        .map(|member| {
324                            let name = member.name();
325                            let id = self.0.state.node_builder.next_id();
326                            let symbol = self.0.state.assigner.unique_symbol(format_args!("{struct_name}#{name}"), "$");
327                            self.0.state.type_table.insert(id, member.type_.clone());
328                            (member.name(), Identifier::new(symbol, id))
329                        })
330                        .collect()
331                });
332                *members.get(&member_access.name.name).expect("Type checking should have ensured this is valid.")
333            }
334            Expression::TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
335            _ => panic!("Type checking should have ensured there are no other places for assignments"),
336        }
337    }
338}