leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for types in the AST.
24pub trait AstReconstructor {
25    type AdditionalOutput: Default;
26    type AdditionalInput: Default;
27
28    /* Types */
29    fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30        match input {
31            Type::Array(array_type) => self.reconstruct_array_type(array_type),
32            Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33            Type::Future(future_type) => self.reconstruct_future_type(future_type),
34            Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35            Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36            Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37            Type::Address
38            | Type::Boolean
39            | Type::Field
40            | Type::Group
41            | Type::Identifier(_)
42            | Type::Integer(_)
43            | Type::Scalar
44            | Type::Signature
45            | Type::String
46            | Type::Numeric
47            | Type::Unit
48            | Type::Err => (input.clone(), Default::default()),
49        }
50    }
51
52    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
53        (
54            Type::Array(ArrayType {
55                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
56                length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
57            }),
58            Default::default(),
59        )
60    }
61
62    fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
63        (
64            Type::Composite(CompositeType {
65                const_arguments: input
66                    .const_arguments
67                    .into_iter()
68                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
69                    .collect(),
70                ..input
71            }),
72            Default::default(),
73        )
74    }
75
76    fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
77        (
78            Type::Future(FutureType {
79                inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
80                ..input
81            }),
82            Default::default(),
83        )
84    }
85
86    fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
87        (
88            Type::Mapping(MappingType {
89                key: Box::new(self.reconstruct_type(*input.key).0),
90                value: Box::new(self.reconstruct_type(*input.value).0),
91                ..input
92            }),
93            Default::default(),
94        )
95    }
96
97    fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
98        (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
99    }
100
101    fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
102        (
103            Type::Tuple(TupleType {
104                elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
105            }),
106            Default::default(),
107        )
108    }
109
110    /* Expressions */
111    fn reconstruct_expression(
112        &mut self,
113        input: Expression,
114        additional: &Self::AdditionalInput,
115    ) -> (Expression, Self::AdditionalOutput) {
116        match input {
117            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, additional),
118            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, additional),
119            Expression::Async(async_) => self.reconstruct_async(async_, additional),
120            Expression::Array(array) => self.reconstruct_array(array, additional),
121            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
122            Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
123            Expression::Call(call) => self.reconstruct_call(*call, additional),
124            Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
125            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
126            Expression::Err(err) => self.reconstruct_err(err, additional),
127            Expression::Path(path) => self.reconstruct_path(path, additional),
128            Expression::Literal(value) => self.reconstruct_literal(value, additional),
129            Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
130            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
131            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
132            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
133            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
134            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
135            Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
136            Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
137        }
138    }
139
140    fn reconstruct_array_access(
141        &mut self,
142        input: ArrayAccess,
143        _additional: &Self::AdditionalInput,
144    ) -> (Expression, Self::AdditionalOutput) {
145        (
146            ArrayAccess {
147                array: self.reconstruct_expression(input.array, &Default::default()).0,
148                index: self.reconstruct_expression(input.index, &Default::default()).0,
149                ..input
150            }
151            .into(),
152            Default::default(),
153        )
154    }
155
156    fn reconstruct_associated_constant(
157        &mut self,
158        input: AssociatedConstantExpression,
159        _additional: &Self::AdditionalInput,
160    ) -> (Expression, Self::AdditionalOutput) {
161        (input.into(), Default::default())
162    }
163
164    fn reconstruct_associated_function(
165        &mut self,
166        input: AssociatedFunctionExpression,
167        _additional: &Self::AdditionalInput,
168    ) -> (Expression, Self::AdditionalOutput) {
169        (
170            AssociatedFunctionExpression {
171                arguments: input
172                    .arguments
173                    .into_iter()
174                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
175                    .collect(),
176                ..input
177            }
178            .into(),
179            Default::default(),
180        )
181    }
182
183    fn reconstruct_async(
184        &mut self,
185        input: AsyncExpression,
186        _additional: &Self::AdditionalInput,
187    ) -> (Expression, Self::AdditionalOutput) {
188        (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
189    }
190
191    fn reconstruct_member_access(
192        &mut self,
193        input: MemberAccess,
194        _additional: &Self::AdditionalInput,
195    ) -> (Expression, Self::AdditionalOutput) {
196        (
197            MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
198            Default::default(),
199        )
200    }
201
202    fn reconstruct_repeat(
203        &mut self,
204        input: RepeatExpression,
205        _additional: &Self::AdditionalInput,
206    ) -> (Expression, Self::AdditionalOutput) {
207        (
208            RepeatExpression {
209                expr: self.reconstruct_expression(input.expr, &Default::default()).0,
210                count: self.reconstruct_expression(input.count, &Default::default()).0,
211                ..input
212            }
213            .into(),
214            Default::default(),
215        )
216    }
217
218    fn reconstruct_tuple_access(
219        &mut self,
220        input: TupleAccess,
221        _additional: &Self::AdditionalInput,
222    ) -> (Expression, Self::AdditionalOutput) {
223        (
224            TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
225            Default::default(),
226        )
227    }
228
229    fn reconstruct_array(
230        &mut self,
231        input: ArrayExpression,
232        _additional: &Self::AdditionalInput,
233    ) -> (Expression, Self::AdditionalOutput) {
234        (
235            ArrayExpression {
236                elements: input
237                    .elements
238                    .into_iter()
239                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
240                    .collect(),
241                ..input
242            }
243            .into(),
244            Default::default(),
245        )
246    }
247
248    fn reconstruct_binary(
249        &mut self,
250        input: BinaryExpression,
251        _additional: &Self::AdditionalInput,
252    ) -> (Expression, Self::AdditionalOutput) {
253        (
254            BinaryExpression {
255                left: self.reconstruct_expression(input.left, &Default::default()).0,
256                right: self.reconstruct_expression(input.right, &Default::default()).0,
257                ..input
258            }
259            .into(),
260            Default::default(),
261        )
262    }
263
264    fn reconstruct_call(
265        &mut self,
266        input: CallExpression,
267        _additional: &Self::AdditionalInput,
268    ) -> (Expression, Self::AdditionalOutput) {
269        (
270            CallExpression {
271                const_arguments: input
272                    .const_arguments
273                    .into_iter()
274                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
275                    .collect(),
276                arguments: input
277                    .arguments
278                    .into_iter()
279                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
280                    .collect(),
281                ..input
282            }
283            .into(),
284            Default::default(),
285        )
286    }
287
288    fn reconstruct_cast(
289        &mut self,
290        input: CastExpression,
291        _additional: &Self::AdditionalInput,
292    ) -> (Expression, Self::AdditionalOutput) {
293        (
294            CastExpression {
295                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
296                ..input
297            }
298            .into(),
299            Default::default(),
300        )
301    }
302
303    fn reconstruct_struct_init(
304        &mut self,
305        input: StructExpression,
306        _additional: &Self::AdditionalInput,
307    ) -> (Expression, Self::AdditionalOutput) {
308        (
309            StructExpression {
310                const_arguments: input
311                    .const_arguments
312                    .into_iter()
313                    .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
314                    .collect(),
315                members: input
316                    .members
317                    .into_iter()
318                    .map(|member| StructVariableInitializer {
319                        identifier: member.identifier,
320                        expression: member
321                            .expression
322                            .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
323                        span: member.span,
324                        id: member.id,
325                    })
326                    .collect(),
327                ..input
328            }
329            .into(),
330            Default::default(),
331        )
332    }
333
334    fn reconstruct_err(
335        &mut self,
336        _input: ErrExpression,
337        _additional: &Self::AdditionalInput,
338    ) -> (Expression, Self::AdditionalOutput) {
339        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
340    }
341
342    fn reconstruct_path(
343        &mut self,
344        input: Path,
345        _additional: &Self::AdditionalInput,
346    ) -> (Expression, Self::AdditionalOutput) {
347        (input.into(), Default::default())
348    }
349
350    fn reconstruct_literal(
351        &mut self,
352        input: Literal,
353        _additional: &Self::AdditionalInput,
354    ) -> (Expression, Self::AdditionalOutput) {
355        (input.into(), Default::default())
356    }
357
358    fn reconstruct_locator(
359        &mut self,
360        input: LocatorExpression,
361        _additional: &Self::AdditionalInput,
362    ) -> (Expression, Self::AdditionalOutput) {
363        (input.into(), Default::default())
364    }
365
366    fn reconstruct_ternary(
367        &mut self,
368        input: TernaryExpression,
369        _additional: &Self::AdditionalInput,
370    ) -> (Expression, Self::AdditionalOutput) {
371        (
372            TernaryExpression {
373                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
374                if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
375                if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
376                span: input.span,
377                id: input.id,
378            }
379            .into(),
380            Default::default(),
381        )
382    }
383
384    fn reconstruct_tuple(
385        &mut self,
386        input: TupleExpression,
387        _additional: &Self::AdditionalInput,
388    ) -> (Expression, Self::AdditionalOutput) {
389        (
390            TupleExpression {
391                elements: input
392                    .elements
393                    .into_iter()
394                    .map(|element| self.reconstruct_expression(element, &Default::default()).0)
395                    .collect(),
396                ..input
397            }
398            .into(),
399            Default::default(),
400        )
401    }
402
403    fn reconstruct_unary(
404        &mut self,
405        input: UnaryExpression,
406        _additional: &Self::AdditionalInput,
407    ) -> (Expression, Self::AdditionalOutput) {
408        (
409            UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
410                .into(),
411            Default::default(),
412        )
413    }
414
415    fn reconstruct_unit(
416        &mut self,
417        input: UnitExpression,
418        _additional: &Self::AdditionalInput,
419    ) -> (Expression, Self::AdditionalOutput) {
420        (input.into(), Default::default())
421    }
422
423    /* Statements */
424    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
425        match input {
426            Statement::Assert(assert) => self.reconstruct_assert(assert),
427            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
428            Statement::Block(stmt) => {
429                let (stmt, output) = self.reconstruct_block(stmt);
430                (stmt.into(), output)
431            }
432            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
433            Statement::Const(stmt) => self.reconstruct_const(stmt),
434            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
435            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
436            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
437            Statement::Return(stmt) => self.reconstruct_return(stmt),
438        }
439    }
440
441    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
442        (
443            AssertStatement {
444                variant: match input.variant {
445                    AssertVariant::Assert(expr) => {
446                        AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
447                    }
448                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
449                        self.reconstruct_expression(left, &Default::default()).0,
450                        self.reconstruct_expression(right, &Default::default()).0,
451                    ),
452                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
453                        self.reconstruct_expression(left, &Default::default()).0,
454                        self.reconstruct_expression(right, &Default::default()).0,
455                    ),
456                },
457                ..input
458            }
459            .into(),
460            Default::default(),
461        )
462    }
463
464    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
465        (
466            AssignStatement {
467                place: self.reconstruct_expression(input.place, &Default::default()).0,
468                value: self.reconstruct_expression(input.value, &Default::default()).0,
469                ..input
470            }
471            .into(),
472            Default::default(),
473        )
474    }
475
476    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
477        (
478            Block {
479                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
480                span: input.span,
481                id: input.id,
482            },
483            Default::default(),
484        )
485    }
486
487    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
488        (
489            ConditionalStatement {
490                condition: self.reconstruct_expression(input.condition, &Default::default()).0,
491                then: self.reconstruct_block(input.then).0,
492                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
493                ..input
494            }
495            .into(),
496            Default::default(),
497        )
498    }
499
500    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
501        (
502            ConstDeclaration {
503                type_: self.reconstruct_type(input.type_).0,
504                value: self.reconstruct_expression(input.value, &Default::default()).0,
505                ..input
506            }
507            .into(),
508            Default::default(),
509        )
510    }
511
512    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
513        (
514            DefinitionStatement {
515                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
516                value: self.reconstruct_expression(input.value, &Default::default()).0,
517                ..input
518            }
519            .into(),
520            Default::default(),
521        )
522    }
523
524    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
525        (
526            ExpressionStatement {
527                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
528                ..input
529            }
530            .into(),
531            Default::default(),
532        )
533    }
534
535    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
536        (
537            IterationStatement {
538                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
539                start: self.reconstruct_expression(input.start, &Default::default()).0,
540                stop: self.reconstruct_expression(input.stop, &Default::default()).0,
541                block: self.reconstruct_block(input.block).0,
542                ..input
543            }
544            .into(),
545            Default::default(),
546        )
547    }
548
549    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
550        (
551            ReturnStatement {
552                expression: self.reconstruct_expression(input.expression, &Default::default()).0,
553                ..input
554            }
555            .into(),
556            Default::default(),
557        )
558    }
559}
560
561/// A Reconstructor trait for the program represented by the AST.
562pub trait ProgramReconstructor: AstReconstructor {
563    fn reconstruct_program(&mut self, input: Program) -> Program {
564        let program_scopes =
565            input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
566        Program {
567            imports: input
568                .imports
569                .into_iter()
570                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
571                .collect(),
572            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
573            modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
574            program_scopes,
575        }
576    }
577
578    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
579        Stub {
580            imports: input.imports,
581            stub_id: input.stub_id,
582            consts: input.consts,
583            structs: input.structs,
584            mappings: input.mappings,
585            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
586            span: input.span,
587        }
588    }
589
590    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
591        ProgramScope {
592            program_id: input.program_id,
593            consts: input
594                .consts
595                .into_iter()
596                .map(|(i, c)| match self.reconstruct_const(c) {
597                    (Statement::Const(declaration), _) => (i, declaration),
598                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
599                })
600                .collect(),
601            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
602            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
603            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
604            constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
605            span: input.span,
606        }
607    }
608
609    fn reconstruct_module(&mut self, input: Module) -> Module {
610        Module {
611            program_name: input.program_name,
612            path: input.path,
613            consts: input
614                .consts
615                .into_iter()
616                .map(|(i, c)| match self.reconstruct_const(c) {
617                    (Statement::Const(declaration), _) => (i, declaration),
618                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
619                })
620                .collect(),
621            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
622            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
623        }
624    }
625
626    fn reconstruct_function(&mut self, input: Function) -> Function {
627        Function {
628            annotations: input.annotations,
629            variant: input.variant,
630            identifier: input.identifier,
631            const_parameters: input
632                .const_parameters
633                .iter()
634                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
635                .collect(),
636            input: input
637                .input
638                .iter()
639                .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
640                .collect(),
641            output: input
642                .output
643                .iter()
644                .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
645                .collect(),
646            output_type: self.reconstruct_type(input.output_type).0,
647            block: self.reconstruct_block(input.block).0,
648            span: input.span,
649            id: input.id,
650        }
651    }
652
653    fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
654        Constructor {
655            annotations: input.annotations,
656            block: self.reconstruct_block(input.block).0,
657            span: input.span,
658            id: input.id,
659        }
660    }
661
662    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
663        input
664    }
665
666    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
667        Composite {
668            const_parameters: input
669                .const_parameters
670                .iter()
671                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
672                .collect(),
673            members: input
674                .members
675                .iter()
676                .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
677                .collect(),
678            ..input
679        }
680    }
681
682    fn reconstruct_import(&mut self, input: Program) -> Program {
683        self.reconstruct_program(input)
684    }
685
686    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
687        Mapping {
688            key_type: self.reconstruct_type(input.key_type).0,
689            value_type: self.reconstruct_type(input.value_type).0,
690            ..input
691        }
692    }
693}