leo_passes/write_transforming/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::WriteTransformingVisitor;
18use leo_ast::*;
19use leo_span::Symbol;
20
21impl WriteTransformingVisitor<'_> {
22    pub fn get_array_member(&self, array_name: Symbol, index: &Expression) -> Option<Identifier> {
23        let members = self.array_members.get(&array_name)?;
24        let Expression::Literal(lit) = index else {
25            panic!("Const propagation should have ensured this is a literal.");
26        };
27        let index = lit
28            .as_u32()
29            .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
30            as usize;
31        Some(members[index])
32    }
33
34    pub fn get_struct_member(&self, struct_name: Symbol, field_name: Symbol) -> Option<Identifier> {
35        let members = self.struct_members.get(&struct_name)?;
36        members.get(&field_name).cloned()
37    }
38}
39
40impl WriteTransformingVisitor<'_> {
41    fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Vec<Statement>) {
42        let ty = self.state.type_table.get(&input.id()).unwrap();
43        let mut statements = Vec::new();
44        if let Some(array_members) = self.array_members.get(&input.name) {
45            // Build the array expression from the members.
46            let id = self.state.node_builder.next_id();
47            self.state.type_table.insert(id, ty.clone());
48            let expr = ArrayExpression {
49                elements: array_members
50                    // This clone is unfortunate, but both `array_members` and the closure below borrow self.
51                    .clone()
52                    .iter()
53                    .map(|identifier| {
54                        let (expr, statements2) = self.reconstruct_identifier(*identifier);
55                        statements.extend(statements2);
56                        expr
57                    })
58                    .collect(),
59                span: Default::default(),
60                id,
61            };
62            let statement = AssignStatement {
63                place: Path::from(input).into_absolute().into(),
64                value: expr.into(),
65                span: Default::default(),
66                id: self.state.node_builder.next_id(),
67            };
68            statements.push(statement.into());
69            (Path::from(input).into_absolute().into(), statements)
70        } else if let Some(struct_members) = self.struct_members.get(&input.name) {
71            // Build the struct expression from the members.
72            let id = self.state.node_builder.next_id();
73            self.state.type_table.insert(id, ty.clone());
74            let Type::Composite(comp_type) = ty else {
75                panic!("The type of a struct init should be a composite.");
76            };
77            let expr = StructExpression {
78                const_arguments: Vec::new(), // All const arguments should have been resolved by now
79                members: struct_members
80                    // This clone is unfortunate, but both `struct_members` and the closure below borrow self.
81                    .clone()
82                    .iter()
83                    .map(|(field_name, ident)| {
84                        let (expr, statements2) = self.reconstruct_identifier(*ident);
85                        statements.extend(statements2);
86                        StructVariableInitializer {
87                            identifier: Identifier::new(*field_name, self.state.node_builder.next_id()),
88                            expression: Some(expr),
89                            span: Default::default(),
90                            id: self.state.node_builder.next_id(),
91                        }
92                    })
93                    .collect(),
94                path: comp_type.path,
95                span: Default::default(),
96                id,
97            };
98            let statement = AssignStatement {
99                place: Path::from(input).into_absolute().into(),
100                value: expr.into(),
101                span: Default::default(),
102                id: self.state.node_builder.next_id(),
103            };
104            statements.push(statement.into());
105            (Path::from(input).into_absolute().into(), statements)
106        } else {
107            // This is not a struct or array whose members are written to, so there's nothing to do.
108            (Path::from(input).into_absolute().into(), Default::default())
109        }
110    }
111}
112
113impl AstReconstructor for WriteTransformingVisitor<'_> {
114    type AdditionalInput = ();
115    type AdditionalOutput = Vec<Statement>;
116
117    /* Expressions */
118    fn reconstruct_path(&mut self, input: Path, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
119        if input.qualifier().is_empty() {
120            self.reconstruct_identifier(Identifier { name: input.identifier().name, span: input.span, id: input.id })
121        } else {
122            (input.into(), Default::default())
123        }
124    }
125
126    fn reconstruct_array_access(
127        &mut self,
128        input: ArrayAccess,
129        _addiional: &(),
130    ) -> (Expression, Self::AdditionalOutput) {
131        let Expression::Path(ref array_name) = input.array else {
132            panic!("SSA ensures that this is a Path.");
133        };
134        if let Some(member) = self.get_array_member(array_name.identifier().name, &input.index) {
135            self.reconstruct_identifier(member)
136        } else {
137            (input.into(), Default::default())
138        }
139    }
140
141    fn reconstruct_member_access(
142        &mut self,
143        input: MemberAccess,
144        _addiional: &(),
145    ) -> (Expression, Self::AdditionalOutput) {
146        let Expression::Path(ref struct_name) = input.inner else {
147            panic!("SSA ensures that this is a Path.");
148        };
149        if let Some(member) = self.get_struct_member(struct_name.identifier().name, input.name.name) {
150            self.reconstruct_identifier(member)
151        } else {
152            (input.into(), Default::default())
153        }
154    }
155
156    // The rest of the methods below don't do anything but traverse - we only modify their default implementations
157    // to combine the `Vec<Statement>` outputs.
158
159    fn reconstruct_associated_function(
160        &mut self,
161        mut input: AssociatedFunctionExpression,
162        _addiional: &(),
163    ) -> (Expression, Self::AdditionalOutput) {
164        let mut statements = Vec::new();
165        for arg in input.arguments.iter_mut() {
166            let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
167            statements.extend(statements2);
168            *arg = expr;
169        }
170        (input.into(), statements)
171    }
172
173    fn reconstruct_tuple_access(
174        &mut self,
175        _input: TupleAccess,
176        _addiional: &(),
177    ) -> (Expression, Self::AdditionalOutput) {
178        panic!("`TupleAccess` should not be in the AST at this point.");
179    }
180
181    fn reconstruct_array(
182        &mut self,
183        mut input: ArrayExpression,
184        _addiional: &(),
185    ) -> (Expression, Self::AdditionalOutput) {
186        let mut statements = Vec::new();
187        for element in input.elements.iter_mut() {
188            let (expr, statements2) = self.reconstruct_expression(std::mem::take(element), &());
189            statements.extend(statements2);
190            *element = expr;
191        }
192        (input.into(), statements)
193    }
194
195    fn reconstruct_binary(&mut self, input: BinaryExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
196        let (left, mut statements) = self.reconstruct_expression(input.left, &());
197        let (right, statements2) = self.reconstruct_expression(input.right, &());
198        statements.extend(statements2);
199        (BinaryExpression { left, right, ..input }.into(), statements)
200    }
201
202    fn reconstruct_call(&mut self, mut input: CallExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
203        let mut statements = Vec::new();
204        for arg in input.arguments.iter_mut() {
205            let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
206            statements.extend(statements2);
207            *arg = expr;
208        }
209        (input.into(), statements)
210    }
211
212    fn reconstruct_cast(&mut self, input: CastExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
213        let (expression, statements) = self.reconstruct_expression(input.expression, &());
214        (CastExpression { expression, ..input }.into(), statements)
215    }
216
217    fn reconstruct_struct_init(
218        &mut self,
219        mut input: StructExpression,
220        _addiional: &(),
221    ) -> (Expression, Self::AdditionalOutput) {
222        let mut statements = Vec::new();
223        for member in input.members.iter_mut() {
224            assert!(member.expression.is_some());
225            let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap(), &());
226            statements.extend(statements2);
227            member.expression = Some(expr);
228        }
229
230        (input.into(), statements)
231    }
232
233    fn reconstruct_err(
234        &mut self,
235        _input: leo_ast::ErrExpression,
236        _addiional: &(),
237    ) -> (Expression, Self::AdditionalOutput) {
238        std::panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
239    }
240
241    fn reconstruct_literal(
242        &mut self,
243        input: leo_ast::Literal,
244        _addiional: &(),
245    ) -> (Expression, Self::AdditionalOutput) {
246        (input.into(), Default::default())
247    }
248
249    fn reconstruct_ternary(
250        &mut self,
251        input: TernaryExpression,
252        _addiional: &(),
253    ) -> (Expression, Self::AdditionalOutput) {
254        let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
255        let (if_true, statements2) = self.reconstruct_expression(input.if_true, &());
256        let (if_false, statements3) = self.reconstruct_expression(input.if_false, &());
257        statements.extend(statements2);
258        statements.extend(statements3);
259        (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
260    }
261
262    fn reconstruct_tuple(
263        &mut self,
264        input: leo_ast::TupleExpression,
265        _addiional: &(),
266    ) -> (Expression, Self::AdditionalOutput) {
267        // This should ony appear in a return statement.
268        let mut statements = Vec::new();
269        let elements = input
270            .elements
271            .into_iter()
272            .map(|element| {
273                let (expr, statements2) = self.reconstruct_expression(element, &());
274                statements.extend(statements2);
275                expr
276            })
277            .collect();
278        (TupleExpression { elements, ..input }.into(), statements)
279    }
280
281    fn reconstruct_unary(
282        &mut self,
283        input: leo_ast::UnaryExpression,
284        _addiional: &(),
285    ) -> (Expression, Self::AdditionalOutput) {
286        let (receiver, statements) = self.reconstruct_expression(input.receiver, &());
287        (UnaryExpression { receiver, ..input }.into(), statements)
288    }
289
290    fn reconstruct_unit(
291        &mut self,
292        input: leo_ast::UnitExpression,
293        _addiional: &(),
294    ) -> (Expression, Self::AdditionalOutput) {
295        (input.into(), Default::default())
296    }
297
298    /* Statements */
299    /// This is the only reconstructing function where we do anything other than traverse and combine statements,
300    /// by calling `reconstruct_assign_place` and `reconstruct_assign_recurse`.
301    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
302        let (value, mut statements) = self.reconstruct_expression(input.value, &());
303        let place = self.reconstruct_assign_place(input.place);
304        self.reconstruct_assign_recurse(place, value, &mut statements);
305        (Statement::dummy(), statements)
306    }
307
308    fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
309        let mut statements = Vec::new();
310        let stmt = AssertStatement {
311            variant: match input.variant {
312                AssertVariant::Assert(expr) => {
313                    let (expr, statements2) = self.reconstruct_expression(expr, &());
314                    statements.extend(statements2);
315                    AssertVariant::Assert(expr)
316                }
317                AssertVariant::AssertEq(left, right) => {
318                    let (left, statements2) = self.reconstruct_expression(left, &());
319                    statements.extend(statements2);
320                    let (right, statements3) = self.reconstruct_expression(right, &());
321                    statements.extend(statements3);
322                    AssertVariant::AssertEq(left, right)
323                }
324                AssertVariant::AssertNeq(left, right) => {
325                    let (left, statements2) = self.reconstruct_expression(left, &());
326                    statements.extend(statements2);
327                    let (right, statements3) = self.reconstruct_expression(right, &());
328                    statements.extend(statements3);
329                    AssertVariant::AssertNeq(left, right)
330                }
331            },
332            ..input
333        }
334        .into();
335        (stmt, Default::default())
336    }
337
338    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
339        let mut statements = Vec::with_capacity(block.statements.len());
340
341        // Reconstruct the statements in the block, accumulating any additional statements.
342        for statement in block.statements {
343            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
344            statements.extend(additional_statements);
345            if !reconstructed_statement.is_empty() {
346                statements.push(reconstructed_statement);
347            }
348        }
349
350        (Block { statements, ..block }, Default::default())
351    }
352
353    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
354        let (value, mut statements) = self.reconstruct_expression(input.value, &());
355        input.value = value;
356        match input.place.clone() {
357            DefinitionPlace::Single(identifier) => {
358                statements.push(input.into());
359                self.define_variable_members(identifier, &mut statements);
360            }
361            DefinitionPlace::Multiple(identifiers) => {
362                statements.push(input.into());
363                for &identifier in identifiers.iter() {
364                    self.define_variable_members(identifier, &mut statements);
365                }
366            }
367        }
368        (Statement::dummy(), statements)
369    }
370
371    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
372        let (expression, statements) = self.reconstruct_expression(input.expression, &());
373        (ExpressionStatement { expression, ..input }.into(), statements)
374    }
375
376    fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
377        panic!("`IterationStatement`s should not be in the AST at this point.");
378    }
379
380    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
381        let (expression, statements) = self.reconstruct_expression(input.expression, &());
382        (ReturnStatement { expression, ..input }.into(), statements)
383    }
384
385    fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
386        let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
387        let (then, statements2) = self.reconstruct_block(input.then);
388        statements.extend(statements2);
389        let otherwise = input.otherwise.map(|oth| {
390            let (expr, statements3) = self.reconstruct_statement(*oth);
391            statements.extend(statements3);
392            Box::new(expr)
393        });
394        (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
395    }
396}