leo_passes/flattening/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::{FlatteningVisitor, Guard, ReturnGuard};
18
19use leo_ast::*;
20
21use itertools::Itertools;
22
23impl AstReconstructor for FlatteningVisitor<'_> {
24    type AdditionalOutput = Vec<Statement>;
25
26    /* Expressions */
27    /// Reconstructs a struct init expression, flattening any tuples in the expression.
28    fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
29        let mut statements = Vec::new();
30        let mut members = Vec::with_capacity(input.members.len());
31
32        // Reconstruct and flatten the argument expressions.
33        for member in input.members.into_iter() {
34            // Note that this unwrap is safe since SSA guarantees that all struct variable initializers are of the form `<name>: <expr>`.
35            let (expr, stmts) = self.reconstruct_expression(member.expression.unwrap());
36            // Accumulate any statements produced.
37            statements.extend(stmts);
38            // Accumulate the struct members.
39            members.push(StructVariableInitializer {
40                identifier: member.identifier,
41                expression: Some(expr),
42                span: member.span,
43                id: member.id,
44            });
45        }
46
47        (StructExpression { members, ..input }.into(), statements)
48    }
49
50    /// Reconstructs ternary expressions over arrays, structs, and tuples, accumulating any statements that are generated.
51    /// This is necessary because Aleo instructions does not support ternary expressions over composite data types.
52    /// For example, the ternary expression `cond ? (a, b) : (c, d)` is flattened into the following:
53    /// ```leo
54    /// let var$0 = cond ? a : c;
55    /// let var$1 = cond ? b : d;
56    /// (var$0, var$1)
57    /// ```
58    /// For structs, the ternary expression `cond ? a : b`, where `a` and `b` are both structs `Foo { bar: u8, baz: u8 }`, is flattened into the following:
59    /// ```leo
60    /// let var$0 = cond ? a.bar : b.bar;
61    /// let var$1 = cond ? a.baz : b.baz;
62    /// let var$2 = Foo { bar: var$0, baz: var$1 };
63    /// var$2
64    /// ```
65    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
66        let if_true_type = self
67            .state
68            .type_table
69            .get(&input.if_true.id())
70            .expect("Type checking guarantees that all expressions are typed.");
71        let if_false_type = self
72            .state
73            .type_table
74            .get(&input.if_false.id())
75            .expect("Type checking guarantees that all expressions are typed.");
76
77        // Note that type checking guarantees that both expressions have the same same type. This is a sanity check.
78        assert!(if_true_type.eq_flat_relaxed(&if_false_type));
79
80        fn as_identifier(ident_expr: Expression) -> Identifier {
81            let Expression::Identifier(identifier) = ident_expr else {
82                panic!("SSA form should have guaranteed this is an identifier: {}.", ident_expr);
83            };
84            identifier
85        }
86
87        match &if_true_type {
88            Type::Array(if_true_type) => self.ternary_array(
89                if_true_type,
90                &input.condition,
91                &as_identifier(input.if_true),
92                &as_identifier(input.if_false),
93            ),
94            Type::Composite(if_true_type) => {
95                // Get the struct definitions.
96                let program = if_true_type.program.unwrap_or(self.program);
97                let if_true_type = self
98                    .state
99                    .symbol_table
100                    .lookup_struct(if_true_type.id.name)
101                    .or_else(|| self.state.symbol_table.lookup_record(Location::new(program, if_true_type.id.name)))
102                    .expect("This definition should exist")
103                    .clone();
104
105                self.ternary_struct(
106                    &if_true_type,
107                    &input.condition,
108                    &as_identifier(input.if_true),
109                    &as_identifier(input.if_false),
110                )
111            }
112            Type::Tuple(if_true_type) => {
113                self.ternary_tuple(if_true_type, &input.condition, &input.if_true, &input.if_false)
114            }
115            _ => {
116                // There's nothing to be done - SSA has guaranteed that `if_true` and `if_false` are identifiers,
117                // so there's not even any point in reconstructing them.
118
119                assert!(matches!(&input.if_true, Expression::Identifier(..)));
120                assert!(matches!(&input.if_false, Expression::Identifier(..)));
121
122                (input.into(), Default::default())
123            }
124        }
125    }
126
127    /* Statements */
128    /// Rewrites an assert statement into a flattened form.
129    /// Assert statements at the top level only have their arguments flattened.
130    /// Assert statements inside a conditional statement are flattened to such that the check is conditional on
131    /// the execution path being valid.
132    /// For example, the following snippet:
133    /// ```leo
134    /// if condition1 {
135    ///    if condition2 {
136    ///        assert(foo);
137    ///    }
138    /// }
139    /// ```
140    /// is flattened to:
141    /// ```leo
142    /// assert(!(condition1 && condition2) || foo);
143    /// ```
144    /// which is equivalent to the logical formula `(condition1 /\ condition2) ==> foo`.
145    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
146        let mut statements = Vec::new();
147
148        // If we are traversing an async function, then we can return the assert as it.
149        if self.is_async {
150            return (input.into(), statements);
151        }
152
153        // Flatten the arguments of the assert statement.
154        let assert = AssertStatement {
155            span: input.span,
156            id: input.id,
157            variant: match input.variant {
158                AssertVariant::Assert(expression) => {
159                    let (expression, additional_statements) = self.reconstruct_expression(expression);
160                    statements.extend(additional_statements);
161                    AssertVariant::Assert(expression)
162                }
163                AssertVariant::AssertEq(left, right) => {
164                    let (left, additional_statements) = self.reconstruct_expression(left);
165                    statements.extend(additional_statements);
166                    let (right, additional_statements) = self.reconstruct_expression(right);
167                    statements.extend(additional_statements);
168                    AssertVariant::AssertEq(left, right)
169                }
170                AssertVariant::AssertNeq(left, right) => {
171                    let (left, additional_statements) = self.reconstruct_expression(left);
172                    statements.extend(additional_statements);
173                    let (right, additional_statements) = self.reconstruct_expression(right);
174                    statements.extend(additional_statements);
175                    AssertVariant::AssertNeq(left, right)
176                }
177            },
178        };
179
180        let mut guards: Vec<Expression> = Vec::new();
181
182        if let Some((guard, guard_statements)) = self.construct_guard() {
183            statements.extend(guard_statements);
184
185            // The not_guard is true if we didn't follow the condition chain
186            // that led to this assertion.
187            let not_guard = UnaryExpression {
188                op: UnaryOperation::Not,
189                receiver: guard.into(),
190                span: Default::default(),
191                id: {
192                    // Create a new node ID for the unary expression.
193                    let id = self.state.node_builder.next_id();
194                    // Update the type table with the type of the unary expression.
195                    self.state.type_table.insert(id, Type::Boolean);
196                    id
197                },
198            }
199            .into();
200            let (identifier, statement) = self.unique_simple_definition(not_guard);
201            statements.push(statement);
202            guards.push(identifier.into());
203        }
204
205        // We also need to guard against early returns.
206        if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
207            guards.push(guard.into());
208            statements.extend(guard_statements);
209        }
210
211        if guards.is_empty() {
212            return (assert.into(), statements);
213        }
214
215        let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));
216
217        // We need to `or` the asserted expression with the guards,
218        // so extract an appropriate expression.
219        let mut expression = match assert.variant {
220            // If the assert statement is an `assert`, use the expression as is.
221            AssertVariant::Assert(expression) => expression,
222
223            // For `assert_eq` or `assert_neq`, construct a new expression.
224            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
225                let binary = BinaryExpression {
226                    left,
227                    right,
228                    op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
229                    span: Default::default(),
230                    id: self.state.node_builder.next_id(),
231                };
232                self.state.type_table.insert(binary.id, Type::Boolean);
233                let (identifier, statement) = self.unique_simple_definition(binary.into());
234                statements.push(statement);
235                identifier.into()
236            }
237        };
238
239        // The assertion will be that the original assert statement is true or one of the guards is true
240        // (ie, we either didn't follow the condition chain that led to this assert, or else we took an
241        // early return).
242        for guard in guards.into_iter() {
243            let binary = BinaryExpression {
244                left: expression,
245                right: guard,
246                op: BinaryOperation::Or,
247                span: Default::default(),
248                id: self.state.node_builder.next_id(),
249            };
250            self.state.type_table.insert(binary.id(), Type::Boolean);
251            let (identifier, statement) = self.unique_simple_definition(binary.into());
252            statements.push(statement);
253            expression = identifier.into();
254        }
255
256        let assert_statement = AssertStatement { variant: AssertVariant::Assert(expression), ..input }.into();
257
258        (assert_statement, statements)
259    }
260
261    fn reconstruct_assign(&mut self, _assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
262        panic!("`AssignStatement`s should not be in the AST at this phase of compilation");
263    }
264
265    // TODO: Do we want to flatten nested blocks? They do not affect code generation but it would regularize the AST structure.
266    /// Flattens the statements inside a basic block.
267    /// The resulting block does not contain any conditional statements.
268    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
269        let mut statements = Vec::with_capacity(block.statements.len());
270
271        // Flatten each statement, accumulating any new statements produced.
272        for statement in block.statements {
273            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
274            statements.extend(additional_statements);
275            statements.push(reconstructed_statement);
276        }
277
278        (Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
279    }
280
281    /// Flatten a conditional statement into a list of statements.
282    fn reconstruct_conditional(&mut self, conditional: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
283        let mut statements = Vec::with_capacity(conditional.then.statements.len());
284
285        // If we are traversing an async function, reconstruct the if and else blocks, but do not flatten them.
286        if self.is_async {
287            let then_block = self.reconstruct_block(conditional.then).0;
288            let otherwise_block = match conditional.otherwise {
289                Some(statement) => match *statement {
290                    Statement::Block(block) => self.reconstruct_block(block).0,
291                    _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
292                },
293                None => {
294                    Block { span: Default::default(), statements: Vec::new(), id: self.state.node_builder.next_id() }
295                }
296            };
297
298            return (
299                ConditionalStatement {
300                    then: then_block,
301                    otherwise: Some(Box::new(otherwise_block.into())),
302                    ..conditional
303                }
304                .into(),
305                statements,
306            );
307        }
308
309        // Assign the condition to a variable, as it may be used multiple times.
310        let place = Identifier {
311            name: self.state.assigner.unique_symbol("condition", "$"),
312            span: Default::default(),
313            id: {
314                let id = self.state.node_builder.next_id();
315                self.state.type_table.insert(id, Type::Boolean);
316                id
317            },
318        };
319
320        statements.push(self.simple_definition(place, conditional.condition.clone()));
321
322        // Add condition to the condition stack.
323        self.condition_stack.push(Guard::Unconstructed(place));
324
325        // Reconstruct the then-block and accumulate it constituent statements.
326        statements.extend(self.reconstruct_block(conditional.then).0.statements);
327
328        // Remove condition from the condition stack.
329        self.condition_stack.pop();
330
331        // Consume the otherwise-block and flatten its constituent statements into the current block.
332        if let Some(statement) = conditional.otherwise {
333            // Apply Not to the condition, assign it, and put it on the condition stack.
334            let not_condition = UnaryExpression {
335                op: UnaryOperation::Not,
336                receiver: conditional.condition.clone(),
337                span: conditional.condition.span(),
338                id: conditional.condition.id(),
339            }
340            .into();
341            let not_place = Identifier {
342                name: self.state.assigner.unique_symbol("condition", "$"),
343                span: Default::default(),
344                id: {
345                    let id = self.state.node_builder.next_id();
346                    self.state.type_table.insert(id, Type::Boolean);
347                    id
348                },
349            };
350            statements.push(self.simple_definition(not_place, not_condition));
351            self.condition_stack.push(Guard::Unconstructed(not_place));
352
353            // Reconstruct the otherwise-block and accumulate it constituent statements.
354            match *statement {
355                Statement::Block(block) => statements.extend(self.reconstruct_block(block).0.statements),
356                _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
357            }
358
359            // Remove the negated condition from the condition stack.
360            self.condition_stack.pop();
361        };
362
363        (Statement::dummy(), statements)
364    }
365
366    /// Flattens a definition, if necessary.
367    /// Marks variables as structs as necessary.
368    /// Note that new statements are only produced if the right hand side is a ternary expression over structs.
369    /// Otherwise, the statement is returned as is.
370    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
371        // Flatten the rhs of the assignment.
372        let (value, statements) = self.reconstruct_expression(definition.value);
373        match (definition.place, &value) {
374            (DefinitionPlace::Single(identifier), _) => (self.simple_definition(identifier, value), statements),
375            (DefinitionPlace::Multiple(identifiers), expression) => {
376                let output_type = match &self.state.type_table.get(&expression.id()) {
377                    Some(Type::Tuple(tuple_type)) => tuple_type.clone(),
378                    _ => panic!("Type checking guarantees that the output type is a tuple."),
379                };
380
381                for (identifier, type_) in identifiers.iter().zip_eq(output_type.elements().iter()) {
382                    // Add the type of each identifier to the type table.
383                    self.state.type_table.insert(identifier.id, type_.clone());
384                }
385
386                (
387                    DefinitionStatement {
388                        place: DefinitionPlace::Multiple(identifiers),
389                        type_: None,
390                        value,
391                        span: Default::default(),
392                        id: self.state.node_builder.next_id(),
393                    }
394                    .into(),
395                    statements,
396                )
397            }
398        }
399    }
400
401    fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
402        panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
403    }
404
405    /// Transforms a return statement into an empty block statement.
406    /// Stores the arguments to the return statement, which are later folded into a single return statement at the end of the function.
407    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
408        use Expression::*;
409
410        // If we are traversing an async function, return as is.
411        if self.is_async {
412            return (input.into(), Default::default());
413        }
414        // Construct the associated guard.
415        let (guard_identifier, statements) = self.construct_guard().unzip();
416
417        let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);
418
419        let is_tuple_ids = matches!(&input.expression, Tuple(tuple_expr) if tuple_expr .elements.iter() .all(|expr| matches!(expr, Identifier(_))));
420        if !matches!(&input.expression, Unit(_) | Identifier(_) | AssociatedConstant(_)) && !is_tuple_ids {
421            panic!("SSA guarantees that the expression is always an identifier, unit expression, or tuple literal.")
422        }
423
424        self.returns.push((return_guard, input));
425
426        (Statement::dummy(), statements.unwrap_or_default())
427    }
428}