leo_passes/storage_lowering/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::StorageLoweringVisitor;
18
19use leo_ast::*;
20use leo_span::{Span, Symbol, sym};
21
22impl leo_ast::AstReconstructor for StorageLoweringVisitor<'_> {
23    type AdditionalInput = ();
24    type AdditionalOutput = Vec<Statement>;
25
26    /* Types */
27    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
28        let (length, stmts) = self.reconstruct_expression(*input.length, &());
29        (
30            Type::Array(ArrayType {
31                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
32                length: Box::new(length),
33            }),
34            stmts,
35        )
36    }
37
38    fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
39        let mut statements = Vec::new();
40
41        let const_arguments = input
42            .const_arguments
43            .into_iter()
44            .map(|arg| {
45                let (expr, stmts) = self.reconstruct_expression(arg, &Default::default());
46                statements.extend(stmts);
47                expr
48            })
49            .collect();
50
51        (Type::Composite(CompositeType { const_arguments, ..input }), statements)
52    }
53
54    /* Expressions */
55    fn reconstruct_array_access(
56        &mut self,
57        mut input: ArrayAccess,
58        _additional: &(),
59    ) -> (Expression, Self::AdditionalOutput) {
60        let (array, mut stmts_array) = self.reconstruct_expression(input.array, &());
61        let (index, mut stmts_index) = self.reconstruct_expression(input.index, &());
62
63        input.array = array;
64        input.index = index;
65
66        // Merge side effects
67        stmts_array.append(&mut stmts_index);
68
69        (input.into(), stmts_array)
70    }
71
72    fn reconstruct_associated_function(
73        &mut self,
74        mut input: AssociatedFunctionExpression,
75        _additional: &(),
76    ) -> (Expression, Self::AdditionalOutput) {
77        match CoreFunction::from_symbols(input.variant.name, input.name.name) {
78            Some(CoreFunction::VectorPush) => {
79                // Input:
80                //   Vector::push(v, 42u32)
81                //
82                // Lowered reconstruction:
83                //   let $len_var = Mapping::get_or_use(len_map, false, 0u32);
84                //   Mapping::set(vec_map, $len_var, 42u32);
85                //   Mapping::set(len_map, false, $len_var + 1);
86
87                // Unpack arguments
88                let [vector_expr, value_expr] = &mut input.arguments[..] else {
89                    panic!("Vector::push should have 2 arguments");
90                };
91
92                // Validate vector type
93                assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
94
95                // Reconstruct value
96                let (value, stmts) = self.reconstruct_expression(value_expr.clone(), &());
97
98                let (vec_values_mapping_name, vec_length_mapping_name) =
99                    self.generate_mapping_names_for_vector(vector_expr);
100                let vec_path_expr = self.symbol_to_path_expr(vec_values_mapping_name);
101                let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
102
103                // let $len_var = Mapping::get_or_use(len_map, false, 0u32)
104                let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
105                let len_var_ident =
106                    Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
107                let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
108                let len_stmt = self.state.assigner.simple_definition(
109                    len_var_ident,
110                    get_len_expr,
111                    self.state.node_builder.next_id(),
112                );
113                let len_var_expr: Expression = len_var_ident.into();
114
115                // index + 1
116                let literal_one = self.literal_one_u32();
117                let increment_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Add, literal_one);
118
119                // Mapping::set(vec__, $len_var, value)
120                let set_vec_stmt_expr = self.set_mapping_expr(vec_path_expr, len_var_expr.clone(), value, input.span);
121
122                // Mapping::set(len_map, false, $len_var + 1)
123                let literal_false = self.literal_false();
124                let set_len_stmt = Statement::Expression(ExpressionStatement {
125                    expression: self.set_mapping_expr(len_path_expr, literal_false, increment_expr, input.span),
126                    span: input.span,
127                    id: self.state.node_builder.next_id(),
128                });
129
130                (set_vec_stmt_expr, [stmts, vec![len_stmt, set_len_stmt]].concat())
131            }
132
133            Some(CoreFunction::VectorLen) => {
134                // Input:
135                //   Vector::len(v)
136                //
137                // Lowered reconstruction:
138                //   Mapping::get_or_use(len_map, false, 0u32)
139
140                //  Unpack arguments
141                let [vector_expr] = &mut input.arguments[..] else {
142                    panic!("Vector::len should have 1 argument");
143                };
144
145                // Validate vector type
146                assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
147
148                let (_vec_values_mapping_name, vec_length_mapping_name) =
149                    self.generate_mapping_names_for_vector(vector_expr);
150                let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
151
152                let get_len_expr = self.get_vector_len_expr(len_path_expr, input.span);
153                (get_len_expr, vec![])
154            }
155
156            Some(CoreFunction::VectorPop) => {
157                // Input:
158                //   Vector::pop(v)
159                //
160                // Lowered reconstruction:
161                //   let $len_var = Mapping::get_or_use(len_map, false, 0u32);
162                //   Mapping::set(len_map, false, $len_var > 0 ? $len_var - 1 : $len_var);
163                //   $len_var > 0 ? Mapping::get_or_use(vec_map, $len_var - 1, zero_value) : None
164
165                // Unpack argument
166                let [vector_expr] = &mut input.arguments[..] else {
167                    panic!("Vector::pop should have 1 argument");
168                };
169
170                // Validate vector type
171                let Some(Type::Vector(VectorType { element_type })) = self.state.type_table.get(&vector_expr.id())
172                else {
173                    panic!("argument to Vector::pop should be of type `Vector`.");
174                };
175
176                let (vec_values_mapping_name, vec_length_mapping_name) =
177                    self.generate_mapping_names_for_vector(vector_expr);
178                let vec_path_expr = self.symbol_to_path_expr(vec_values_mapping_name);
179                let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
180
181                // let $len_var = Mapping::get_or_use(len_map, false, 0u32)
182                let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
183                let len_var_ident =
184                    Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
185                let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
186                let len_stmt = self.state.assigner.simple_definition(
187                    len_var_ident,
188                    get_len_expr,
189                    self.state.node_builder.next_id(),
190                );
191                let len_var_expr: Expression = len_var_ident.into();
192
193                // $len_var > 0
194                let literal_zero = self.literal_zero_u32();
195                let len_gt_zero_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Gt, literal_zero);
196
197                // $len_var - 1
198                let literal_one = self.literal_one_u32();
199                let len_minus_one_expr =
200                    self.binary_expr(len_var_expr.clone(), BinaryOperation::SubWrapped, literal_one);
201
202                // ternary for new length: ($len_var > 0 ? $len_var - 1 : $len_var)
203                let new_len_expr = self.ternary_expr(
204                    len_gt_zero_expr.clone(),
205                    len_minus_one_expr.clone(),
206                    len_var_expr.clone(),
207                    input.span,
208                );
209
210                // Mapping::set(len_map, false, new_len)
211                let literal_false = self.literal_false();
212                let set_len_stmt = Statement::Expression(ExpressionStatement {
213                    expression: self.set_mapping_expr(len_path_expr.clone(), literal_false, new_len_expr, input.span),
214                    span: input.span,
215                    id: self.state.node_builder.next_id(),
216                });
217
218                // zero value for element type (used as default in get_or_use)
219                let zero = self.zero(&element_type);
220
221                // Mapping::get_or_use(vec_map, $len_var - 1, zero)
222                let get_or_use_expr =
223                    self.get_or_use_mapping_expr(vec_path_expr, len_minus_one_expr.clone(), zero, input.span);
224
225                // ternary: $len_var > 0 ? get(vec, len-1) : None
226                let none_expr: Expression = Literal::none(Span::default(), self.state.node_builder.next_id()).into();
227                let ternary_expr = self.ternary_expr(len_gt_zero_expr, get_or_use_expr, none_expr, input.span);
228
229                (ternary_expr, vec![len_stmt, set_len_stmt])
230            }
231
232            Some(CoreFunction::Get) => {
233                // Unpack arguments (container, index/key)
234                let [container_expr, key_expr] = &mut input.arguments[..] else {
235                    panic!("Get should have 2 arguments");
236                };
237
238                // Reconstruct key/index)
239                let (reconstructed_key_expr, key_stmts) =
240                    self.reconstruct_expression(key_expr.clone(), &Default::default());
241
242                match self.state.type_table.get(&container_expr.id()) {
243                    Some(Type::Vector(VectorType { element_type })) => {
244                        // Input:
245                        //   Get(v, index)
246                        //
247                        // Lowered reconstruction:
248                        //   let $len_var = Mapping::get_or_use(len_map, false, 0u32);
249                        //   index < $len_var
250                        //       ? Mapping::get_or_use(vec_map, index, zero_value)
251                        //       : None
252
253                        let (vec_values_mapping_name, vec_length_mapping_name) =
254                            self.generate_mapping_names_for_vector(container_expr);
255                        let vec_path_expr = self.symbol_to_path_expr(vec_values_mapping_name);
256                        let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
257
258                        // let $len_var = Mapping::get_or_use(len_map, false, 0u32)
259                        let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
260                        let len_var_ident = Identifier {
261                            name: len_var_sym,
262                            span: Default::default(),
263                            id: self.state.node_builder.next_id(),
264                        };
265                        let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
266                        let len_stmt = self.state.assigner.simple_definition(
267                            len_var_ident,
268                            get_len_expr,
269                            self.state.node_builder.next_id(),
270                        );
271                        let len_var_expr: Expression = len_var_ident.into();
272
273                        // index < len
274                        let index_lt_len_expr =
275                            self.binary_expr(reconstructed_key_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
276
277                        // zero value for element type (used as default in get_or_use)
278                        let zero = self.zero(&element_type);
279
280                        // Mapping::get(vec_map, index)
281                        let get_or_use_expr = self.get_or_use_mapping_expr(
282                            vec_path_expr,
283                            reconstructed_key_expr.clone(),
284                            zero,
285                            input.span,
286                        );
287
288                        // ternary: index < len ? get(vec, index) : None
289                        let none_expr: Expression =
290                            Literal::none(Span::default(), self.state.node_builder.next_id()).into();
291                        let ternary_expr = self.ternary_expr(index_lt_len_expr, get_or_use_expr, none_expr, input.span);
292
293                        (ternary_expr, [key_stmts, vec![len_stmt]].concat())
294                    }
295
296                    Some(Type::Mapping(_)) => {
297                        // Update the `key` argument of the `Get` as well as the variant name from
298                        // `__unresolved` to `Mapping`.
299                        input.arguments[1] = reconstructed_key_expr;
300                        input.variant.name = sym::Mapping;
301                        (input.into(), key_stmts)
302                    }
303
304                    _ => {
305                        panic!("type checking should guarantee that no other type is expected here.")
306                    }
307                }
308            }
309
310            Some(CoreFunction::Set) => {
311                // Unpack arguments (container, index/key, value)
312                let [container_expr, index_expr, value_expr] = &mut input.arguments[..] else {
313                    panic!("Set should have 3 arguments");
314                };
315
316                // Reconstruct key/index and value
317                let (reconstructed_key_expr, key_stmts) =
318                    self.reconstruct_expression(index_expr.clone(), &Default::default());
319                let (reconstructed_value_expr, value_stmts) =
320                    self.reconstruct_expression(value_expr.clone(), &Default::default());
321
322                match self.state.type_table.get(&container_expr.id()) {
323                    Some(Type::Vector(_)) => {
324                        // Input:
325                        //   Set(v, index, value)
326                        //
327                        // Lowered reconstruction (conceptually):
328                        //   let $len_var = Mapping::get_or_use(len_map, false, 0u32);
329                        //   assert(index < $len_var);
330                        //   Mapping::set(vec_map, index, value);
331
332                        let (vec_values_mapping_name, vec_length_mapping_name) =
333                            self.generate_mapping_names_for_vector(container_expr);
334                        let vec_path_expr = self.symbol_to_path_expr(vec_values_mapping_name);
335                        let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
336
337                        // let $len_var = Mapping::get_or_use(len_map, false, 0u32)
338                        let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
339                        let len_var_ident = Identifier {
340                            name: len_var_sym,
341                            span: Default::default(),
342                            id: self.state.node_builder.next_id(),
343                        };
344                        let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
345                        let len_stmt = self.state.assigner.simple_definition(
346                            len_var_ident,
347                            get_len_expr,
348                            self.state.node_builder.next_id(),
349                        );
350                        let len_var_expr: Expression = len_var_ident.into();
351
352                        // index < $len_var
353                        let index_lt_len_expr =
354                            self.binary_expr(reconstructed_key_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
355
356                        // Mapping::set(vec_map, index, value)
357                        let set_stmt_expr = self.set_mapping_expr(
358                            vec_path_expr.clone(),
359                            reconstructed_key_expr.clone(),
360                            reconstructed_value_expr.clone(),
361                            input.span,
362                        );
363
364                        // assert(index < len)
365                        let assert_stmt = Statement::Assert(AssertStatement {
366                            variant: AssertVariant::Assert(index_lt_len_expr.clone()),
367                            span: Span::default(),
368                            id: self.state.node_builder.next_id(),
369                        });
370
371                        // Emit assert then set
372                        (set_stmt_expr, [key_stmts, value_stmts, vec![len_stmt, assert_stmt]].concat())
373                    }
374
375                    Some(Type::Mapping(_)) => {
376                        // Update the `key` and `value` arguments as well as the variant name from
377                        // `__unresolved` to `Mapping`.
378                        input.arguments[1] = reconstructed_key_expr;
379                        input.arguments[2] = reconstructed_value_expr;
380                        input.variant.name = sym::Mapping;
381                        (input.into(), [key_stmts, value_stmts].concat())
382                    }
383
384                    _ => {
385                        panic!("type checking should guarantee that no other type is expected here.")
386                    }
387                }
388            }
389
390            Some(CoreFunction::VectorClear) => {
391                // Input:
392                //   Vector::clear(v)
393                //
394                // Lowered reconstruction (conceptually):
395                //   Mapping::set(len_map, false, 0u32);
396                //
397                // Note: `VectorClear` does not actually remove any elements from the mapping of
398                // vector values.
399
400                // Unpack arguments
401                let [vector_expr] = &mut input.arguments[..] else {
402                    panic!("Vector::clear should have 1 argument");
403                };
404
405                // Validate vector type
406                assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
407
408                let (_vec_values_mapping_name, vec_length_mapping_name) =
409                    self.generate_mapping_names_for_vector(vector_expr);
410                let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
411
412                // Mapping::set(len_map, false, 0u32)
413                let literal_false = self.literal_false();
414                let literal_zero = self.literal_zero_u32();
415                let set_len_stmt_expr = self.set_mapping_expr(len_path_expr, literal_false, literal_zero, input.span);
416
417                (set_len_stmt_expr, vec![])
418            }
419
420            Some(CoreFunction::VectorSwapRemove) => {
421                // Input:
422                //   Vector::swap_remove(v, index)
423                //
424                // Lowered reconstruction (conceptually):
425                //   let $len_var = Mapping::get_or_use(len_map, false, 0u32);
426                //   assert(index < $len_var);
427                //   let $removed = Mapping::get(vec_map, index);
428                //   Mapping::set(vec_map, index, Mapping::get(vec_map, $len_var - 1));
429                //   Mapping::set(len_map, false, $len_var - 1);
430                //   $removed
431
432                let [vector_expr, index_expr] = &mut input.arguments[..] else {
433                    panic!("Vector::swap_remove should have 2 arguments");
434                };
435
436                // Validate vector type
437                assert!(matches!(self.state.type_table.get(&vector_expr.id()), Some(Type::Vector(_))));
438
439                // Reconstruct index
440                let (reconstructed_index_expr, index_stmts) =
441                    self.reconstruct_expression(index_expr.clone(), &Default::default());
442
443                let (vec_values_mapping_name, vec_length_mapping_name) =
444                    self.generate_mapping_names_for_vector(vector_expr);
445                let vec_path_expr = self.symbol_to_path_expr(vec_values_mapping_name);
446                let len_path_expr = self.symbol_to_path_expr(vec_length_mapping_name);
447
448                // let $len_var = Mapping::get_or_use(len_map, false, 0u32)
449                let len_var_sym = self.state.assigner.unique_symbol("$len_var", "$");
450                let len_var_ident =
451                    Identifier { name: len_var_sym, span: Default::default(), id: self.state.node_builder.next_id() };
452                let get_len_expr = self.get_vector_len_expr(len_path_expr.clone(), input.span);
453                let len_stmt = self.state.assigner.simple_definition(
454                    len_var_ident,
455                    get_len_expr,
456                    self.state.node_builder.next_id(),
457                );
458                let len_var_expr: Expression = len_var_ident.into();
459
460                // assert(index < $len_var);
461                let index_lt_len_expr =
462                    self.binary_expr(reconstructed_index_expr.clone(), BinaryOperation::Lt, len_var_expr.clone());
463                let assert_stmt = Statement::Assert(AssertStatement {
464                    variant: AssertVariant::Assert(index_lt_len_expr.clone()),
465                    span: input.span,
466                    id: self.state.node_builder.next_id(),
467                });
468
469                // let $removed = Mapping::get(vec_map, index); // the element to return
470                let get_elem_expr =
471                    self.get_mapping_expr(vec_path_expr.clone(), reconstructed_index_expr.clone(), input.span);
472                let removed_sym = self.state.assigner.unique_symbol("$removed", "$");
473                let removed_ident =
474                    Identifier { name: removed_sym, span: Default::default(), id: self.state.node_builder.next_id() };
475                let removed_stmt = Statement::Definition(DefinitionStatement {
476                    place: DefinitionPlace::Single(removed_ident),
477                    type_: None,
478                    value: get_elem_expr,
479                    span: input.span,
480                    id: self.state.node_builder.next_id(),
481                });
482
483                // len - 1
484                let literal_one = self.literal_one_u32();
485                let len_minus_one_expr = self.binary_expr(len_var_expr.clone(), BinaryOperation::Sub, literal_one);
486
487                // Mapping::set(vec_map, index, Mapping::get(vec_map, len - 1));
488                let get_last_expr =
489                    self.get_mapping_expr(vec_path_expr.clone(), len_minus_one_expr.clone(), input.span);
490                let set_swap_stmt = Statement::Expression(ExpressionStatement {
491                    expression: self.set_mapping_expr(
492                        vec_path_expr.clone(),
493                        reconstructed_index_expr.clone(),
494                        get_last_expr,
495                        input.span,
496                    ),
497                    span: input.span,
498                    id: self.state.node_builder.next_id(),
499                });
500
501                // Mapping::set(len_map, false, len - 1);
502                let literal_false = self.literal_false();
503                let set_len_stmt = Statement::Expression(ExpressionStatement {
504                    expression: self.set_mapping_expr(
505                        len_path_expr.clone(),
506                        literal_false,
507                        len_minus_one_expr,
508                        input.span,
509                    ),
510                    span: input.span,
511                    id: self.state.node_builder.next_id(),
512                });
513
514                // Return `$removed` as the resulting expression
515                (
516                    removed_ident.into(),
517                    [index_stmts, vec![len_stmt, assert_stmt, removed_stmt, set_swap_stmt, set_len_stmt]].concat(),
518                )
519            }
520
521            _ => {
522                // Default: reconstruct all arguments recursively and return the (possibly updated) original call
523                let statements: Vec<_> = input
524                    .arguments
525                    .iter_mut()
526                    .flat_map(|arg| {
527                        let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &());
528                        *arg = expr;
529                        stmts
530                    })
531                    .collect();
532
533                (input.into(), statements)
534            }
535        }
536    }
537
538    fn reconstruct_member_access(
539        &mut self,
540        mut input: MemberAccess,
541        _additional: &(),
542    ) -> (Expression, Self::AdditionalOutput) {
543        let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &());
544
545        input.inner = inner;
546
547        (input.into(), stmts_inner)
548    }
549
550    fn reconstruct_repeat(
551        &mut self,
552        mut input: RepeatExpression,
553        _additional: &(),
554    ) -> (Expression, Self::AdditionalOutput) {
555        // Use expected type (if available) for `expr`
556        let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &());
557        let (count, mut stmts_count) = self.reconstruct_expression(input.count, &());
558
559        input.expr = expr;
560        input.count = count;
561
562        stmts_expr.append(&mut stmts_count);
563
564        (input.into(), stmts_expr)
565    }
566
567    fn reconstruct_tuple_access(
568        &mut self,
569        mut input: TupleAccess,
570        _additional: &(),
571    ) -> (Expression, Self::AdditionalOutput) {
572        let (tuple, stmts) = self.reconstruct_expression(input.tuple, &());
573
574        input.tuple = tuple;
575
576        (input.into(), stmts)
577    }
578
579    fn reconstruct_array(
580        &mut self,
581        mut input: ArrayExpression,
582        _additional: &(),
583    ) -> (Expression, Self::AdditionalOutput) {
584        let mut all_stmts = Vec::new();
585        let mut new_elements = Vec::with_capacity(input.elements.len());
586
587        for element in input.elements.into_iter() {
588            let (expr, mut stmts) = self.reconstruct_expression(element, &());
589            all_stmts.append(&mut stmts);
590            new_elements.push(expr);
591        }
592
593        input.elements = new_elements;
594
595        (input.into(), all_stmts)
596    }
597
598    fn reconstruct_binary(
599        &mut self,
600        mut input: BinaryExpression,
601        _additional: &(),
602    ) -> (Expression, Self::AdditionalOutput) {
603        let (left, mut stmts_left) = self.reconstruct_expression(input.left, &());
604        let (right, mut stmts_right) = self.reconstruct_expression(input.right, &());
605
606        input.left = left;
607        input.right = right;
608
609        // Merge side effects
610        stmts_left.append(&mut stmts_right);
611
612        (input.into(), stmts_left)
613    }
614
615    fn reconstruct_call(&mut self, mut input: CallExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
616        let mut statements = Vec::new();
617        for arg in input.arguments.iter_mut() {
618            let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
619            statements.extend(statements2);
620            *arg = expr;
621        }
622        (input.into(), statements)
623    }
624
625    fn reconstruct_cast(&mut self, input: CastExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
626        let (expression, statements) = self.reconstruct_expression(input.expression, &());
627        (CastExpression { expression, ..input }.into(), statements)
628    }
629
630    fn reconstruct_struct_init(
631        &mut self,
632        mut input: StructExpression,
633        _additional: &(),
634    ) -> (Expression, Self::AdditionalOutput) {
635        let mut statements = Vec::new();
636
637        // Reconstruct const_arguments and extract statements
638        for const_arg in input.const_arguments.iter_mut() {
639            let (expr, statements2) = self.reconstruct_expression(const_arg.clone(), &());
640            statements.extend(statements2);
641            *const_arg = expr;
642        }
643
644        // Reconstruct members and extract statements
645        for member in input.members.iter_mut() {
646            assert!(member.expression.is_some());
647            let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap(), &());
648            statements.extend(statements2);
649            member.expression = Some(expr);
650        }
651
652        (input.into(), statements)
653    }
654
655    fn reconstruct_path(&mut self, input: Path, _additional: &()) -> (Expression, Self::AdditionalOutput) {
656        // Check if this path corresponds to a global symbol.
657        let Some(var) = self.state.symbol_table.lookup_global(&Location::new(self.program, input.absolute_path()))
658        else {
659            // Nothing to do
660            return (input.into(), vec![]);
661        };
662
663        match &var.type_ {
664            Type::Mapping(_) => {
665                // No transformation needed for mappings.
666                (input.into(), vec![])
667            }
668
669            Type::Optional(OptionalType { inner }) => {
670                // Input:
671                //   storage x: field;
672                //   ...
673                //   let y = x;
674                //
675                // Lowered reconstruction:
676                //  mapping x__: bool => field
677                //  let y = x__.contains(false)
678                //      ? x__.get_or_use(false, 0field)
679                //      : None;
680
681                let id = || self.state.node_builder.next_id();
682                let var_name = input.identifier().name;
683
684                // Path to the mapping backing the optional variable: `<var_name>__`
685                let mapping_symbol = Symbol::intern(&format!("{var_name}__"));
686                let mapping_ident = Identifier::new(mapping_symbol, id());
687
688                // === Build expressions ===
689                let mapping_expr: Expression = Path::from(mapping_ident).into_absolute().into();
690                let false_literal: Expression = Literal::boolean(false, Span::default(), id()).into();
691
692                // `<var_name>__.contains(false)`
693                let contains_expr: Expression = AssociatedFunctionExpression {
694                    variant: Identifier::new(sym::Mapping, id()),
695                    name: Identifier::new(Symbol::intern("contains"), id()),
696                    type_parameters: vec![],
697                    arguments: vec![mapping_expr.clone(), false_literal.clone()],
698                    span: Span::default(),
699                    id: id(),
700                }
701                .into();
702
703                // zero value for element type
704                let zero = self.zero(inner);
705
706                // `<var_name>__.get_or_use(false, zero_value)`
707                let get_or_use_expr: Expression = AssociatedFunctionExpression {
708                    variant: Identifier::new(sym::Mapping, id()),
709                    name: Identifier::new(Symbol::intern("get_or_use"), id()),
710                    type_parameters: vec![],
711                    arguments: vec![mapping_expr.clone(), false_literal, zero],
712                    span: Span::default(),
713                    id: id(),
714                }
715                .into();
716
717                // `None`
718                let none_expr =
719                    Expression::Literal(Literal { variant: LiteralVariant::None, span: Span::default(), id: id() });
720
721                // Combine into ternary:
722                // `<var_name>__.contains(false) ? <var_name>__.get_or_use(false, zero_val) : None`
723                let ternary_expr: Expression = TernaryExpression {
724                    condition: contains_expr,
725                    if_true: get_or_use_expr,
726                    if_false: none_expr,
727                    span: Span::default(),
728                    id: id(),
729                }
730                .into();
731
732                (ternary_expr, vec![])
733            }
734
735            _ => {
736                panic!("Expected a non-vector type in reconstruct_path, found {:?}", var.type_);
737            }
738        }
739    }
740
741    fn reconstruct_ternary(
742        &mut self,
743        input: TernaryExpression,
744        _addiional: &(),
745    ) -> (Expression, Self::AdditionalOutput) {
746        let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
747        let (if_true, statements2) = self.reconstruct_expression(input.if_true, &());
748        let (if_false, statements3) = self.reconstruct_expression(input.if_false, &());
749        statements.extend(statements2);
750        statements.extend(statements3);
751        (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
752    }
753
754    fn reconstruct_tuple(
755        &mut self,
756        input: leo_ast::TupleExpression,
757        _addiional: &(),
758    ) -> (Expression, Self::AdditionalOutput) {
759        // This should ony appear in a return statement.
760        let mut statements = Vec::new();
761        let elements = input
762            .elements
763            .into_iter()
764            .map(|element| {
765                let (expr, statements2) = self.reconstruct_expression(element, &());
766                statements.extend(statements2);
767                expr
768            })
769            .collect();
770        (TupleExpression { elements, ..input }.into(), statements)
771    }
772
773    fn reconstruct_unary(
774        &mut self,
775        input: leo_ast::UnaryExpression,
776        _addiional: &(),
777    ) -> (Expression, Self::AdditionalOutput) {
778        let (receiver, statements) = self.reconstruct_expression(input.receiver, &());
779        (UnaryExpression { receiver, ..input }.into(), statements)
780    }
781
782    /* Statements */
783    fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
784        let mut statements = Vec::new();
785        let stmt = AssertStatement {
786            variant: match input.variant {
787                AssertVariant::Assert(expr) => {
788                    let (expr, statements2) = self.reconstruct_expression(expr, &());
789                    statements.extend(statements2);
790                    AssertVariant::Assert(expr)
791                }
792                AssertVariant::AssertEq(left, right) => {
793                    let (left, statements2) = self.reconstruct_expression(left, &());
794                    statements.extend(statements2);
795                    let (right, statements3) = self.reconstruct_expression(right, &());
796                    statements.extend(statements3);
797                    AssertVariant::AssertEq(left, right)
798                }
799                AssertVariant::AssertNeq(left, right) => {
800                    let (left, statements2) = self.reconstruct_expression(left, &());
801                    statements.extend(statements2);
802                    let (right, statements3) = self.reconstruct_expression(right, &());
803                    statements.extend(statements3);
804                    AssertVariant::AssertNeq(left, right)
805                }
806            },
807            ..input
808        }
809        .into();
810        (stmt, statements)
811    }
812
813    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
814        let AssignStatement { place, value, span, .. } = input;
815        let mut statements = vec![];
816
817        // Check if `place` is a path
818        if let Expression::Path(path) = &place {
819            // Check if the path corresponds to a global storage variable
820            if let Some(var) = self.state.symbol_table.lookup_global(&Location::new(self.program, path.absolute_path()))
821            {
822                // Storage variables that are not optional nor mappings are implicitly wrapped in an optional.
823                assert!(
824                    var.type_.is_optional(),
825                    "Only storage variables that are not vectors or mappings are expected here."
826                );
827
828                // Reconstruct the RHS
829                let (new_value, mut value_stmts) = self.reconstruct_expression(value, &());
830                statements.append(&mut value_stmts);
831
832                let id = || self.state.node_builder.next_id();
833                let var_name = path.identifier().name;
834
835                // Path to the mapping backing the storage variable: `<var_name>__`
836                let mapping_symbol = Symbol::intern(&format!("{var_name}__"));
837                let mapping_ident = Identifier::new(mapping_symbol, id());
838                let mapping_expr: Expression = Path::from(mapping_ident).into_absolute().into();
839                let false_literal: Expression = Literal::boolean(false, Span::default(), id()).into();
840
841                let stmt = if matches!(new_value, Expression::Literal(Literal { variant: LiteralVariant::None, .. })) {
842                    // Input:
843                    //   storage x: field;
844                    //   ...
845                    //   x = none;
846                    //
847                    // Lowered reconstruction:
848                    //   mapping x__: bool => field;
849                    //   ...
850                    //   Mapping::remove(x__, false);
851                    let remove_expr: Expression = AssociatedFunctionExpression {
852                        variant: Identifier::new(sym::Mapping, id()),
853                        name: Identifier::new(Symbol::intern("remove"), id()),
854                        type_parameters: vec![],
855                        arguments: vec![mapping_expr, false_literal],
856                        span,
857                        id: id(),
858                    }
859                    .into();
860                    Statement::Expression(ExpressionStatement { expression: remove_expr, span, id: id() })
861                } else {
862                    // Input:
863                    //   storage x: field;
864                    //   ...
865                    //   x = 5field;
866                    //
867                    // Lowered reconstruction:
868                    //   mapping x__: bool => field;
869                    //   ...
870                    //   Mapping::set(x__, false, 5field);
871                    let set_expr: Expression = AssociatedFunctionExpression {
872                        variant: Identifier::new(sym::Mapping, id()),
873                        name: Identifier::new(Symbol::intern("set"), id()),
874                        type_parameters: vec![],
875                        arguments: vec![mapping_expr, false_literal, new_value],
876                        span,
877                        id: id(),
878                    }
879                    .into();
880                    Statement::Expression(ExpressionStatement { expression: set_expr, span, id: id() })
881                };
882                return (stmt, statements);
883            }
884        }
885
886        // In all other cases, nothing special to do.
887        let (new_place, mut place_stmts) = self.reconstruct_expression(place, &());
888        let (new_value, mut value_stmts) = self.reconstruct_expression(value, &());
889        statements.append(&mut place_stmts);
890        statements.append(&mut value_stmts);
891
892        let stmt =
893            AssignStatement { place: new_place, value: new_value, span, id: self.state.node_builder.next_id() }.into();
894        (stmt, statements)
895    }
896
897    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
898        let mut statements = Vec::with_capacity(block.statements.len());
899
900        // Flatten each statement, accumulating any new statements produced.
901        for statement in block.statements {
902            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
903            statements.extend(additional_statements);
904            statements.push(reconstructed_statement);
905        }
906
907        (Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
908    }
909
910    fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
911        let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
912        let (then, statements2) = self.reconstruct_block(input.then);
913        statements.extend(statements2);
914        let otherwise = input.otherwise.map(|oth| {
915            let (expr, statements3) = self.reconstruct_statement(*oth);
916            statements.extend(statements3);
917            Box::new(expr)
918        });
919        (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
920    }
921
922    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
923        let (type_expr, type_statements) = self.reconstruct_type(input.type_);
924        let (value_expr, value_statements) = self.reconstruct_expression(input.value, &Default::default());
925
926        let mut statements = Vec::new();
927        statements.extend(type_statements);
928        statements.extend(value_statements);
929
930        (ConstDeclaration { type_: type_expr, value: value_expr, ..input }.into(), statements)
931    }
932
933    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
934        let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &());
935
936        input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
937        input.value = new_value;
938
939        (input.into(), additional_stmts)
940    }
941
942    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
943        let (reconstructed_expression, statements) = self.reconstruct_expression(input.expression, &Default::default());
944        if !matches!(reconstructed_expression, Expression::Call(_) | Expression::AssociatedFunction(_)) {
945            (
946                ExpressionStatement {
947                    expression: Expression::Unit(UnitExpression {
948                        span: Span::default(),
949                        id: self.state.node_builder.next_id(),
950                    }),
951                    ..input
952                }
953                .into(),
954                statements,
955            )
956        } else {
957            (ExpressionStatement { expression: reconstructed_expression, ..input }.into(), statements)
958        }
959    }
960
961    fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
962        panic!("`IterationStatement`s should not be in the AST at this point.");
963    }
964
965    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
966        let (expression, statements) = self.reconstruct_expression(input.expression, &());
967        (ReturnStatement { expression, ..input }.into(), statements)
968    }
969}