leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for types in the AST.
24pub trait AstReconstructor {
25    type AdditionalOutput: Default;
26    type AdditionalInput: Default;
27
28    /* Types */
29    fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30        match input {
31            Type::Array(array_type) => self.reconstruct_array_type(array_type),
32            Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33            Type::Future(future_type) => self.reconstruct_future_type(future_type),
34            Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35            Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36            Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37            Type::Vector(vector_type) => self.reconstruct_vector_type(vector_type),
38            Type::Address
39            | Type::Boolean
40            | Type::Field
41            | Type::Group
42            | Type::Identifier(_)
43            | Type::Integer(_)
44            | Type::Scalar
45            | Type::Signature
46            | Type::String
47            | Type::Numeric
48            | Type::Unit
49            | Type::Err => (input.clone(), Default::default()),
50        }
51    }
52
53    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
54        (
55            Type::Array(ArrayType {
56                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
57                length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
58            }),
59            Default::default(),
60        )
61    }
62
63    fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
64        (
65            Type::Composite(CompositeType {
66                const_arguments: input
67                    .const_arguments
68                    .into_iter()
69                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
70                    .collect(),
71                ..input
72            }),
73            Default::default(),
74        )
75    }
76
77    fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
78        (
79            Type::Future(FutureType {
80                inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
81                ..input
82            }),
83            Default::default(),
84        )
85    }
86
87    fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
88        (
89            Type::Mapping(MappingType {
90                key: Box::new(self.reconstruct_type(*input.key).0),
91                value: Box::new(self.reconstruct_type(*input.value).0),
92                ..input
93            }),
94            Default::default(),
95        )
96    }
97
98    fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
99        (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
100    }
101
102    fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
103        (
104            Type::Tuple(TupleType {
105                elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
106            }),
107            Default::default(),
108        )
109    }
110
111    fn reconstruct_vector_type(&mut self, input: VectorType) -> (Type, Self::AdditionalOutput) {
112        (
113            Type::Vector(VectorType { element_type: Box::new(self.reconstruct_type(*input.element_type).0) }),
114            Default::default(),
115        )
116    }
117
118    /* Expressions */
119    fn reconstruct_expression(
120        &mut self,
121        input: Expression,
122        additional: &Self::AdditionalInput,
123    ) -> (Expression, Self::AdditionalOutput) {
124        match input {
125            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, additional),
126            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, additional),
127            Expression::Async(async_) => self.reconstruct_async(async_, additional),
128            Expression::Array(array) => self.reconstruct_array(array, additional),
129            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
130            Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
131            Expression::Call(call) => self.reconstruct_call(*call, additional),
132            Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
133            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
134            Expression::Err(err) => self.reconstruct_err(err, additional),
135            Expression::Path(path) => self.reconstruct_path(path, additional),
136            Expression::Literal(value) => self.reconstruct_literal(value, additional),
137            Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
138            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
139            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
140            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
141            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
142            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
143            Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
144            Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
145        }
146    }
147
148    fn reconstruct_array_access(
149        &mut self,
150        input: ArrayAccess,
151        _additional: &Self::AdditionalInput,
152    ) -> (Expression, Self::AdditionalOutput) {
153        (
154            ArrayAccess {
155                array: self.reconstruct_expression(input.array, &Default::default()).0,
156                index: self.reconstruct_expression(input.index, &Default::default()).0,
157                ..input
158            }
159            .into(),
160            Default::default(),
161        )
162    }
163
164    fn reconstruct_associated_constant(
165        &mut self,
166        input: AssociatedConstantExpression,
167        _additional: &Self::AdditionalInput,
168    ) -> (Expression, Self::AdditionalOutput) {
169        (input.into(), Default::default())
170    }
171
172    fn reconstruct_associated_function(
173        &mut self,
174        input: AssociatedFunctionExpression,
175        _additional: &Self::AdditionalInput,
176    ) -> (Expression, Self::AdditionalOutput) {
177        (
178            AssociatedFunctionExpression {
179                arguments: input
180                    .arguments
181                    .into_iter()
182                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
183                    .collect(),
184                ..input
185            }
186            .into(),
187            Default::default(),
188        )
189    }
190
191    fn reconstruct_async(
192        &mut self,
193        input: AsyncExpression,
194        _additional: &Self::AdditionalInput,
195    ) -> (Expression, Self::AdditionalOutput) {
196        (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
197    }
198
199    fn reconstruct_member_access(
200        &mut self,
201        input: MemberAccess,
202        _additional: &Self::AdditionalInput,
203    ) -> (Expression, Self::AdditionalOutput) {
204        (
205            MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
206            Default::default(),
207        )
208    }
209
210    fn reconstruct_repeat(
211        &mut self,
212        input: RepeatExpression,
213        _additional: &Self::AdditionalInput,
214    ) -> (Expression, Self::AdditionalOutput) {
215        (
216            RepeatExpression {
217                expr: self.reconstruct_expression(input.expr, &Default::default()).0,
218                count: self.reconstruct_expression(input.count, &Default::default()).0,
219                ..input
220            }
221            .into(),
222            Default::default(),
223        )
224    }
225
226    fn reconstruct_tuple_access(
227        &mut self,
228        input: TupleAccess,
229        _additional: &Self::AdditionalInput,
230    ) -> (Expression, Self::AdditionalOutput) {
231        (
232            TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
233            Default::default(),
234        )
235    }
236
237    fn reconstruct_array(
238        &mut self,
239        input: ArrayExpression,
240        _additional: &Self::AdditionalInput,
241    ) -> (Expression, Self::AdditionalOutput) {
242        (
243            ArrayExpression {
244                elements: input
245                    .elements
246                    .into_iter()
247                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
248                    .collect(),
249                ..input
250            }
251            .into(),
252            Default::default(),
253        )
254    }
255
256    fn reconstruct_binary(
257        &mut self,
258        input: BinaryExpression,
259        _additional: &Self::AdditionalInput,
260    ) -> (Expression, Self::AdditionalOutput) {
261        (
262            BinaryExpression {
263                left: self.reconstruct_expression(input.left, &Default::default()).0,
264                right: self.reconstruct_expression(input.right, &Default::default()).0,
265                ..input
266            }
267            .into(),
268            Default::default(),
269        )
270    }
271
272    fn reconstruct_call(
273        &mut self,
274        input: CallExpression,
275        _additional: &Self::AdditionalInput,
276    ) -> (Expression, Self::AdditionalOutput) {
277        (
278            CallExpression {
279                const_arguments: input
280                    .const_arguments
281                    .into_iter()
282                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
283                    .collect(),
284                arguments: input
285                    .arguments
286                    .into_iter()
287                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
288                    .collect(),
289                ..input
290            }
291            .into(),
292            Default::default(),
293        )
294    }
295
296    fn reconstruct_cast(
297        &mut self,
298        input: CastExpression,
299        _additional: &Self::AdditionalInput,
300    ) -> (Expression, Self::AdditionalOutput) {
301        (
302            CastExpression {
303                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
304                ..input
305            }
306            .into(),
307            Default::default(),
308        )
309    }
310
311    fn reconstruct_struct_init(
312        &mut self,
313        input: StructExpression,
314        _additional: &Self::AdditionalInput,
315    ) -> (Expression, Self::AdditionalOutput) {
316        (
317            StructExpression {
318                const_arguments: input
319                    .const_arguments
320                    .into_iter()
321                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
322                    .collect(),
323                members: input
324                    .members
325                    .into_iter()
326                    .map(|member| StructVariableInitializer {
327                        identifier: member.identifier,
328                        expression: member
329                            .expression
330                            .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
331                        span: member.span,
332                        id: member.id,
333                    })
334                    .collect(),
335                ..input
336            }
337            .into(),
338            Default::default(),
339        )
340    }
341
342    fn reconstruct_err(
343        &mut self,
344        _input: ErrExpression,
345        _additional: &Self::AdditionalInput,
346    ) -> (Expression, Self::AdditionalOutput) {
347        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
348    }
349
350    fn reconstruct_path(
351        &mut self,
352        input: Path,
353        _additional: &Self::AdditionalInput,
354    ) -> (Expression, Self::AdditionalOutput) {
355        (input.into(), Default::default())
356    }
357
358    fn reconstruct_literal(
359        &mut self,
360        input: Literal,
361        _additional: &Self::AdditionalInput,
362    ) -> (Expression, Self::AdditionalOutput) {
363        (input.into(), Default::default())
364    }
365
366    fn reconstruct_locator(
367        &mut self,
368        input: LocatorExpression,
369        _additional: &Self::AdditionalInput,
370    ) -> (Expression, Self::AdditionalOutput) {
371        (input.into(), Default::default())
372    }
373
374    fn reconstruct_ternary(
375        &mut self,
376        input: TernaryExpression,
377        _additional: &Self::AdditionalInput,
378    ) -> (Expression, Self::AdditionalOutput) {
379        (
380            TernaryExpression {
381                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
382                if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
383                if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
384                span: input.span,
385                id: input.id,
386            }
387            .into(),
388            Default::default(),
389        )
390    }
391
392    fn reconstruct_tuple(
393        &mut self,
394        input: TupleExpression,
395        _additional: &Self::AdditionalInput,
396    ) -> (Expression, Self::AdditionalOutput) {
397        (
398            TupleExpression {
399                elements: input
400                    .elements
401                    .into_iter()
402                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
403                    .collect(),
404                ..input
405            }
406            .into(),
407            Default::default(),
408        )
409    }
410
411    fn reconstruct_unary(
412        &mut self,
413        input: UnaryExpression,
414        _additional: &Self::AdditionalInput,
415    ) -> (Expression, Self::AdditionalOutput) {
416        (
417            UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
418                .into(),
419            Default::default(),
420        )
421    }
422
423    fn reconstruct_unit(
424        &mut self,
425        input: UnitExpression,
426        _additional: &Self::AdditionalInput,
427    ) -> (Expression, Self::AdditionalOutput) {
428        (input.into(), Default::default())
429    }
430
431    /* Statements */
432    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
433        match input {
434            Statement::Assert(assert) => self.reconstruct_assert(assert),
435            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
436            Statement::Block(stmt) => {
437                let (stmt, output) = self.reconstruct_block(stmt);
438                (stmt.into(), output)
439            }
440            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
441            Statement::Const(stmt) => self.reconstruct_const(stmt),
442            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
443            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
444            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
445            Statement::Return(stmt) => self.reconstruct_return(stmt),
446        }
447    }
448
449    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
450        (
451            AssertStatement {
452                variant: match input.variant {
453                    AssertVariant::Assert(expr) => {
454                        AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
455                    }
456                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
457                        self.reconstruct_expression(left, &Default::default()).0,
458                        self.reconstruct_expression(right, &Default::default()).0,
459                    ),
460                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
461                        self.reconstruct_expression(left, &Default::default()).0,
462                        self.reconstruct_expression(right, &Default::default()).0,
463                    ),
464                },
465                ..input
466            }
467            .into(),
468            Default::default(),
469        )
470    }
471
472    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
473        (
474            AssignStatement {
475                place: self.reconstruct_expression(input.place, &Default::default()).0,
476                value: self.reconstruct_expression(input.value, &Default::default()).0,
477                ..input
478            }
479            .into(),
480            Default::default(),
481        )
482    }
483
484    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
485        (
486            Block {
487                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
488                span: input.span,
489                id: input.id,
490            },
491            Default::default(),
492        )
493    }
494
495    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
496        (
497            ConditionalStatement {
498                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
499                then: self.reconstruct_block(input.then).0,
500                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
501                ..input
502            }
503            .into(),
504            Default::default(),
505        )
506    }
507
508    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
509        (
510            ConstDeclaration {
511                type_: self.reconstruct_type(input.type_).0,
512                value: self.reconstruct_expression(input.value, &Default::default()).0,
513                ..input
514            }
515            .into(),
516            Default::default(),
517        )
518    }
519
520    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
521        (
522            DefinitionStatement {
523                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
524                value: self.reconstruct_expression(input.value, &Default::default()).0,
525                ..input
526            }
527            .into(),
528            Default::default(),
529        )
530    }
531
532    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
533        (
534            ExpressionStatement {
535                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
536                ..input
537            }
538            .into(),
539            Default::default(),
540        )
541    }
542
543    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
544        (
545            IterationStatement {
546                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
547                start: self.reconstruct_expression(input.start, &Default::default()).0,
548                stop: self.reconstruct_expression(input.stop, &Default::default()).0,
549                block: self.reconstruct_block(input.block).0,
550                ..input
551            }
552            .into(),
553            Default::default(),
554        )
555    }
556
557    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
558        (
559            ReturnStatement {
560                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
561                ..input
562            }
563            .into(),
564            Default::default(),
565        )
566    }
567}
568
569/// A Reconstructor trait for the program represented by the AST.
570pub trait ProgramReconstructor: AstReconstructor {
571    fn reconstruct_program(&mut self, input: Program) -> Program {
572        let program_scopes =
573            input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
574        Program {
575            imports: input
576                .imports
577                .into_iter()
578                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
579                .collect(),
580            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
581            modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
582            program_scopes,
583        }
584    }
585
586    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
587        Stub {
588            imports: input.imports,
589            stub_id: input.stub_id,
590            consts: input.consts,
591            structs: input.structs,
592            mappings: input.mappings,
593            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
594            span: input.span,
595        }
596    }
597
598    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
599        ProgramScope {
600            program_id: input.program_id,
601            consts: input
602                .consts
603                .into_iter()
604                .map(|(i, c)| match self.reconstruct_const(c) {
605                    (Statement::Const(declaration), _) => (i, declaration),
606                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
607                })
608                .collect(),
609            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
610            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
611            storage_variables: input
612                .storage_variables
613                .into_iter()
614                .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
615                .collect(),
616            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
617            constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
618            span: input.span,
619        }
620    }
621
622    fn reconstruct_module(&mut self, input: Module) -> Module {
623        Module {
624            program_name: input.program_name,
625            path: input.path,
626            consts: input
627                .consts
628                .into_iter()
629                .map(|(i, c)| match self.reconstruct_const(c) {
630                    (Statement::Const(declaration), _) => (i, declaration),
631                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
632                })
633                .collect(),
634            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
635            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
636        }
637    }
638
639    fn reconstruct_function(&mut self, input: Function) -> Function {
640        Function {
641            annotations: input.annotations,
642            variant: input.variant,
643            identifier: input.identifier,
644            const_parameters: input
645                .const_parameters
646                .iter()
647                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
648                .collect(),
649            input: input
650                .input
651                .iter()
652                .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
653                .collect(),
654            output: input
655                .output
656                .iter()
657                .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
658                .collect(),
659            output_type: self.reconstruct_type(input.output_type).0,
660            block: self.reconstruct_block(input.block).0,
661            span: input.span,
662            id: input.id,
663        }
664    }
665
666    fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
667        Constructor {
668            annotations: input.annotations,
669            block: self.reconstruct_block(input.block).0,
670            span: input.span,
671            id: input.id,
672        }
673    }
674
675    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
676        input
677    }
678
679    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
680        Composite {
681            const_parameters: input
682                .const_parameters
683                .iter()
684                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
685                .collect(),
686            members: input
687                .members
688                .iter()
689                .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
690                .collect(),
691            ..input
692        }
693    }
694
695    fn reconstruct_import(&mut self, input: Program) -> Program {
696        self.reconstruct_program(input)
697    }
698
699    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
700        Mapping {
701            key_type: self.reconstruct_type(input.key_type).0,
702            value_type: self.reconstruct_type(input.value_type).0,
703            ..input
704        }
705    }
706
707    fn reconstruct_storage_variable(&mut self, input: StorageVariable) -> StorageVariable {
708        StorageVariable { type_: self.reconstruct_type(input.type_).0, ..input }
709    }
710}