leo_passes/destructuring/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::DestructuringVisitor;
18use leo_ast::*;
19use leo_span::Symbol;
20
21use itertools::{Itertools, izip};
22
23impl AstReconstructor for DestructuringVisitor<'_> {
24    type AdditionalInput = ();
25    type AdditionalOutput = Vec<Statement>;
26
27    /// Reconstructs a binary expression, expanding equality and inequality over
28    /// tuples into elementwise comparisons. When both sides are tuples and the
29    /// operator is `==` or `!=`, it generates per-element comparisons and folds
30    /// them with AND/OR; otherwise the expression is rebuilt normally.
31    ///
32    /// Example: `(a, b) == (c, d)` → `(a == c) && (b == d)`
33    /// Example: `(a, b, c) != (x, y, z)` → `(a != x) || (b != y) || (c != z)`
34    fn reconstruct_binary(
35        &mut self,
36        input: BinaryExpression,
37        _additional: &Self::AdditionalInput,
38    ) -> (Expression, Self::AdditionalOutput) {
39        let (left, mut statements) = self.reconstruct_expression_tuple(input.left);
40        let (right, statements2) = self.reconstruct_expression_tuple(input.right);
41        statements.extend(statements2);
42
43        use BinaryOperation::*;
44
45        // Tuple equality / inequality expansion
46        if let (Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) = (&left, &right)
47            && matches!(input.op, Eq | Neq)
48        {
49            assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
50
51            // Directly build elementwise (l OP r)
52            let pieces: Vec<Expression> = tuple_left
53                .elements
54                .iter()
55                .zip(&tuple_right.elements)
56                .map(|(l, r)| {
57                    let expr: Expression = BinaryExpression {
58                        op: input.op,
59                        left: l.clone(),
60                        right: r.clone(),
61                        span: Default::default(),
62                        id: self.state.node_builder.next_id(),
63                    }
64                    .into();
65
66                    self.state.type_table.insert(expr.id(), Type::Boolean);
67                    expr
68                })
69                .collect();
70
71            // Fold appropriately
72            let op = match input.op {
73                Eq => BinaryOperation::And,
74                Neq => BinaryOperation::Or,
75                _ => unreachable!(),
76            };
77
78            return (self.fold_with_op(op, pieces.into_iter()), statements);
79        }
80
81        // Fallback
82        (BinaryExpression { op: input.op, left, right, ..input }.into(), Default::default())
83    }
84
85    /// Replaces a tuple access expression with the appropriate expression.
86    fn reconstruct_tuple_access(
87        &mut self,
88        input: TupleAccess,
89        _additional: &(),
90    ) -> (Expression, Self::AdditionalOutput) {
91        let Expression::Path(path) = &input.tuple else {
92            panic!("SSA guarantees that subexpressions are identifiers or literals.");
93        };
94
95        // Look up the expression in the tuple map.
96        match self.tuples.get(&path.identifier().name).and_then(|tuple_names| tuple_names.get(input.index.value())) {
97            Some(id) => (Path::from(*id).into_absolute().into(), Default::default()),
98            None => {
99                if !matches!(self.state.type_table.get(&path.id), Some(Type::Future(_))) {
100                    panic!("Type checking guarantees that all tuple accesses are declared and indices are valid.");
101                }
102
103                let index = Literal::integer(
104                    IntegerType::U32,
105                    input.index.to_string(),
106                    input.span,
107                    self.state.node_builder.next_id(),
108                );
109                self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
110
111                let expr =
112                    ArrayAccess { array: path.clone().into(), index: index.into(), span: input.span, id: input.id }
113                        .into();
114
115                (expr, Default::default())
116            }
117        }
118    }
119
120    /// If this is a ternary expression on tuples of length `n`, we'll need to change it into
121    /// `n` ternary expressions on the members.
122    fn reconstruct_ternary(
123        &mut self,
124        mut input: TernaryExpression,
125        _additional: &(),
126    ) -> (Expression, Self::AdditionalOutput) {
127        let (condition, mut statements) =
128            self.reconstruct_expression(std::mem::take(&mut input.condition), &Default::default());
129        let (if_true, statements2) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_true));
130        statements.extend(statements2);
131        let (if_false, statements3) = self.reconstruct_expression_tuple(std::mem::take(&mut input.if_false));
132        statements.extend(statements3);
133
134        match (if_true, if_false) {
135            (Expression::Tuple(tuple_true), Expression::Tuple(tuple_false)) => {
136                // Aleo's `ternary` opcode doesn't know about tuples, so we have to handle this.
137                let Some(Type::Tuple(tuple_type)) = self.state.type_table.get(&tuple_true.id()) else {
138                    panic!("Should have tuple type");
139                };
140
141                // We'll be reusing `condition`, so assign it to a variable.
142                let cond = if let Expression::Path(..) = condition {
143                    condition
144                } else {
145                    let place = Identifier::new(
146                        self.state.assigner.unique_symbol("cond", "$$"),
147                        self.state.node_builder.next_id(),
148                    );
149
150                    let definition =
151                        self.state.assigner.simple_definition(place, condition, self.state.node_builder.next_id());
152
153                    statements.push(definition);
154
155                    self.state.type_table.insert(place.id(), Type::Boolean);
156
157                    Expression::Path(Path::from(place).into_absolute())
158                };
159
160                // These will be the `elements` of our resulting tuple.
161                let mut elements = Vec::with_capacity(tuple_true.elements.len());
162
163                // Create an individual `ternary` for each tuple member and assign the
164                // result to a new variable.
165                for (i, (lhs, rhs, ty)) in
166                    izip!(tuple_true.elements, tuple_false.elements, tuple_type.elements()).enumerate()
167                {
168                    let identifier = Identifier::new(
169                        self.state.assigner.unique_symbol(format_args!("ternary_{i}"), "$$"),
170                        self.state.node_builder.next_id(),
171                    );
172
173                    let expression: Expression = TernaryExpression {
174                        condition: cond.clone(),
175                        if_true: lhs,
176                        if_false: rhs,
177                        span: Default::default(),
178                        id: self.state.node_builder.next_id(),
179                    }
180                    .into();
181
182                    self.state.type_table.insert(identifier.id(), ty.clone());
183                    self.state.type_table.insert(expression.id(), ty.clone());
184
185                    let definition = self.state.assigner.simple_definition(
186                        identifier,
187                        expression,
188                        self.state.node_builder.next_id(),
189                    );
190
191                    statements.push(definition);
192                    elements.push(Path::from(identifier).into_absolute().into());
193                }
194
195                let expr: Expression =
196                    TupleExpression { elements, span: Default::default(), id: self.state.node_builder.next_id() }
197                        .into();
198
199                self.state.type_table.insert(expr.id(), Type::Tuple(tuple_type.clone()));
200
201                (expr, statements)
202            }
203            (if_true, if_false) => {
204                // This isn't a tuple. Just rebuild it and otherwise leave it alone.
205                (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
206            }
207        }
208    }
209
210    /* Statements */
211    /// `assert_eq` and `assert_neq` comparing tuples should be expanded to as many asserts as
212    /// the length of each tuple.
213    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
214        match input.variant {
215            AssertVariant::Assert(expr) => {
216                // Simple assert, just reconstruct the expression.
217                let (expr, _) = self.reconstruct_expression(expr, &Default::default());
218                (AssertStatement { variant: AssertVariant::Assert(expr), ..input }.into(), Default::default())
219            }
220            AssertVariant::AssertEq(ref left, ref right) | AssertVariant::AssertNeq(ref left, ref right) => {
221                let (left, mut statements) = self.reconstruct_expression_tuple(left.clone());
222                let (right, statements2) = self.reconstruct_expression_tuple(right.clone());
223                statements.extend(statements2);
224
225                match (&left, &right) {
226                    (Expression::Tuple(tuple_left), Expression::Tuple(tuple_right)) => {
227                        // Ensure the tuple lengths match
228                        assert_eq!(tuple_left.elements.len(), tuple_right.elements.len());
229
230                        for (l, r) in tuple_left.elements.iter().zip(&tuple_right.elements) {
231                            let assert_variant = match input.variant {
232                                AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(l.clone(), r.clone()),
233                                AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(l.clone(), r.clone()),
234                                _ => unreachable!(),
235                            };
236
237                            let stmt = AssertStatement { variant: assert_variant, ..input.clone() }.into();
238                            statements.push(stmt);
239                        }
240
241                        // We don't need the original statement, just the ones we've created.
242                        (Statement::dummy(), statements)
243                    }
244                    _ => {
245                        // Not tuples, just keep the original assert
246                        let variant = match input.variant {
247                            AssertVariant::AssertEq(_, _) => AssertVariant::AssertEq(left, right),
248                            AssertVariant::AssertNeq(_, _) => AssertVariant::AssertNeq(left, right),
249                            _ => unreachable!(),
250                        };
251                        (AssertStatement { variant, ..input }.into(), Default::default())
252                    }
253                }
254            }
255        }
256    }
257
258    /// Modify assignments to tuples to become assignments to the corresponding variables.
259    ///
260    /// There are two cases we handle:
261    /// 1. An assignment to a tuple x, like `x = rhs;`.
262    ///    This we need to transform into individual assignments
263    ///    `x_i = rhs_i;`
264    ///    of the variables corresponding to members of `x` and `rhs`.
265    /// 2. An assignment to a tuple member, like `x.2[i].member = rhs;`.
266    ///    This we need to change into
267    ///    `x_2[i].member = rhs;`
268    ///    where `x_2` is the variable corresponding to `x.2`.
269    fn reconstruct_assign(&mut self, mut assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
270        let (value, mut statements) = self.reconstruct_expression(assign.value, &());
271
272        if let Expression::Path(path) = &assign.place
273            && let Type::Tuple(..) = self.state.type_table.get(&value.id()).expect("Expressions should have types.")
274        {
275            // This is the first case, assigning to a variable of tuple type.
276            let identifiers = self.tuples.get(&path.identifier().name).expect("Tuple should have been encountered.");
277
278            let Expression::Path(rhs) = value else {
279                panic!("SSA should have ensured this is an identifier.");
280            };
281
282            let rhs_identifiers = self.tuples.get(&rhs.identifier().name).expect("Tuple should have been encountered.");
283
284            // Again, make an assignment for each identifier.
285            for (&identifier, &rhs_identifier) in identifiers.iter().zip_eq(rhs_identifiers) {
286                let stmt = AssignStatement {
287                    place: Path::from(identifier).into_absolute().into(),
288                    value: Path::from(rhs_identifier).into_absolute().into(),
289                    id: self.state.node_builder.next_id(),
290                    span: Default::default(),
291                }
292                .into();
293
294                statements.push(stmt);
295            }
296
297            // We don't need the original assignment, just the ones we've created.
298            return (Statement::dummy(), statements);
299        }
300
301        // We need to check for case 2, so we loop and see if we find a tuple access.
302
303        assign.value = value;
304        let mut place = &mut assign.place;
305
306        loop {
307            // Loop through the places in the assignment to the top-level expression until an identifier or tuple access is reached.
308            match place {
309                Expression::TupleAccess(access) => {
310                    // We're assigning to a tuple member, case 2 mentioned above.
311                    let Expression::Path(path) = &access.tuple else {
312                        panic!("SSA should have ensured this is an identifier.");
313                    };
314
315                    let tuple_ids =
316                        self.tuples.get(&path.identifier().name).expect("Tuple should have been encountered.");
317
318                    // This is the corresponding variable name of the member we're assigning to.
319                    let identifier = tuple_ids[access.index.value()];
320
321                    *place = Path::from(identifier).into_absolute().into();
322
323                    return (assign.into(), statements);
324                }
325
326                Expression::ArrayAccess(access) => {
327                    // We need to investigate the array, as maybe it's inside a tuple access, like `tupl.0[1u8]`.
328                    place = &mut access.array;
329                }
330
331                Expression::MemberAccess(access) => {
332                    // We need to investigate the struct, as maybe it's inside a tuple access, like `tupl.0.mem`.
333                    place = &mut access.inner;
334                }
335
336                Expression::Path(..) => {
337                    // There was no tuple access, so this is neither case 1 nor 2; there's nothing to do.
338                    return (assign.into(), statements);
339                }
340
341                _ => panic!("Type checking should have prevented this."),
342            }
343        }
344    }
345
346    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
347        let mut statements = Vec::with_capacity(block.statements.len());
348
349        // Reconstruct the statements in the block, accumulating any additional statements.
350        for statement in block.statements {
351            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
352            statements.extend(additional_statements);
353            if !reconstructed_statement.is_empty() {
354                statements.push(reconstructed_statement);
355            }
356        }
357
358        (Block { statements, ..block }, Default::default())
359    }
360
361    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
362        let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
363        let (then, statements2) = self.reconstruct_block(input.then);
364        statements.extend(statements2);
365        let otherwise = input.otherwise.map(|oth| {
366            let (expr, statements3) = self.reconstruct_statement(*oth);
367            statements.extend(statements3);
368            Box::new(expr)
369        });
370        (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
371    }
372
373    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
374        use DefinitionPlace::*;
375
376        let make_identifiers = |slf: &mut Self, single: Symbol, count: usize| -> Vec<Identifier> {
377            (0..count)
378                .map(|i| {
379                    Identifier::new(
380                        slf.state.assigner.unique_symbol(format_args!("{single}#tuple{i}"), "$"),
381                        slf.state.node_builder.next_id(),
382                    )
383                })
384                .collect()
385        };
386
387        let (value, mut statements) = self.reconstruct_expression(definition.value, &());
388        let ty = self.state.type_table.get(&value.id()).expect("Expressions should have a type.");
389        match (definition.place, value, ty) {
390            (Single(identifier), Expression::Path(rhs), Type::Tuple(tuple_type)) => {
391                // We need to give the members new names, in case they are assigned to.
392                let identifiers = make_identifiers(self, identifier.name, tuple_type.length());
393
394                let rhs_identifiers = self.tuples.get(&rhs.identifier().name).unwrap();
395
396                for (identifier, rhs_identifier, ty) in izip!(&identifiers, rhs_identifiers, tuple_type.elements()) {
397                    // Make a definition for each.
398                    let stmt = DefinitionStatement {
399                        place: Single(*identifier),
400                        type_: Some(ty.clone()),
401                        value: Expression::Path(Path::from(*rhs_identifier).into_absolute()),
402                        span: Default::default(),
403                        id: self.state.node_builder.next_id(),
404                    }
405                    .into();
406                    statements.push(stmt);
407
408                    // Put each into the type table.
409                    self.state.type_table.insert(identifier.id(), ty.clone());
410                }
411
412                // Put the identifier in `self.tuples`. We don't need to keep our definition.
413                self.tuples.insert(identifier.name, identifiers);
414                (Statement::dummy(), statements)
415            }
416            (Single(identifier), Expression::Tuple(tuple), Type::Tuple(tuple_type)) => {
417                // Name each of the expressions on the right.
418                let identifiers = make_identifiers(self, identifier.name, tuple_type.length());
419
420                for (identifier, expr, ty) in izip!(&identifiers, tuple.elements, tuple_type.elements()) {
421                    // Make a definition for each.
422                    let stmt = DefinitionStatement {
423                        place: Single(*identifier),
424                        type_: Some(ty.clone()),
425                        value: expr,
426                        span: Default::default(),
427                        id: self.state.node_builder.next_id(),
428                    }
429                    .into();
430                    statements.push(stmt);
431
432                    // Put each into the type table.
433                    self.state.type_table.insert(identifier.id(), ty.clone());
434                }
435
436                // Put the identifier in `self.tuples`. We don't need to keep our definition.
437                self.tuples.insert(identifier.name, identifiers);
438                (Statement::dummy(), statements)
439            }
440            (Single(identifier), rhs @ Expression::Call(..), Type::Tuple(tuple_type)) => {
441                let definition_stmt = self.assign_tuple(rhs, identifier.name);
442
443                let Statement::Definition(DefinitionStatement {
444                    place: DefinitionPlace::Multiple(identifiers), ..
445                }) = &definition_stmt
446                else {
447                    panic!("assign_tuple creates `Multiple`.");
448                };
449
450                // Put it into `self.tuples`.
451                self.tuples.insert(identifier.name, identifiers.clone());
452
453                // Put each into the type table.
454                for (identifier, ty) in identifiers.iter().zip(tuple_type.elements()) {
455                    self.state.type_table.insert(identifier.id(), ty.clone());
456                }
457
458                (definition_stmt, statements)
459            }
460            (Multiple(identifiers), Expression::Tuple(tuple), Type::Tuple(..)) => {
461                // Just make a definition for each tuple element.
462                for (identifier, expr) in identifiers.into_iter().zip_eq(tuple.elements) {
463                    let stmt = DefinitionStatement {
464                        place: Single(identifier),
465                        type_: None,
466                        value: expr,
467                        span: Default::default(),
468                        id: self.state.node_builder.next_id(),
469                    }
470                    .into();
471                    statements.push(stmt);
472                }
473
474                // We don't need to keep the original definition.
475                (Statement::dummy(), statements)
476            }
477            (Multiple(identifiers), Expression::Path(rhs), Type::Tuple(..)) => {
478                // Again, make a definition for each tuple element.
479                let rhs_identifiers =
480                    self.tuples.get(&rhs.identifier().name).expect("We should have encountered this tuple by now");
481                for (identifier, rhs_identifier) in identifiers.into_iter().zip_eq(rhs_identifiers.iter()) {
482                    let stmt = DefinitionStatement {
483                        place: Single(identifier),
484                        type_: None,
485                        value: Expression::Path(Path::from(*rhs_identifier).into_absolute()),
486                        span: Default::default(),
487                        id: self.state.node_builder.next_id(),
488                    }
489                    .into();
490                    statements.push(stmt);
491                }
492
493                // We don't need to keep the original definition.
494                (Statement::dummy(), statements)
495            }
496            (m @ Multiple(..), value @ Expression::Call(..), Type::Tuple(..)) => {
497                // Just reconstruct the statement.
498                let stmt =
499                    DefinitionStatement { place: m, type_: None, value, span: definition.span, id: definition.id }
500                        .into();
501                (stmt, statements)
502            }
503            (_, _, Type::Tuple(..)) => {
504                panic!("Expressions of tuple type can only be tuple literals, identifiers, or calls.");
505            }
506            (s @ Single(..), rhs, _) => {
507                // This isn't a tuple. Just build the definition again.
508                let stmt = DefinitionStatement {
509                    place: s,
510                    type_: None,
511                    value: rhs,
512                    span: Default::default(),
513                    id: definition.id,
514                }
515                .into();
516                (stmt, statements)
517            }
518            (Multiple(_), _, _) => panic!("A definition with multiple identifiers must have tuple type"),
519        }
520    }
521
522    fn reconstruct_iteration(&mut self, _: IterationStatement) -> (Statement, Self::AdditionalOutput) {
523        panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
524    }
525
526    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
527        let (expression, statements) = self.reconstruct_expression_tuple(input.expression);
528        (ReturnStatement { expression, ..input }.into(), statements)
529    }
530}