leo_passes/flattening/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::{FlatteningVisitor, Guard, ReturnGuard};
18
19use leo_ast::*;
20
21use itertools::Itertools;
22
23impl AstReconstructor for FlatteningVisitor<'_> {
24    type AdditionalInput = ();
25    type AdditionalOutput = Vec<Statement>;
26
27    /* Expressions */
28    /// Reconstructs a struct init expression, flattening any tuples in the expression.
29    fn reconstruct_struct_init(
30        &mut self,
31        input: StructExpression,
32        _additional: &(),
33    ) -> (Expression, Self::AdditionalOutput) {
34        let mut statements = Vec::new();
35        let mut members = Vec::with_capacity(input.members.len());
36
37        // Reconstruct and flatten the argument expressions.
38        for member in input.members.into_iter() {
39            // Note that this unwrap is safe since SSA guarantees that all struct variable initializers are of the form `<name>: <expr>`.
40            let (expr, stmts) = self.reconstruct_expression(member.expression.unwrap(), &());
41            // Accumulate any statements produced.
42            statements.extend(stmts);
43            // Accumulate the struct members.
44            members.push(StructVariableInitializer {
45                identifier: member.identifier,
46                expression: Some(expr),
47                span: member.span,
48                id: member.id,
49            });
50        }
51
52        (StructExpression { members, ..input }.into(), statements)
53    }
54
55    /// Reconstructs ternary expressions over arrays, structs, and tuples, accumulating any statements that are generated.
56    /// This is necessary because Aleo instructions does not support ternary expressions over composite data types.
57    /// For example, the ternary expression `cond ? (a, b) : (c, d)` is flattened into the following:
58    /// ```leo
59    /// let var$0 = cond ? a : c;
60    /// let var$1 = cond ? b : d;
61    /// (var$0, var$1)
62    /// ```
63    /// For structs, the ternary expression `cond ? a : b`, where `a` and `b` are both structs `Foo { bar: u8, baz: u8 }`, is flattened into the following:
64    /// ```leo
65    /// let var$0 = cond ? a.bar : b.bar;
66    /// let var$1 = cond ? a.baz : b.baz;
67    /// let var$2 = Foo { bar: var$0, baz: var$1 };
68    /// var$2
69    /// ```
70    fn reconstruct_ternary(
71        &mut self,
72        input: TernaryExpression,
73        _additional: &(),
74    ) -> (Expression, Self::AdditionalOutput) {
75        let if_true_type = self
76            .state
77            .type_table
78            .get(&input.if_true.id())
79            .expect("Type checking guarantees that all expressions are typed.");
80        let if_false_type = self
81            .state
82            .type_table
83            .get(&input.if_false.id())
84            .expect("Type checking guarantees that all expressions are typed.");
85
86        // Note that type checking guarantees that both expressions have the same same type. This is a sanity check.
87        assert!(if_true_type.eq_flat_relaxed(&if_false_type));
88
89        fn as_identifier(path_expr: Expression) -> Identifier {
90            let Expression::Path(path) = path_expr else {
91                panic!("SSA form should have guaranteed this is a path: {path_expr}.");
92            };
93            Identifier { name: path.identifier().name, span: path.span, id: path.id }
94        }
95
96        match &if_true_type {
97            Type::Array(if_true_type) => self.ternary_array(
98                if_true_type,
99                &input.condition,
100                &as_identifier(input.if_true),
101                &as_identifier(input.if_false),
102            ),
103            Type::Composite(if_true_type) => {
104                // Get the struct definitions.
105                let program = if_true_type.program.unwrap_or(self.program);
106                let composite_path = if_true_type.path.clone();
107                let if_true_type = self
108                    .state
109                    .symbol_table
110                    .lookup_struct(&composite_path.absolute_path())
111                    .or_else(|| {
112                        self.state.symbol_table.lookup_record(&Location::new(program, composite_path.absolute_path()))
113                    })
114                    .expect("This definition should exist")
115                    .clone();
116
117                self.ternary_struct(
118                    &composite_path,
119                    &if_true_type,
120                    &input.condition,
121                    &as_identifier(input.if_true),
122                    &as_identifier(input.if_false),
123                )
124            }
125            Type::Tuple(if_true_type) => {
126                self.ternary_tuple(if_true_type, &input.condition, &input.if_true, &input.if_false)
127            }
128            _ => {
129                // There's nothing to be done - SSA has guaranteed that `if_true` and `if_false` are identifiers,
130                // so there's not even any point in reconstructing them.
131
132                assert!(matches!(&input.if_true, Expression::Path(..)));
133                assert!(matches!(&input.if_false, Expression::Path(..)));
134
135                (input.into(), Default::default())
136            }
137        }
138    }
139
140    /* Statements */
141    /// Rewrites an assert statement into a flattened form.
142    /// Assert statements at the top level only have their arguments flattened.
143    /// Assert statements inside a conditional statement are flattened to such that the check is conditional on
144    /// the execution path being valid.
145    /// For example, the following snippet:
146    /// ```leo
147    /// if condition1 {
148    ///    if condition2 {
149    ///        assert(foo);
150    ///    }
151    /// }
152    /// ```
153    /// is flattened to:
154    /// ```leo
155    /// assert(!(condition1 && condition2) || foo);
156    /// ```
157    /// which is equivalent to the logical formula `(condition1 /\ condition2) ==> foo`.
158    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
159        let mut statements = Vec::new();
160
161        // If we are traversing an async function, then we can return the assert as it.
162        if self.is_async {
163            return (input.into(), statements);
164        }
165
166        // Flatten the arguments of the assert statement.
167        let assert = AssertStatement {
168            span: input.span,
169            id: input.id,
170            variant: match input.variant {
171                AssertVariant::Assert(expression) => {
172                    let (expression, additional_statements) = self.reconstruct_expression(expression, &());
173                    statements.extend(additional_statements);
174                    AssertVariant::Assert(expression)
175                }
176                AssertVariant::AssertEq(left, right) => {
177                    let (left, additional_statements) = self.reconstruct_expression(left, &());
178                    statements.extend(additional_statements);
179                    let (right, additional_statements) = self.reconstruct_expression(right, &());
180                    statements.extend(additional_statements);
181                    AssertVariant::AssertEq(left, right)
182                }
183                AssertVariant::AssertNeq(left, right) => {
184                    let (left, additional_statements) = self.reconstruct_expression(left, &());
185                    statements.extend(additional_statements);
186                    let (right, additional_statements) = self.reconstruct_expression(right, &());
187                    statements.extend(additional_statements);
188                    AssertVariant::AssertNeq(left, right)
189                }
190            },
191        };
192
193        let mut guards: Vec<Expression> = Vec::new();
194
195        if let Some((guard, guard_statements)) = self.construct_guard() {
196            statements.extend(guard_statements);
197
198            // The not_guard is true if we didn't follow the condition chain
199            // that led to this assertion.
200            let not_guard = UnaryExpression {
201                op: UnaryOperation::Not,
202                receiver: Path::from(guard).into(),
203                span: Default::default(),
204                id: {
205                    // Create a new node ID for the unary expression.
206                    let id = self.state.node_builder.next_id();
207                    // Update the type table with the type of the unary expression.
208                    self.state.type_table.insert(id, Type::Boolean);
209                    id
210                },
211            }
212            .into();
213            let (identifier, statement) = self.unique_simple_definition(not_guard);
214            statements.push(statement);
215            guards.push(Path::from(identifier).into());
216        }
217
218        // We also need to guard against early returns.
219        if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
220            guards.push(Path::from(guard).into());
221            statements.extend(guard_statements);
222        }
223
224        if guards.is_empty() {
225            return (assert.into(), statements);
226        }
227
228        let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));
229
230        // We need to `or` the asserted expression with the guards,
231        // so extract an appropriate expression.
232        let mut expression = match assert.variant {
233            // If the assert statement is an `assert`, use the expression as is.
234            AssertVariant::Assert(expression) => expression,
235
236            // For `assert_eq` or `assert_neq`, construct a new expression.
237            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
238                let binary = BinaryExpression {
239                    left,
240                    right,
241                    op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
242                    span: Default::default(),
243                    id: self.state.node_builder.next_id(),
244                };
245                self.state.type_table.insert(binary.id, Type::Boolean);
246                let (identifier, statement) = self.unique_simple_definition(binary.into());
247                statements.push(statement);
248                Path::from(identifier).into()
249            }
250        };
251
252        // The assertion will be that the original assert statement is true or one of the guards is true
253        // (ie, we either didn't follow the condition chain that led to this assert, or else we took an
254        // early return).
255        for guard in guards.into_iter() {
256            let binary = BinaryExpression {
257                left: expression,
258                right: guard,
259                op: BinaryOperation::Or,
260                span: Default::default(),
261                id: self.state.node_builder.next_id(),
262            };
263            self.state.type_table.insert(binary.id(), Type::Boolean);
264            let (identifier, statement) = self.unique_simple_definition(binary.into());
265            statements.push(statement);
266            expression = Path::from(identifier).into();
267        }
268
269        let assert_statement = AssertStatement { variant: AssertVariant::Assert(expression), ..input }.into();
270
271        (assert_statement, statements)
272    }
273
274    fn reconstruct_assign(&mut self, _assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
275        panic!("`AssignStatement`s should not be in the AST at this phase of compilation");
276    }
277
278    // TODO: Do we want to flatten nested blocks? They do not affect code generation but it would regularize the AST structure.
279    /// Flattens the statements inside a basic block.
280    /// The resulting block does not contain any conditional statements.
281    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
282        let mut statements = Vec::with_capacity(block.statements.len());
283
284        // Flatten each statement, accumulating any new statements produced.
285        for statement in block.statements {
286            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
287            statements.extend(additional_statements);
288            statements.push(reconstructed_statement);
289        }
290
291        (Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
292    }
293
294    /// Flatten a conditional statement into a list of statements.
295    fn reconstruct_conditional(&mut self, conditional: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
296        let mut statements = Vec::with_capacity(conditional.then.statements.len());
297
298        // If we are traversing an async function, reconstruct the if and else blocks, but do not flatten them.
299        if self.is_async {
300            let then_block = self.reconstruct_block(conditional.then).0;
301            let otherwise_block = match conditional.otherwise {
302                Some(statement) => match *statement {
303                    Statement::Block(block) => self.reconstruct_block(block).0,
304                    _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
305                },
306                None => {
307                    Block { span: Default::default(), statements: Vec::new(), id: self.state.node_builder.next_id() }
308                }
309            };
310
311            return (
312                ConditionalStatement {
313                    then: then_block,
314                    otherwise: Some(Box::new(otherwise_block.into())),
315                    ..conditional
316                }
317                .into(),
318                statements,
319            );
320        }
321
322        // Assign the condition to a variable, as it may be used multiple times.
323        let place = Identifier {
324            name: self.state.assigner.unique_symbol("condition", "$"),
325            span: Default::default(),
326            id: {
327                let id = self.state.node_builder.next_id();
328                self.state.type_table.insert(id, Type::Boolean);
329                id
330            },
331        };
332
333        statements.push(self.simple_definition(place, conditional.condition.clone()));
334
335        // Add condition to the condition stack.
336        self.condition_stack.push(Guard::Unconstructed(place));
337
338        // Reconstruct the then-block and accumulate it constituent statements.
339        statements.extend(self.reconstruct_block(conditional.then).0.statements);
340
341        // Remove condition from the condition stack.
342        self.condition_stack.pop();
343
344        // Consume the otherwise-block and flatten its constituent statements into the current block.
345        if let Some(statement) = conditional.otherwise {
346            // Apply Not to the condition, assign it, and put it on the condition stack.
347            let not_condition = UnaryExpression {
348                op: UnaryOperation::Not,
349                receiver: conditional.condition.clone(),
350                span: conditional.condition.span(),
351                id: conditional.condition.id(),
352            }
353            .into();
354            let not_place = Identifier {
355                name: self.state.assigner.unique_symbol("condition", "$"),
356                span: Default::default(),
357                id: {
358                    let id = self.state.node_builder.next_id();
359                    self.state.type_table.insert(id, Type::Boolean);
360                    id
361                },
362            };
363            statements.push(self.simple_definition(not_place, not_condition));
364            self.condition_stack.push(Guard::Unconstructed(not_place));
365
366            // Reconstruct the otherwise-block and accumulate it constituent statements.
367            match *statement {
368                Statement::Block(block) => statements.extend(self.reconstruct_block(block).0.statements),
369                _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
370            }
371
372            // Remove the negated condition from the condition stack.
373            self.condition_stack.pop();
374        };
375
376        (Statement::dummy(), statements)
377    }
378
379    /// Flattens a definition, if necessary.
380    /// Marks variables as structs as necessary.
381    /// Note that new statements are only produced if the right hand side is a ternary expression over structs.
382    /// Otherwise, the statement is returned as is.
383    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
384        // Flatten the rhs of the assignment.
385        let (value, statements) = self.reconstruct_expression(definition.value, &());
386        match (definition.place, &value) {
387            (DefinitionPlace::Single(identifier), _) => (self.simple_definition(identifier, value), statements),
388            (DefinitionPlace::Multiple(identifiers), expression) => {
389                let output_type = match &self.state.type_table.get(&expression.id()) {
390                    Some(Type::Tuple(tuple_type)) => tuple_type.clone(),
391                    _ => panic!("Type checking guarantees that the output type is a tuple."),
392                };
393
394                for (identifier, type_) in identifiers.iter().zip_eq(output_type.elements().iter()) {
395                    // Add the type of each identifier to the type table.
396                    self.state.type_table.insert(identifier.id, type_.clone());
397                }
398
399                (
400                    DefinitionStatement {
401                        place: DefinitionPlace::Multiple(identifiers),
402                        type_: None,
403                        value,
404                        span: Default::default(),
405                        id: self.state.node_builder.next_id(),
406                    }
407                    .into(),
408                    statements,
409                )
410            }
411        }
412    }
413
414    fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
415        panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
416    }
417
418    /// Transforms a return statement into an empty block statement.
419    /// Stores the arguments to the return statement, which are later folded into a single return statement at the end of the function.
420    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
421        use Expression::*;
422
423        // If we are traversing an async function, return as is.
424        if self.is_async {
425            return (input.into(), Default::default());
426        }
427        // Construct the associated guard.
428        let (guard_identifier, statements) = self.construct_guard().unzip();
429
430        let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);
431
432        let is_tuple_ids = matches!(&input.expression, Tuple(tuple_expr) if tuple_expr .elements.iter() .all(|expr| matches!(expr, Expression::Path(_))));
433        if !matches!(&input.expression, Unit(_) | Expression::Path(_) | AssociatedConstant(_)) && !is_tuple_ids {
434            panic!("SSA guarantees that the expression is always a Path, unit expression, or tuple literal.")
435        }
436
437        self.returns.push((return_guard, input));
438
439        (Statement::dummy(), statements.unwrap_or_default())
440    }
441}