leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for types in the AST.
24pub trait AstReconstructor {
25    type AdditionalOutput: Default;
26    type AdditionalInput: Default;
27
28    /* Types */
29    fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30        match input {
31            Type::Array(array_type) => self.reconstruct_array_type(array_type),
32            Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33            Type::Future(future_type) => self.reconstruct_future_type(future_type),
34            Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35            Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36            Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37            Type::Vector(vector_type) => self.reconstruct_vector_type(vector_type),
38            Type::Address
39            | Type::Boolean
40            | Type::Field
41            | Type::Group
42            | Type::Identifier(_)
43            | Type::Integer(_)
44            | Type::Scalar
45            | Type::Signature
46            | Type::String
47            | Type::Numeric
48            | Type::Unit
49            | Type::Err => (input.clone(), Default::default()),
50        }
51    }
52
53    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
54        (
55            Type::Array(ArrayType {
56                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
57                length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
58            }),
59            Default::default(),
60        )
61    }
62
63    fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
64        (
65            Type::Composite(CompositeType {
66                const_arguments: input
67                    .const_arguments
68                    .into_iter()
69                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
70                    .collect(),
71                ..input
72            }),
73            Default::default(),
74        )
75    }
76
77    fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
78        (
79            Type::Future(FutureType {
80                inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
81                ..input
82            }),
83            Default::default(),
84        )
85    }
86
87    fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
88        (
89            Type::Mapping(MappingType {
90                key: Box::new(self.reconstruct_type(*input.key).0),
91                value: Box::new(self.reconstruct_type(*input.value).0),
92                ..input
93            }),
94            Default::default(),
95        )
96    }
97
98    fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
99        (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
100    }
101
102    fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
103        (
104            Type::Tuple(TupleType {
105                elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
106            }),
107            Default::default(),
108        )
109    }
110
111    fn reconstruct_vector_type(&mut self, input: VectorType) -> (Type, Self::AdditionalOutput) {
112        (
113            Type::Vector(VectorType { element_type: Box::new(self.reconstruct_type(*input.element_type).0) }),
114            Default::default(),
115        )
116    }
117
118    /* Expressions */
119    fn reconstruct_expression(
120        &mut self,
121        input: Expression,
122        additional: &Self::AdditionalInput,
123    ) -> (Expression, Self::AdditionalOutput) {
124        match input {
125            Expression::Async(async_) => self.reconstruct_async(async_, additional),
126            Expression::Array(array) => self.reconstruct_array(array, additional),
127            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
128            Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
129            Expression::Call(call) => self.reconstruct_call(*call, additional),
130            Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
131            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
132            Expression::Err(err) => self.reconstruct_err(err, additional),
133            Expression::Path(path) => self.reconstruct_path(path, additional),
134            Expression::Literal(value) => self.reconstruct_literal(value, additional),
135            Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
136            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
137            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
138            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
139            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
140            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
141            Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
142            Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
143            Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, additional),
144        }
145    }
146
147    fn reconstruct_array_access(
148        &mut self,
149        input: ArrayAccess,
150        _additional: &Self::AdditionalInput,
151    ) -> (Expression, Self::AdditionalOutput) {
152        (
153            ArrayAccess {
154                array: self.reconstruct_expression(input.array, &Default::default()).0,
155                index: self.reconstruct_expression(input.index, &Default::default()).0,
156                ..input
157            }
158            .into(),
159            Default::default(),
160        )
161    }
162
163    fn reconstruct_async(
164        &mut self,
165        input: AsyncExpression,
166        _additional: &Self::AdditionalInput,
167    ) -> (Expression, Self::AdditionalOutput) {
168        (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
169    }
170
171    fn reconstruct_member_access(
172        &mut self,
173        input: MemberAccess,
174        _additional: &Self::AdditionalInput,
175    ) -> (Expression, Self::AdditionalOutput) {
176        (
177            MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
178            Default::default(),
179        )
180    }
181
182    fn reconstruct_repeat(
183        &mut self,
184        input: RepeatExpression,
185        _additional: &Self::AdditionalInput,
186    ) -> (Expression, Self::AdditionalOutput) {
187        (
188            RepeatExpression {
189                expr: self.reconstruct_expression(input.expr, &Default::default()).0,
190                count: self.reconstruct_expression(input.count, &Default::default()).0,
191                ..input
192            }
193            .into(),
194            Default::default(),
195        )
196    }
197
198    fn reconstruct_intrinsic(
199        &mut self,
200        input: IntrinsicExpression,
201        _additional: &Self::AdditionalInput,
202    ) -> (Expression, Self::AdditionalOutput) {
203        (
204            IntrinsicExpression {
205                arguments: input
206                    .arguments
207                    .into_iter()
208                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
209                    .collect(),
210                ..input
211            }
212            .into(),
213            Default::default(),
214        )
215    }
216
217    fn reconstruct_tuple_access(
218        &mut self,
219        input: TupleAccess,
220        _additional: &Self::AdditionalInput,
221    ) -> (Expression, Self::AdditionalOutput) {
222        (
223            TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
224            Default::default(),
225        )
226    }
227
228    fn reconstruct_array(
229        &mut self,
230        input: ArrayExpression,
231        _additional: &Self::AdditionalInput,
232    ) -> (Expression, Self::AdditionalOutput) {
233        (
234            ArrayExpression {
235                elements: input
236                    .elements
237                    .into_iter()
238                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
239                    .collect(),
240                ..input
241            }
242            .into(),
243            Default::default(),
244        )
245    }
246
247    fn reconstruct_binary(
248        &mut self,
249        input: BinaryExpression,
250        _additional: &Self::AdditionalInput,
251    ) -> (Expression, Self::AdditionalOutput) {
252        (
253            BinaryExpression {
254                left: self.reconstruct_expression(input.left, &Default::default()).0,
255                right: self.reconstruct_expression(input.right, &Default::default()).0,
256                ..input
257            }
258            .into(),
259            Default::default(),
260        )
261    }
262
263    fn reconstruct_call(
264        &mut self,
265        input: CallExpression,
266        _additional: &Self::AdditionalInput,
267    ) -> (Expression, Self::AdditionalOutput) {
268        (
269            CallExpression {
270                const_arguments: input
271                    .const_arguments
272                    .into_iter()
273                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
274                    .collect(),
275                arguments: input
276                    .arguments
277                    .into_iter()
278                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
279                    .collect(),
280                ..input
281            }
282            .into(),
283            Default::default(),
284        )
285    }
286
287    fn reconstruct_cast(
288        &mut self,
289        input: CastExpression,
290        _additional: &Self::AdditionalInput,
291    ) -> (Expression, Self::AdditionalOutput) {
292        (
293            CastExpression {
294                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
295                ..input
296            }
297            .into(),
298            Default::default(),
299        )
300    }
301
302    fn reconstruct_struct_init(
303        &mut self,
304        input: StructExpression,
305        _additional: &Self::AdditionalInput,
306    ) -> (Expression, Self::AdditionalOutput) {
307        (
308            StructExpression {
309                const_arguments: input
310                    .const_arguments
311                    .into_iter()
312                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
313                    .collect(),
314                members: input
315                    .members
316                    .into_iter()
317                    .map(|member| StructVariableInitializer {
318                        identifier: member.identifier,
319                        expression: member
320                            .expression
321                            .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
322                        span: member.span,
323                        id: member.id,
324                    })
325                    .collect(),
326                ..input
327            }
328            .into(),
329            Default::default(),
330        )
331    }
332
333    fn reconstruct_err(
334        &mut self,
335        _input: ErrExpression,
336        _additional: &Self::AdditionalInput,
337    ) -> (Expression, Self::AdditionalOutput) {
338        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
339    }
340
341    fn reconstruct_path(
342        &mut self,
343        input: Path,
344        _additional: &Self::AdditionalInput,
345    ) -> (Expression, Self::AdditionalOutput) {
346        (input.into(), Default::default())
347    }
348
349    fn reconstruct_literal(
350        &mut self,
351        input: Literal,
352        _additional: &Self::AdditionalInput,
353    ) -> (Expression, Self::AdditionalOutput) {
354        (input.into(), Default::default())
355    }
356
357    fn reconstruct_locator(
358        &mut self,
359        input: LocatorExpression,
360        _additional: &Self::AdditionalInput,
361    ) -> (Expression, Self::AdditionalOutput) {
362        (input.into(), Default::default())
363    }
364
365    fn reconstruct_ternary(
366        &mut self,
367        input: TernaryExpression,
368        _additional: &Self::AdditionalInput,
369    ) -> (Expression, Self::AdditionalOutput) {
370        (
371            TernaryExpression {
372                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
373                if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
374                if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
375                span: input.span,
376                id: input.id,
377            }
378            .into(),
379            Default::default(),
380        )
381    }
382
383    fn reconstruct_tuple(
384        &mut self,
385        input: TupleExpression,
386        _additional: &Self::AdditionalInput,
387    ) -> (Expression, Self::AdditionalOutput) {
388        (
389            TupleExpression {
390                elements: input
391                    .elements
392                    .into_iter()
393                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
394                    .collect(),
395                ..input
396            }
397            .into(),
398            Default::default(),
399        )
400    }
401
402    fn reconstruct_unary(
403        &mut self,
404        input: UnaryExpression,
405        _additional: &Self::AdditionalInput,
406    ) -> (Expression, Self::AdditionalOutput) {
407        (
408            UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
409                .into(),
410            Default::default(),
411        )
412    }
413
414    fn reconstruct_unit(
415        &mut self,
416        input: UnitExpression,
417        _additional: &Self::AdditionalInput,
418    ) -> (Expression, Self::AdditionalOutput) {
419        (input.into(), Default::default())
420    }
421
422    /* Statements */
423    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
424        match input {
425            Statement::Assert(assert) => self.reconstruct_assert(assert),
426            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
427            Statement::Block(stmt) => {
428                let (stmt, output) = self.reconstruct_block(stmt);
429                (stmt.into(), output)
430            }
431            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
432            Statement::Const(stmt) => self.reconstruct_const(stmt),
433            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
434            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
435            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
436            Statement::Return(stmt) => self.reconstruct_return(stmt),
437        }
438    }
439
440    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
441        (
442            AssertStatement {
443                variant: match input.variant {
444                    AssertVariant::Assert(expr) => {
445                        AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
446                    }
447                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
448                        self.reconstruct_expression(left, &Default::default()).0,
449                        self.reconstruct_expression(right, &Default::default()).0,
450                    ),
451                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
452                        self.reconstruct_expression(left, &Default::default()).0,
453                        self.reconstruct_expression(right, &Default::default()).0,
454                    ),
455                },
456                ..input
457            }
458            .into(),
459            Default::default(),
460        )
461    }
462
463    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
464        (
465            AssignStatement {
466                place: self.reconstruct_expression(input.place, &Default::default()).0,
467                value: self.reconstruct_expression(input.value, &Default::default()).0,
468                ..input
469            }
470            .into(),
471            Default::default(),
472        )
473    }
474
475    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
476        (
477            Block {
478                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
479                span: input.span,
480                id: input.id,
481            },
482            Default::default(),
483        )
484    }
485
486    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
487        (
488            ConditionalStatement {
489                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
490                then: self.reconstruct_block(input.then).0,
491                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
492                ..input
493            }
494            .into(),
495            Default::default(),
496        )
497    }
498
499    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
500        (
501            ConstDeclaration {
502                type_: self.reconstruct_type(input.type_).0,
503                value: self.reconstruct_expression(input.value, &Default::default()).0,
504                ..input
505            }
506            .into(),
507            Default::default(),
508        )
509    }
510
511    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
512        (
513            DefinitionStatement {
514                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
515                value: self.reconstruct_expression(input.value, &Default::default()).0,
516                ..input
517            }
518            .into(),
519            Default::default(),
520        )
521    }
522
523    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
524        (
525            ExpressionStatement {
526                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
527                ..input
528            }
529            .into(),
530            Default::default(),
531        )
532    }
533
534    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
535        (
536            IterationStatement {
537                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
538                start: self.reconstruct_expression(input.start, &Default::default()).0,
539                stop: self.reconstruct_expression(input.stop, &Default::default()).0,
540                block: self.reconstruct_block(input.block).0,
541                ..input
542            }
543            .into(),
544            Default::default(),
545        )
546    }
547
548    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
549        (
550            ReturnStatement {
551                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
552                ..input
553            }
554            .into(),
555            Default::default(),
556        )
557    }
558}
559
560/// A Reconstructor trait for the program represented by the AST.
561pub trait ProgramReconstructor: AstReconstructor {
562    fn reconstruct_program(&mut self, input: Program) -> Program {
563        let program_scopes =
564            input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
565        Program {
566            imports: input
567                .imports
568                .into_iter()
569                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
570                .collect(),
571            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
572            modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
573            program_scopes,
574        }
575    }
576
577    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
578        Stub {
579            imports: input.imports,
580            stub_id: input.stub_id,
581            consts: input.consts,
582            structs: input.structs,
583            mappings: input.mappings,
584            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
585            span: input.span,
586        }
587    }
588
589    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
590        ProgramScope {
591            program_id: input.program_id,
592            consts: input
593                .consts
594                .into_iter()
595                .map(|(i, c)| match self.reconstruct_const(c) {
596                    (Statement::Const(declaration), _) => (i, declaration),
597                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
598                })
599                .collect(),
600            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
601            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
602            storage_variables: input
603                .storage_variables
604                .into_iter()
605                .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
606                .collect(),
607            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
608            constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
609            span: input.span,
610        }
611    }
612
613    fn reconstruct_module(&mut self, input: Module) -> Module {
614        Module {
615            program_name: input.program_name,
616            path: input.path,
617            consts: input
618                .consts
619                .into_iter()
620                .map(|(i, c)| match self.reconstruct_const(c) {
621                    (Statement::Const(declaration), _) => (i, declaration),
622                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
623                })
624                .collect(),
625            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
626            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
627        }
628    }
629
630    fn reconstruct_function(&mut self, input: Function) -> Function {
631        Function {
632            annotations: input.annotations,
633            variant: input.variant,
634            identifier: input.identifier,
635            const_parameters: input
636                .const_parameters
637                .iter()
638                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
639                .collect(),
640            input: input
641                .input
642                .iter()
643                .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
644                .collect(),
645            output: input
646                .output
647                .iter()
648                .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
649                .collect(),
650            output_type: self.reconstruct_type(input.output_type).0,
651            block: self.reconstruct_block(input.block).0,
652            span: input.span,
653            id: input.id,
654        }
655    }
656
657    fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
658        Constructor {
659            annotations: input.annotations,
660            block: self.reconstruct_block(input.block).0,
661            span: input.span,
662            id: input.id,
663        }
664    }
665
666    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
667        input
668    }
669
670    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
671        Composite {
672            const_parameters: input
673                .const_parameters
674                .iter()
675                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
676                .collect(),
677            members: input
678                .members
679                .iter()
680                .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
681                .collect(),
682            ..input
683        }
684    }
685
686    fn reconstruct_import(&mut self, input: Program) -> Program {
687        self.reconstruct_program(input)
688    }
689
690    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
691        Mapping {
692            key_type: self.reconstruct_type(input.key_type).0,
693            value_type: self.reconstruct_type(input.value_type).0,
694            ..input
695        }
696    }
697
698    fn reconstruct_storage_variable(&mut self, input: StorageVariable) -> StorageVariable {
699        StorageVariable { type_: self.reconstruct_type(input.type_).0, ..input }
700    }
701}