leo_passes/flattening/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{
20    ArrayAccess,
21    ArrayExpression,
22    ArrayType,
23    AstReconstructor,
24    BinaryExpression,
25    BinaryOperation,
26    Block,
27    Composite,
28    CompositeType,
29    Expression,
30    Identifier,
31    IntegerType,
32    Literal,
33    Member,
34    MemberAccess,
35    Node,
36    NonNegativeNumber,
37    ReturnStatement,
38    Statement,
39    StructExpression,
40    StructVariableInitializer,
41    TernaryExpression,
42    TupleAccess,
43    TupleExpression,
44    TupleType,
45    Type,
46    UnitExpression,
47};
48use leo_span::Symbol;
49
50/// An expression representing a conditional to reach the current
51/// point in the AST.
52#[derive(Clone, Copy)]
53pub enum Guard {
54    /// An Unconstructed guard is one representing a single conditional
55    /// on the stack of conditions.
56    Unconstructed(Identifier),
57
58    /// A Constructed guard is one which as been `And`ed with all previous
59    /// conditions on the stack.
60    ///
61    /// We cache this so that we don't have to evaluate the same chain
62    /// of conditions repeatedly.
63    Constructed(Identifier),
64}
65
66#[derive(Clone, Copy)]
67pub enum ReturnGuard {
68    /// There were no conditionals on the path to this return statement.
69    None,
70
71    /// There was a chain of conditionals on the path to this return statement,
72    /// and they are true iff this Identifier is true.
73    Unconstructed(Identifier),
74
75    /// There was a chain of conditionals on the path to this return statement.`
76    Constructed {
77        /// True iff the conditionals on the path to this return statement are true.
78        plain: Identifier,
79
80        /// True iff any of the guards to return statements so far encountered
81        /// are true. We cache this to guard asserts against early returns.
82        any_return: Identifier,
83    },
84}
85
86impl Guard {
87    fn identifier(self) -> Identifier {
88        match self {
89            Guard::Constructed(id) | Guard::Unconstructed(id) => id,
90        }
91    }
92}
93
94pub struct FlatteningVisitor<'a> {
95    pub state: &'a mut CompilerState,
96
97    /// A stack of condition `Expression`s visited up to the current point in the AST.
98    pub condition_stack: Vec<Guard>,
99
100    /// A list containing tuples of guards and expressions associated `ReturnStatement`s.
101    /// A guard is an expression that evaluates to true on the execution path of the `ReturnStatement`.
102    /// Note that returns are inserted in the order they are encountered during a pre-order traversal of the AST.
103    /// Note that type checking guarantees that there is at most one return in a basic block.
104    pub returns: Vec<(ReturnGuard, ReturnStatement)>,
105
106    /// The program name.
107    pub program: Symbol,
108
109    /// Whether the function is an async function.
110    pub is_async: bool,
111}
112
113impl FlatteningVisitor<'_> {
114    /// Construct an early return guard.
115    ///
116    /// That is, an Identifier assigned to a boolean that is true iff some early return was taken.
117    pub fn construct_early_return_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
118        if self.returns.is_empty() {
119            return None;
120        }
121
122        if self.returns.iter().any(|g| matches!(g.0, ReturnGuard::None)) {
123            // There was a return with no conditions, so we should simple return True.
124            let place = Identifier {
125                name: self.state.assigner.unique_symbol("true", "$"),
126                span: Default::default(),
127                id: self.state.node_builder.next_id(),
128            };
129            let statement = self.simple_definition(
130                place,
131                Literal::boolean(true, Default::default(), self.state.node_builder.next_id()).into(),
132            );
133            return Some((place, vec![statement]));
134        }
135
136        // All guards up to a certain point in the stack should be constructed.
137        // Find the first unconstructed one.
138        let start_i = (0..self.returns.len())
139            .rev()
140            .take_while(|&i| matches!(self.returns[i].0, ReturnGuard::Unconstructed(_)))
141            .last()
142            .unwrap_or(self.returns.len());
143
144        let mut statements = Vec::with_capacity(self.returns.len() - start_i);
145
146        for i in start_i..self.returns.len() {
147            let ReturnGuard::Unconstructed(identifier) = self.returns[i].0 else {
148                unreachable!("We assured above that all guards after the index are Unconstructed.");
149            };
150            if i == 0 {
151                self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: identifier };
152                continue;
153            }
154
155            let ReturnGuard::Constructed { any_return: previous_identifier, .. } = self.returns[i - 1].0 else {
156                unreachable!("We're always at an index where previous guards were Constructed.");
157            };
158
159            // Construct an Or of the two expressions.
160            let binary = BinaryExpression {
161                op: BinaryOperation::Or,
162                left: previous_identifier.into(),
163                right: identifier.into(),
164                span: Default::default(),
165                id: self.state.node_builder.next_id(),
166            };
167            self.state.type_table.insert(binary.id(), Type::Boolean);
168
169            // Assign that Or to a new Identifier.
170            let place = Identifier {
171                name: self.state.assigner.unique_symbol("guard", "$"),
172                span: Default::default(),
173                id: self.state.node_builder.next_id(),
174            };
175            statements.push(self.simple_definition(place, binary.into()));
176
177            // Make that assigned Identifier the constructed guard.
178            self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: place };
179        }
180
181        let ReturnGuard::Constructed { any_return, .. } = self.returns.last().unwrap().0 else {
182            unreachable!("Above we made all guards Constructed.");
183        };
184
185        Some((any_return, statements))
186    }
187
188    /// Construct a guard from the current state of the condition stack.
189    ///
190    /// That is, a boolean expression which is true iff we've followed the branches
191    /// that led to the current point in the Leo code.
192    pub fn construct_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
193        if self.condition_stack.is_empty() {
194            return None;
195        }
196
197        // All guards up to a certain point in the stack should be constructed.
198        // Find the first unconstructed one. Start the search at the end so we
199        // don't repeatedly traverse the whole stack with repeated calls to this
200        // function.
201        let start_i = (0..self.condition_stack.len())
202            .rev()
203            .take_while(|&i| matches!(self.condition_stack[i], Guard::Unconstructed(_)))
204            .last()
205            .unwrap_or(self.condition_stack.len());
206
207        let mut statements = Vec::with_capacity(self.condition_stack.len() - start_i);
208
209        for i in start_i..self.condition_stack.len() {
210            let identifier = self.condition_stack[i].identifier();
211            if i == 0 {
212                self.condition_stack[0] = Guard::Constructed(identifier);
213                continue;
214            }
215
216            let previous = self.condition_stack[i - 1].identifier();
217
218            // Construct an And of the two expressions.
219            let binary = BinaryExpression {
220                op: BinaryOperation::And,
221                left: previous.into(),
222                right: identifier.into(),
223                span: Default::default(),
224                id: self.state.node_builder.next_id(),
225            };
226            self.state.type_table.insert(binary.id(), Type::Boolean);
227
228            // Assign that And to a new Identifier.
229            let place = Identifier {
230                name: self.state.assigner.unique_symbol("guard", "$"),
231                span: Default::default(),
232                id: self.state.node_builder.next_id(),
233            };
234            statements.push(self.simple_definition(place, binary.into()));
235
236            // Make that assigned Identifier the constructed guard.
237            self.condition_stack[i] = Guard::Constructed(place);
238        }
239
240        Some((self.condition_stack.last().unwrap().identifier(), statements))
241    }
242
243    /// Fold guards and expressions into a single expression.
244    /// Note that this function assumes that at least one guard is present.
245    pub fn fold_guards(
246        &mut self,
247        prefix: &str,
248        mut guards: Vec<(Option<Expression>, Expression)>,
249    ) -> (Expression, Vec<Statement>) {
250        // Type checking guarantees that there exists at least one return statement in the function body.
251        let (_, last_expression) = guards.pop().unwrap();
252
253        match last_expression {
254            // If the expression is a unit expression, then return it directly.
255            Expression::Unit(_) => (last_expression, Vec::new()),
256            // Otherwise, fold the guards and expressions into a single expression.
257            _ => {
258                // Produce a chain of ternary expressions and assignments for the guards.
259                let mut statements = Vec::with_capacity(guards.len());
260
261                // Helper to construct and store ternary assignments. e.g `$ret$0 = $var$0 ? $var$1 : $var$2`
262                let mut construct_ternary_assignment =
263                    |guard: Expression, if_true: Expression, if_false: Expression| {
264                        let place = Identifier {
265                            name: self.state.assigner.unique_symbol(prefix, "$"),
266                            span: Default::default(),
267                            id: self.state.node_builder.next_id(),
268                        };
269                        let Some(type_) = self.state.type_table.get(&if_true.id()) else {
270                            panic!("Type checking guarantees that all expressions have a type.");
271                        };
272                        let ternary = TernaryExpression {
273                            condition: guard,
274                            if_true,
275                            if_false,
276                            span: Default::default(),
277                            id: self.state.node_builder.next_id(),
278                        };
279                        self.state.type_table.insert(ternary.id(), type_);
280                        let (value, stmts) = self.reconstruct_ternary(ternary);
281                        statements.extend(stmts);
282
283                        if let Expression::Tuple(..) = &value {
284                            // If the expression is a tuple, then use it directly.
285                            // This must be done to ensure that intermediate tuple assignments are not created.
286                            value
287                        } else {
288                            // Otherwise, assign the expression to a variable and return the variable.
289                            statements.push(self.simple_definition(place, value));
290                            place.into()
291                        }
292                    };
293
294                let expression = guards.into_iter().rev().fold(last_expression, |acc, (guard, expr)| match guard {
295                    None => unreachable!("All expressions except for the last one must have a guard."),
296                    // Note that type checking guarantees that all expressions have the same type.
297                    Some(guard) => construct_ternary_assignment(guard, expr, acc),
298                });
299
300                (expression, statements)
301            }
302        }
303    }
304
305    /// A wrapper around `assigner.unique_simple_definition` that updates `self.structs`.
306    pub fn unique_simple_definition(&mut self, expr: Expression) -> (Identifier, Statement) {
307        // Create a new variable for the expression.
308        let name = self.state.assigner.unique_symbol("$var", "$");
309        // Construct the lhs of the assignment.
310        let place = Identifier { name, span: Default::default(), id: self.state.node_builder.next_id() };
311        // Construct the assignment statement.
312        let statement = self.simple_definition(place, expr);
313
314        (place, statement)
315    }
316
317    /// A wrapper around `assigner.simple_definition` that tracks the type of the lhs.
318    pub fn simple_definition(&mut self, lhs: Identifier, rhs: Expression) -> Statement {
319        // Update the type table.
320        let type_ = match self.state.type_table.get(&rhs.id()) {
321            Some(type_) => type_,
322            None => unreachable!("Type checking guarantees that all expressions have a type."),
323        };
324        self.state.type_table.insert(lhs.id(), type_);
325        // Construct the statement.
326        self.state.assigner.simple_definition(lhs, rhs, self.state.node_builder.next_id())
327    }
328
329    /// Folds a list of return statements into a single return statement and adds the produced statements to the block.
330    pub fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
331        // If the list of returns is not empty, then fold them into a single return statement.
332        if !returns.is_empty() {
333            let mut return_expressions = Vec::with_capacity(returns.len());
334
335            // Aggregate the return expressions and finalize arguments and their respective guards.
336            for (guard, return_statement) in returns {
337                return_expressions.push((guard.clone(), return_statement.expression));
338            }
339
340            // Fold the return expressions into a single expression.
341            let (expression, stmts) = self.fold_guards("$ret", return_expressions);
342
343            // Add all of the accumulated statements to the end of the block.
344            block.statements.extend(stmts);
345
346            // Add the `ReturnStatement` to the end of the block.
347            block.statements.push(
348                ReturnStatement { expression, span: Default::default(), id: self.state.node_builder.next_id() }.into(),
349            );
350        }
351        // Otherwise, push a dummy return statement to the end of the block.
352        else {
353            block.statements.push(
354                ReturnStatement {
355                    expression: UnitExpression { span: Default::default(), id: self.state.node_builder.next_id() }
356                        .into(),
357                    span: Default::default(),
358                    id: self.state.node_builder.next_id(),
359                }
360                .into(),
361            );
362        }
363    }
364
365    // For use in `ternary_array`.
366    fn make_array_access_definition(
367        &mut self,
368        i: usize,
369        identifier: Identifier,
370        array_type: &ArrayType,
371    ) -> (Identifier, Statement) {
372        let index =
373            Literal::integer(IntegerType::U32, i.to_string(), Default::default(), self.state.node_builder.next_id());
374        self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
375        let access: Expression = ArrayAccess {
376            array: identifier.into(),
377            index: index.into(),
378            span: Default::default(),
379            id: self.state.node_builder.next_id(),
380        }
381        .into();
382        self.state.type_table.insert(access.id(), array_type.element_type().clone());
383        self.unique_simple_definition(access)
384    }
385
386    pub fn ternary_array(
387        &mut self,
388        array: &ArrayType,
389        condition: &Expression,
390        first: &Identifier,
391        second: &Identifier,
392    ) -> (Expression, Vec<Statement>) {
393        // Initialize a vector to accumulate any statements generated.
394        let mut statements = Vec::new();
395        // For each array element, construct a new ternary expression.
396        let elements = (0..array.length.as_u32().expect("length should be known at this point") as usize)
397            .map(|i| {
398                // Create an assignment statement for the first access expression.
399                let (first, stmt) = self.make_array_access_definition(i, *first, array);
400                statements.push(stmt);
401                // Create an assignment statement for the second access expression.
402                let (second, stmt) = self.make_array_access_definition(i, *second, array);
403                statements.push(stmt);
404
405                // Recursively reconstruct the ternary expression.
406                let ternary = TernaryExpression {
407                    condition: condition.clone(),
408                    // Access the member of the first expression.
409                    if_true: first.into(),
410                    // Access the member of the second expression.
411                    if_false: second.into(),
412                    span: Default::default(),
413                    id: self.state.node_builder.next_id(),
414                };
415                self.state.type_table.insert(ternary.id(), array.element_type().clone());
416
417                let (expression, stmts) = self.reconstruct_ternary(ternary);
418
419                // Accumulate any statements generated.
420                statements.extend(stmts);
421
422                expression
423            })
424            .collect();
425
426        // Construct the array expression.
427        let (expr, stmts) = self.reconstruct_array(ArrayExpression {
428            elements,
429            span: Default::default(),
430            id: {
431                // Create a node ID for the array expression.
432                let id = self.state.node_builder.next_id();
433                // Set the type of the node ID.
434                self.state.type_table.insert(id, Type::Array(array.clone()));
435                id
436            },
437        });
438
439        // Accumulate any statements generated.
440        statements.extend(stmts);
441
442        // Create a new assignment statement for the array expression.
443        let (identifier, statement) = self.unique_simple_definition(expr);
444
445        statements.push(statement);
446
447        (identifier.into(), statements)
448    }
449
450    // For use in `ternary_struct`.
451    fn make_struct_access_definition(
452        &mut self,
453        inner: Identifier,
454        name: Identifier,
455        type_: Type,
456    ) -> (Identifier, Statement) {
457        let expr: Expression =
458            MemberAccess { inner: inner.into(), name, span: Default::default(), id: self.state.node_builder.next_id() }
459                .into();
460        self.state.type_table.insert(expr.id(), type_);
461        self.unique_simple_definition(expr)
462    }
463
464    pub fn ternary_struct(
465        &mut self,
466        struct_: &Composite,
467        condition: &Expression,
468        first: &Identifier,
469        second: &Identifier,
470    ) -> (Expression, Vec<Statement>) {
471        // Initialize a vector to accumulate any statements generated.
472        let mut statements = Vec::new();
473        // For each struct member, construct a new ternary expression.
474        let members = struct_
475            .members
476            .iter()
477            .map(|Member { identifier, type_, .. }| {
478                let (first, stmt) = self.make_struct_access_definition(*first, *identifier, type_.clone());
479                statements.push(stmt);
480                let (second, stmt) = self.make_struct_access_definition(*second, *identifier, type_.clone());
481                statements.push(stmt);
482                // Recursively reconstruct the ternary expression.
483                let ternary = TernaryExpression {
484                    condition: condition.clone(),
485                    if_true: first.into(),
486                    if_false: second.into(),
487                    span: Default::default(),
488                    id: self.state.node_builder.next_id(),
489                };
490                self.state.type_table.insert(ternary.id(), type_.clone());
491                let (expression, stmts) = self.reconstruct_ternary(ternary);
492
493                // Accumulate any statements generated.
494                statements.extend(stmts);
495
496                StructVariableInitializer {
497                    identifier: *identifier,
498                    expression: Some(expression),
499                    span: Default::default(),
500                    id: self.state.node_builder.next_id(),
501                }
502            })
503            .collect();
504
505        let (expr, stmts) = self.reconstruct_struct_init(StructExpression {
506            name: struct_.identifier,
507            const_arguments: Vec::new(), // All const arguments should have been resolved by now
508            members,
509            span: Default::default(),
510            id: {
511                // Create a new node ID for the struct expression.
512                let id = self.state.node_builder.next_id();
513                // Set the type of the node ID.
514                self.state.type_table.insert(
515                    id,
516                    Type::Composite(CompositeType {
517                        id: struct_.identifier,
518                        const_arguments: Vec::new(), // all const generics should have been resolved by now
519                        program: struct_.external,
520                    }),
521                );
522                id
523            },
524        });
525
526        // Accumulate any statements generated.
527        statements.extend(stmts);
528
529        // Create a new assignment statement for the struct expression.
530        let (identifier, statement) = self.unique_simple_definition(expr);
531
532        statements.push(statement);
533
534        (identifier.into(), statements)
535    }
536
537    pub fn ternary_tuple(
538        &mut self,
539        tuple_type: &TupleType,
540        condition: &Expression,
541        first: &Expression,
542        second: &Expression,
543    ) -> (Expression, Vec<Statement>) {
544        let make_access = |base_expression: &Expression, i: usize, ty: Type, slf: &mut Self| -> Expression {
545            match base_expression {
546                expr @ Expression::Identifier(..) => {
547                    // Create a new node ID for the access expression.
548                    let id = slf.state.node_builder.next_id();
549                    // Set the type of the node ID.
550                    slf.state.type_table.insert(id, ty);
551                    TupleAccess { tuple: expr.clone(), index: NonNegativeNumber::from(i), span: Default::default(), id }
552                        .into()
553                }
554
555                Expression::Tuple(tuple_expr) => tuple_expr.elements[i].clone(),
556
557                _ => panic!("SSA should have prevented this"),
558            }
559        };
560
561        // Initialize a vector to accumulate any statements generated.
562        let mut statements = Vec::new();
563        // For each tuple element, construct a new ternary expression.
564        let elements = tuple_type
565            .elements()
566            .iter()
567            .enumerate()
568            .map(|(i, type_)| {
569                // Create an assignment statement for the first access expression.
570                let access1 = make_access(first, i, type_.clone(), self);
571                let (first, stmt) = self.unique_simple_definition(access1);
572                statements.push(stmt);
573                let access2 = make_access(second, i, type_.clone(), self);
574                // Create an assignment statement for the second access expression.
575                let (second, stmt) = self.unique_simple_definition(access2);
576                statements.push(stmt);
577
578                // Recursively reconstruct the ternary expression.
579                let ternary = TernaryExpression {
580                    condition: condition.clone(),
581                    if_true: first.into(),
582                    if_false: second.into(),
583                    span: Default::default(),
584                    id: self.state.node_builder.next_id(),
585                };
586                self.state.type_table.insert(ternary.id(), type_.clone());
587                let (expression, stmts) = self.reconstruct_ternary(ternary);
588
589                // Accumulate any statements generated.
590                statements.extend(stmts);
591
592                expression
593            })
594            .collect();
595
596        // Construct the tuple expression.
597        let tuple = TupleExpression {
598            elements,
599            span: Default::default(),
600            id: {
601                // Create a new node ID for the tuple expression.
602                let id = self.state.node_builder.next_id();
603                // Set the type of the node ID.
604                self.state.type_table.insert(id, Type::Tuple(tuple_type.clone()));
605                id
606            },
607        };
608        let (expr, stmts) = self.reconstruct_tuple(tuple);
609
610        // Accumulate any statements generated.
611        statements.extend(stmts);
612
613        if let Expression::Identifier(..) = first {
614            // Create a new assignment statement for the tuple expression.
615            let (identifier, statement) = self.unique_simple_definition(expr);
616
617            statements.push(statement);
618
619            (identifier.into(), statements)
620        } else {
621            // Just use the tuple we just made.
622            (expr, statements)
623        }
624    }
625}