leo_passes/flattening/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{
20    ArrayAccess,
21    ArrayExpression,
22    ArrayType,
23    AstReconstructor,
24    BinaryExpression,
25    BinaryOperation,
26    Block,
27    Composite,
28    CompositeType,
29    Expression,
30    Identifier,
31    IntegerType,
32    Literal,
33    Member,
34    MemberAccess,
35    Node,
36    NonNegativeNumber,
37    Path,
38    ReturnStatement,
39    Statement,
40    StructExpression,
41    StructVariableInitializer,
42    TernaryExpression,
43    TupleAccess,
44    TupleExpression,
45    TupleType,
46    Type,
47    UnitExpression,
48};
49use leo_span::Symbol;
50
51/// An expression representing a conditional to reach the current
52/// point in the AST.
53#[derive(Clone, Copy)]
54pub enum Guard {
55    /// An Unconstructed guard is one representing a single conditional
56    /// on the stack of conditions.
57    Unconstructed(Identifier),
58
59    /// A Constructed guard is one which as been `And`ed with all previous
60    /// conditions on the stack.
61    ///
62    /// We cache this so that we don't have to evaluate the same chain
63    /// of conditions repeatedly.
64    Constructed(Identifier),
65}
66
67#[derive(Clone, Copy)]
68pub enum ReturnGuard {
69    /// There were no conditionals on the path to this return statement.
70    None,
71
72    /// There was a chain of conditionals on the path to this return statement,
73    /// and they are true iff this Identifier is true.
74    Unconstructed(Identifier),
75
76    /// There was a chain of conditionals on the path to this return statement.`
77    Constructed {
78        /// True iff the conditionals on the path to this return statement are true.
79        plain: Identifier,
80
81        /// True iff any of the guards to return statements so far encountered
82        /// are true. We cache this to guard asserts against early returns.
83        any_return: Identifier,
84    },
85}
86
87impl Guard {
88    fn identifier(self) -> Identifier {
89        match self {
90            Guard::Constructed(id) | Guard::Unconstructed(id) => id,
91        }
92    }
93}
94
95pub struct FlatteningVisitor<'a> {
96    pub state: &'a mut CompilerState,
97
98    /// A stack of condition `Expression`s visited up to the current point in the AST.
99    pub condition_stack: Vec<Guard>,
100
101    /// A list containing tuples of guards and expressions associated `ReturnStatement`s.
102    /// A guard is an expression that evaluates to true on the execution path of the `ReturnStatement`.
103    /// Note that returns are inserted in the order they are encountered during a pre-order traversal of the AST.
104    /// Note that type checking guarantees that there is at most one return in a basic block.
105    pub returns: Vec<(ReturnGuard, ReturnStatement)>,
106
107    /// The program name.
108    pub program: Symbol,
109
110    /// Whether the function is an async function.
111    pub is_async: bool,
112}
113
114impl FlatteningVisitor<'_> {
115    /// Construct an early return guard.
116    ///
117    /// That is, an Identifier assigned to a boolean that is true iff some early return was taken.
118    pub fn construct_early_return_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
119        if self.returns.is_empty() {
120            return None;
121        }
122
123        if self.returns.iter().any(|g| matches!(g.0, ReturnGuard::None)) {
124            // There was a return with no conditions, so we should simple return True.
125            let place = Identifier {
126                name: self.state.assigner.unique_symbol("true", "$"),
127                span: Default::default(),
128                id: self.state.node_builder.next_id(),
129            };
130            let statement = self.simple_definition(
131                place,
132                Literal::boolean(true, Default::default(), self.state.node_builder.next_id()).into(),
133            );
134            return Some((place, vec![statement]));
135        }
136
137        // All guards up to a certain point in the stack should be constructed.
138        // Find the first unconstructed one.
139        let start_i = (0..self.returns.len())
140            .rev()
141            .take_while(|&i| matches!(self.returns[i].0, ReturnGuard::Unconstructed(_)))
142            .last()
143            .unwrap_or(self.returns.len());
144
145        let mut statements = Vec::with_capacity(self.returns.len() - start_i);
146
147        for i in start_i..self.returns.len() {
148            let ReturnGuard::Unconstructed(identifier) = self.returns[i].0 else {
149                unreachable!("We assured above that all guards after the index are Unconstructed.");
150            };
151            if i == 0 {
152                self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: identifier };
153                continue;
154            }
155
156            let ReturnGuard::Constructed { any_return: previous_identifier, .. } = self.returns[i - 1].0 else {
157                unreachable!("We're always at an index where previous guards were Constructed.");
158            };
159
160            // Construct an Or of the two expressions.
161            let binary = BinaryExpression {
162                op: BinaryOperation::Or,
163                left: Path::from(previous_identifier).into_absolute().into(),
164                right: Path::from(identifier).into_absolute().into(),
165                span: Default::default(),
166                id: self.state.node_builder.next_id(),
167            };
168            self.state.type_table.insert(binary.id(), Type::Boolean);
169
170            // Assign that Or to a new Identifier.
171            let place = Identifier {
172                name: self.state.assigner.unique_symbol("guard", "$"),
173                span: Default::default(),
174                id: self.state.node_builder.next_id(),
175            };
176            statements.push(self.simple_definition(place, binary.into()));
177
178            // Make that assigned Identifier the constructed guard.
179            self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: place };
180        }
181
182        let ReturnGuard::Constructed { any_return, .. } = self.returns.last().unwrap().0 else {
183            unreachable!("Above we made all guards Constructed.");
184        };
185
186        Some((any_return, statements))
187    }
188
189    /// Construct a guard from the current state of the condition stack.
190    ///
191    /// That is, a boolean expression which is true iff we've followed the branches
192    /// that led to the current point in the Leo code.
193    pub fn construct_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
194        if self.condition_stack.is_empty() {
195            return None;
196        }
197
198        // All guards up to a certain point in the stack should be constructed.
199        // Find the first unconstructed one. Start the search at the end so we
200        // don't repeatedly traverse the whole stack with repeated calls to this
201        // function.
202        let start_i = (0..self.condition_stack.len())
203            .rev()
204            .take_while(|&i| matches!(self.condition_stack[i], Guard::Unconstructed(_)))
205            .last()
206            .unwrap_or(self.condition_stack.len());
207
208        let mut statements = Vec::with_capacity(self.condition_stack.len() - start_i);
209
210        for i in start_i..self.condition_stack.len() {
211            let identifier = self.condition_stack[i].identifier();
212            if i == 0 {
213                self.condition_stack[0] = Guard::Constructed(identifier);
214                continue;
215            }
216
217            let previous = self.condition_stack[i - 1].identifier();
218
219            // Construct an And of the two expressions.
220            let binary = BinaryExpression {
221                op: BinaryOperation::And,
222                left: Path::from(previous).into_absolute().into(),
223                right: Path::from(identifier).into_absolute().into(),
224                span: Default::default(),
225                id: self.state.node_builder.next_id(),
226            };
227            self.state.type_table.insert(binary.id(), Type::Boolean);
228
229            // Assign that And to a new Identifier.
230            let place = Identifier {
231                name: self.state.assigner.unique_symbol("guard", "$"),
232                span: Default::default(),
233                id: self.state.node_builder.next_id(),
234            };
235            statements.push(self.simple_definition(place, binary.into()));
236
237            // Make that assigned Identifier the constructed guard.
238            self.condition_stack[i] = Guard::Constructed(place);
239        }
240
241        Some((self.condition_stack.last().unwrap().identifier(), statements))
242    }
243
244    /// Fold guards and expressions into a single expression.
245    /// Note that this function assumes that at least one guard is present.
246    pub fn fold_guards(
247        &mut self,
248        prefix: &str,
249        mut guards: Vec<(Option<Expression>, Expression)>,
250    ) -> (Expression, Vec<Statement>) {
251        // Type checking guarantees that there exists at least one return statement in the function body.
252        let (_, last_expression) = guards.pop().unwrap();
253
254        match last_expression {
255            // If the expression is a unit expression, then return it directly.
256            Expression::Unit(_) => (last_expression, Vec::new()),
257            // Otherwise, fold the guards and expressions into a single expression.
258            _ => {
259                // Produce a chain of ternary expressions and assignments for the guards.
260                let mut statements = Vec::with_capacity(guards.len());
261
262                // Helper to construct and store ternary assignments. e.g `$ret$0 = $var$0 ? $var$1 : $var$2`
263                let mut construct_ternary_assignment =
264                    |guard: Expression, if_true: Expression, if_false: Expression| {
265                        let place = Identifier {
266                            name: self.state.assigner.unique_symbol(prefix, "$"),
267                            span: Default::default(),
268                            id: self.state.node_builder.next_id(),
269                        };
270                        let Some(type_) = self.state.type_table.get(&if_true.id()) else {
271                            panic!("Type checking guarantees that all expressions have a type.");
272                        };
273                        let ternary = TernaryExpression {
274                            condition: guard,
275                            if_true,
276                            if_false,
277                            span: Default::default(),
278                            id: self.state.node_builder.next_id(),
279                        };
280                        self.state.type_table.insert(ternary.id(), type_);
281                        let (value, stmts) = self.reconstruct_ternary(ternary, &());
282                        statements.extend(stmts);
283
284                        if let Expression::Tuple(..) = &value {
285                            // If the expression is a tuple, then use it directly.
286                            // This must be done to ensure that intermediate tuple assignments are not created.
287                            value
288                        } else {
289                            // Otherwise, assign the expression to a variable and return the variable.
290                            statements.push(self.simple_definition(place, value));
291                            Path::from(place).into_absolute().into()
292                        }
293                    };
294
295                let expression = guards.into_iter().rev().fold(last_expression, |acc, (guard, expr)| match guard {
296                    None => unreachable!("All expressions except for the last one must have a guard."),
297                    // Note that type checking guarantees that all expressions have the same type.
298                    Some(guard) => construct_ternary_assignment(guard, expr, acc),
299                });
300
301                (expression, statements)
302            }
303        }
304    }
305
306    /// A wrapper around `assigner.unique_simple_definition` that updates `self.structs`.
307    pub fn unique_simple_definition(&mut self, expr: Expression) -> (Identifier, Statement) {
308        // Create a new variable for the expression.
309        let name = self.state.assigner.unique_symbol("$var", "$");
310        // Construct the lhs of the assignment.
311        let place = Identifier { name, span: Default::default(), id: self.state.node_builder.next_id() };
312        // Construct the assignment statement.
313        let statement = self.simple_definition(place, expr);
314
315        (place, statement)
316    }
317
318    /// A wrapper around `assigner.simple_definition` that tracks the type of the lhs.
319    pub fn simple_definition(&mut self, lhs: Identifier, rhs: Expression) -> Statement {
320        // Update the type table.
321        let type_ = match self.state.type_table.get(&rhs.id()) {
322            Some(type_) => type_,
323            None => unreachable!("Type checking guarantees that all expressions have a type."),
324        };
325        self.state.type_table.insert(lhs.id(), type_);
326        // Construct the statement.
327        self.state.assigner.simple_definition(lhs, rhs, self.state.node_builder.next_id())
328    }
329
330    /// Folds a list of return statements into a single return statement and adds the produced statements to the block.
331    pub fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
332        // If the list of returns is not empty, then fold them into a single return statement.
333        if !returns.is_empty() {
334            let mut return_expressions = Vec::with_capacity(returns.len());
335
336            // Aggregate the return expressions and finalize arguments and their respective guards.
337            for (guard, return_statement) in returns {
338                return_expressions.push((guard.clone(), return_statement.expression));
339            }
340
341            // Fold the return expressions into a single expression.
342            let (expression, stmts) = self.fold_guards("$ret", return_expressions);
343
344            // Add all of the accumulated statements to the end of the block.
345            block.statements.extend(stmts);
346
347            // Add the `ReturnStatement` to the end of the block.
348            block.statements.push(
349                ReturnStatement { expression, span: Default::default(), id: self.state.node_builder.next_id() }.into(),
350            );
351        }
352        // Otherwise, push a dummy return statement to the end of the block.
353        else {
354            block.statements.push(
355                ReturnStatement {
356                    expression: UnitExpression { span: Default::default(), id: self.state.node_builder.next_id() }
357                        .into(),
358                    span: Default::default(),
359                    id: self.state.node_builder.next_id(),
360                }
361                .into(),
362            );
363        }
364    }
365
366    // For use in `ternary_array`.
367    fn make_array_access_definition(
368        &mut self,
369        i: usize,
370        identifier: Identifier,
371        array_type: &ArrayType,
372    ) -> (Identifier, Statement) {
373        let index =
374            Literal::integer(IntegerType::U32, i.to_string(), Default::default(), self.state.node_builder.next_id());
375        self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
376        let access: Expression = ArrayAccess {
377            array: Path::from(identifier).into_absolute().into(),
378            index: index.into(),
379            span: Default::default(),
380            id: self.state.node_builder.next_id(),
381        }
382        .into();
383        self.state.type_table.insert(access.id(), array_type.element_type().clone());
384        self.unique_simple_definition(access)
385    }
386
387    pub fn ternary_array(
388        &mut self,
389        array: &ArrayType,
390        condition: &Expression,
391        first: &Identifier,
392        second: &Identifier,
393    ) -> (Expression, Vec<Statement>) {
394        // Initialize a vector to accumulate any statements generated.
395        let mut statements = Vec::new();
396        // For each array element, construct a new ternary expression.
397        let elements = (0..array.length.as_u32().expect("length should be known at this point") as usize)
398            .map(|i| {
399                // Create an assignment statement for the first access expression.
400                let (first, stmt) = self.make_array_access_definition(i, *first, array);
401                statements.push(stmt);
402                // Create an assignment statement for the second access expression.
403                let (second, stmt) = self.make_array_access_definition(i, *second, array);
404                statements.push(stmt);
405
406                // Recursively reconstruct the ternary expression.
407                let ternary = TernaryExpression {
408                    condition: condition.clone(),
409                    // Access the member of the first expression.
410                    if_true: Path::from(first).into_absolute().into(),
411                    // Access the member of the second expression.
412                    if_false: Path::from(second).into_absolute().into(),
413                    span: Default::default(),
414                    id: self.state.node_builder.next_id(),
415                };
416                self.state.type_table.insert(ternary.id(), array.element_type().clone());
417
418                let (expression, stmts) = self.reconstruct_ternary(ternary, &());
419
420                // Accumulate any statements generated.
421                statements.extend(stmts);
422
423                expression
424            })
425            .collect();
426
427        // Construct the array expression.
428        let (expr, stmts) = self.reconstruct_array(
429            ArrayExpression {
430                elements,
431                span: Default::default(),
432                id: {
433                    // Create a node ID for the array expression.
434                    let id = self.state.node_builder.next_id();
435                    // Set the type of the node ID.
436                    self.state.type_table.insert(id, Type::Array(array.clone()));
437                    id
438                },
439            },
440            &(),
441        );
442
443        // Accumulate any statements generated.
444        statements.extend(stmts);
445
446        // Create a new assignment statement for the array expression.
447        let (identifier, statement) = self.unique_simple_definition(expr);
448
449        statements.push(statement);
450
451        (Path::from(identifier).into_absolute().into(), statements)
452    }
453
454    // For use in `ternary_struct`.
455    fn make_struct_access_definition(
456        &mut self,
457        inner: Identifier,
458        name: Identifier,
459        type_: Type,
460    ) -> (Identifier, Statement) {
461        let expr: Expression = MemberAccess {
462            inner: Path::from(inner).into_absolute().into(),
463            name,
464            span: Default::default(),
465            id: self.state.node_builder.next_id(),
466        }
467        .into();
468        self.state.type_table.insert(expr.id(), type_);
469        self.unique_simple_definition(expr)
470    }
471
472    pub fn ternary_struct(
473        &mut self,
474        struct_path: &Path,
475        struct_: &Composite,
476        condition: &Expression,
477        first: &Identifier,
478        second: &Identifier,
479    ) -> (Expression, Vec<Statement>) {
480        // Initialize a vector to accumulate any statements generated.
481        let mut statements = Vec::new();
482        // For each struct member, construct a new ternary expression.
483        let members = struct_
484            .members
485            .iter()
486            .map(|Member { identifier, type_, .. }| {
487                let (first, stmt) = self.make_struct_access_definition(*first, *identifier, type_.clone());
488                statements.push(stmt);
489                let (second, stmt) = self.make_struct_access_definition(*second, *identifier, type_.clone());
490                statements.push(stmt);
491                // Recursively reconstruct the ternary expression.
492                let ternary = TernaryExpression {
493                    condition: condition.clone(),
494                    if_true: Path::from(first).into_absolute().into(),
495                    if_false: Path::from(second).into_absolute().into(),
496                    span: Default::default(),
497                    id: self.state.node_builder.next_id(),
498                };
499                self.state.type_table.insert(ternary.id(), type_.clone());
500                let (expression, stmts) = self.reconstruct_ternary(ternary, &());
501
502                // Accumulate any statements generated.
503                statements.extend(stmts);
504
505                StructVariableInitializer {
506                    identifier: *identifier,
507                    expression: Some(expression),
508                    span: Default::default(),
509                    id: self.state.node_builder.next_id(),
510                }
511            })
512            .collect();
513
514        let (expr, stmts) = self.reconstruct_struct_init(
515            StructExpression {
516                path: struct_path.clone(),
517                const_arguments: Vec::new(), // All const arguments should have been resolved by now
518                members,
519                span: Default::default(),
520                id: {
521                    // Create a new node ID for the struct expression.
522                    let id = self.state.node_builder.next_id();
523                    // Set the type of the node ID.
524                    self.state.type_table.insert(
525                        id,
526                        Type::Composite(CompositeType {
527                            path: struct_path.clone(),
528                            const_arguments: Vec::new(), // all const generics should have been resolved by now
529                            program: struct_.external,
530                        }),
531                    );
532                    id
533                },
534            },
535            &(),
536        );
537
538        // Accumulate any statements generated.
539        statements.extend(stmts);
540
541        // Create a new assignment statement for the struct expression.
542        let (identifier, statement) = self.unique_simple_definition(expr);
543
544        statements.push(statement);
545
546        (Path::from(identifier).into_absolute().into(), statements)
547    }
548
549    pub fn ternary_tuple(
550        &mut self,
551        tuple_type: &TupleType,
552        condition: &Expression,
553        first: &Expression,
554        second: &Expression,
555    ) -> (Expression, Vec<Statement>) {
556        let make_access = |base_expression: &Expression, i: usize, ty: Type, slf: &mut Self| -> Expression {
557            match base_expression {
558                expr @ Expression::Path(..) => {
559                    // Create a new node ID for the access expression.
560                    let id = slf.state.node_builder.next_id();
561                    // Set the type of the node ID.
562                    slf.state.type_table.insert(id, ty);
563                    TupleAccess { tuple: expr.clone(), index: NonNegativeNumber::from(i), span: Default::default(), id }
564                        .into()
565                }
566
567                Expression::Tuple(tuple_expr) => tuple_expr.elements[i].clone(),
568
569                _ => panic!("SSA should have prevented this"),
570            }
571        };
572
573        // Initialize a vector to accumulate any statements generated.
574        let mut statements = Vec::new();
575        // For each tuple element, construct a new ternary expression.
576        let elements = tuple_type
577            .elements()
578            .iter()
579            .enumerate()
580            .map(|(i, type_)| {
581                // Create an assignment statement for the first access expression.
582                let access1 = make_access(first, i, type_.clone(), self);
583                let (first, stmt) = self.unique_simple_definition(access1);
584                statements.push(stmt);
585                let access2 = make_access(second, i, type_.clone(), self);
586                // Create an assignment statement for the second access expression.
587                let (second, stmt) = self.unique_simple_definition(access2);
588                statements.push(stmt);
589
590                // Recursively reconstruct the ternary expression.
591                let ternary = TernaryExpression {
592                    condition: condition.clone(),
593                    if_true: Path::from(first).into_absolute().into(),
594                    if_false: Path::from(second).into_absolute().into(),
595                    span: Default::default(),
596                    id: self.state.node_builder.next_id(),
597                };
598                self.state.type_table.insert(ternary.id(), type_.clone());
599                let (expression, stmts) = self.reconstruct_ternary(ternary, &());
600
601                // Accumulate any statements generated.
602                statements.extend(stmts);
603
604                expression
605            })
606            .collect();
607
608        // Construct the tuple expression.
609        let tuple = TupleExpression {
610            elements,
611            span: Default::default(),
612            id: {
613                // Create a new node ID for the tuple expression.
614                let id = self.state.node_builder.next_id();
615                // Set the type of the node ID.
616                self.state.type_table.insert(id, Type::Tuple(tuple_type.clone()));
617                id
618            },
619        };
620        let (expr, stmts) = self.reconstruct_tuple(tuple, &());
621
622        // Accumulate any statements generated.
623        statements.extend(stmts);
624
625        if let Expression::Path(..) = first {
626            // Create a new assignment statement for the tuple expression.
627            let (identifier, statement) = self.unique_simple_definition(expr);
628
629            statements.push(statement);
630
631            (Path::from(identifier).into_absolute().into(), statements)
632        } else {
633            // Just use the tuple we just made.
634            (expr, statements)
635        }
636    }
637}