leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for types in the AST.
24pub trait AstReconstructor {
25    type AdditionalOutput: Default;
26
27    /* Types */
28    fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
29        match input {
30            Type::Array(array_type) => self.reconstruct_array_type(array_type),
31            Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
32            Type::Future(future_type) => self.reconstruct_future_type(future_type),
33            Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
34            Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
35            Type::Address
36            | Type::Boolean
37            | Type::Field
38            | Type::Group
39            | Type::Identifier(_)
40            | Type::Integer(_)
41            | Type::Scalar
42            | Type::Signature
43            | Type::String
44            | Type::Numeric
45            | Type::Unit
46            | Type::Err => (input.clone(), Default::default()),
47        }
48    }
49
50    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
51        (
52            Type::Array(ArrayType {
53                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
54                length: Box::new(self.reconstruct_expression(*input.length).0),
55            }),
56            Default::default(),
57        )
58    }
59
60    fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
61        (
62            Type::Composite(CompositeType {
63                const_arguments: input
64                    .const_arguments
65                    .into_iter()
66                    .map(|arg| self.reconstruct_expression(arg).0)
67                    .collect(),
68                ..input
69            }),
70            Default::default(),
71        )
72    }
73
74    fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
75        (
76            Type::Future(FutureType {
77                inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
78                ..input
79            }),
80            Default::default(),
81        )
82    }
83
84    fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
85        (
86            Type::Mapping(MappingType {
87                key: Box::new(self.reconstruct_type(*input.key).0),
88                value: Box::new(self.reconstruct_type(*input.value).0),
89                ..input
90            }),
91            Default::default(),
92        )
93    }
94
95    fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
96        (
97            Type::Tuple(TupleType {
98                elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
99            }),
100            Default::default(),
101        )
102    }
103
104    /* Expressions */
105    fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
106        match input {
107            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
108            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
109            Expression::Async(async_) => self.reconstruct_async(async_),
110            Expression::Array(array) => self.reconstruct_array(array),
111            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
112            Expression::Binary(binary) => self.reconstruct_binary(*binary),
113            Expression::Call(call) => self.reconstruct_call(*call),
114            Expression::Cast(cast) => self.reconstruct_cast(*cast),
115            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
116            Expression::Err(err) => self.reconstruct_err(err),
117            Expression::Path(path) => self.reconstruct_path(path),
118            Expression::Literal(value) => self.reconstruct_literal(value),
119            Expression::Locator(locator) => self.reconstruct_locator(locator),
120            Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
121            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
122            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
123            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
124            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
125            Expression::Unary(unary) => self.reconstruct_unary(*unary),
126            Expression::Unit(unit) => self.reconstruct_unit(unit),
127        }
128    }
129
130    fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
131        (
132            ArrayAccess {
133                array: self.reconstruct_expression(input.array).0,
134                index: self.reconstruct_expression(input.index).0,
135                ..input
136            }
137            .into(),
138            Default::default(),
139        )
140    }
141
142    fn reconstruct_associated_constant(
143        &mut self,
144        input: AssociatedConstantExpression,
145    ) -> (Expression, Self::AdditionalOutput) {
146        (input.into(), Default::default())
147    }
148
149    fn reconstruct_associated_function(
150        &mut self,
151        input: AssociatedFunctionExpression,
152    ) -> (Expression, Self::AdditionalOutput) {
153        (
154            AssociatedFunctionExpression {
155                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
156                ..input
157            }
158            .into(),
159            Default::default(),
160        )
161    }
162
163    fn reconstruct_async(&mut self, input: AsyncExpression) -> (Expression, Self::AdditionalOutput) {
164        (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
165    }
166
167    fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
168        (MemberAccess { inner: self.reconstruct_expression(input.inner).0, ..input }.into(), Default::default())
169    }
170
171    fn reconstruct_repeat(&mut self, input: RepeatExpression) -> (Expression, Self::AdditionalOutput) {
172        (
173            RepeatExpression {
174                expr: self.reconstruct_expression(input.expr).0,
175                count: self.reconstruct_expression(input.count).0,
176                ..input
177            }
178            .into(),
179            Default::default(),
180        )
181    }
182
183    fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
184        (TupleAccess { tuple: self.reconstruct_expression(input.tuple).0, ..input }.into(), Default::default())
185    }
186
187    fn reconstruct_array(&mut self, input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
188        (
189            ArrayExpression {
190                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
191                ..input
192            }
193            .into(),
194            Default::default(),
195        )
196    }
197
198    fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
199        (
200            BinaryExpression {
201                left: self.reconstruct_expression(input.left).0,
202                right: self.reconstruct_expression(input.right).0,
203                ..input
204            }
205            .into(),
206            Default::default(),
207        )
208    }
209
210    fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
211        (
212            CallExpression {
213                const_arguments: input
214                    .const_arguments
215                    .into_iter()
216                    .map(|arg| self.reconstruct_expression(arg).0)
217                    .collect(),
218                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
219                ..input
220            }
221            .into(),
222            Default::default(),
223        )
224    }
225
226    fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
227        (
228            CastExpression { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
229            Default::default(),
230        )
231    }
232
233    fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
234        (
235            StructExpression {
236                const_arguments: input
237                    .const_arguments
238                    .into_iter()
239                    .map(|arg| self.reconstruct_expression(arg).0)
240                    .collect(),
241                members: input
242                    .members
243                    .into_iter()
244                    .map(|member| StructVariableInitializer {
245                        identifier: member.identifier,
246                        expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
247                        span: member.span,
248                        id: member.id,
249                    })
250                    .collect(),
251                ..input
252            }
253            .into(),
254            Default::default(),
255        )
256    }
257
258    fn reconstruct_err(&mut self, _input: ErrExpression) -> (Expression, Self::AdditionalOutput) {
259        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
260    }
261
262    fn reconstruct_path(&mut self, input: Path) -> (Expression, Self::AdditionalOutput) {
263        (input.into(), Default::default())
264    }
265
266    fn reconstruct_literal(&mut self, input: Literal) -> (Expression, Self::AdditionalOutput) {
267        (input.into(), Default::default())
268    }
269
270    fn reconstruct_locator(&mut self, input: LocatorExpression) -> (Expression, Self::AdditionalOutput) {
271        (input.into(), Default::default())
272    }
273
274    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
275        (
276            TernaryExpression {
277                condition: self.reconstruct_expression(input.condition).0,
278                if_true: self.reconstruct_expression(input.if_true).0,
279                if_false: self.reconstruct_expression(input.if_false).0,
280                span: input.span,
281                id: input.id,
282            }
283            .into(),
284            Default::default(),
285        )
286    }
287
288    fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) {
289        (
290            TupleExpression {
291                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
292                ..input
293            }
294            .into(),
295            Default::default(),
296        )
297    }
298
299    fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
300        (
301            UnaryExpression { receiver: self.reconstruct_expression(input.receiver).0, ..input }.into(),
302            Default::default(),
303        )
304    }
305
306    fn reconstruct_unit(&mut self, input: UnitExpression) -> (Expression, Self::AdditionalOutput) {
307        (input.into(), Default::default())
308    }
309
310    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
311        match input {
312            Statement::Assert(assert) => self.reconstruct_assert(assert),
313            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
314            Statement::Block(stmt) => {
315                let (stmt, output) = self.reconstruct_block(stmt);
316                (stmt.into(), output)
317            }
318            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
319            Statement::Const(stmt) => self.reconstruct_const(stmt),
320            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
321            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
322            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
323            Statement::Return(stmt) => self.reconstruct_return(stmt),
324        }
325    }
326
327    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
328        (
329            AssertStatement {
330                variant: match input.variant {
331                    AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
332                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
333                        self.reconstruct_expression(left).0,
334                        self.reconstruct_expression(right).0,
335                    ),
336                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
337                        self.reconstruct_expression(left).0,
338                        self.reconstruct_expression(right).0,
339                    ),
340                },
341                ..input
342            }
343            .into(),
344            Default::default(),
345        )
346    }
347
348    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
349        (
350            AssignStatement {
351                place: self.reconstruct_expression(input.place).0,
352                value: self.reconstruct_expression(input.value).0,
353                ..input
354            }
355            .into(),
356            Default::default(),
357        )
358    }
359
360    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
361        (
362            Block {
363                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
364                span: input.span,
365                id: input.id,
366            },
367            Default::default(),
368        )
369    }
370
371    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
372        (
373            ConditionalStatement {
374                condition: self.reconstruct_expression(input.condition).0,
375                then: self.reconstruct_block(input.then).0,
376                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
377                ..input
378            }
379            .into(),
380            Default::default(),
381        )
382    }
383
384    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
385        (
386            ConstDeclaration {
387                type_: self.reconstruct_type(input.type_).0,
388                value: self.reconstruct_expression(input.value).0,
389                ..input
390            }
391            .into(),
392            Default::default(),
393        )
394    }
395
396    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
397        (
398            DefinitionStatement {
399                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
400                value: self.reconstruct_expression(input.value).0,
401                ..input
402            }
403            .into(),
404            Default::default(),
405        )
406    }
407
408    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
409        (
410            ExpressionStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
411            Default::default(),
412        )
413    }
414
415    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
416        (
417            IterationStatement {
418                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
419                start: self.reconstruct_expression(input.start).0,
420                stop: self.reconstruct_expression(input.stop).0,
421                block: self.reconstruct_block(input.block).0,
422                ..input
423            }
424            .into(),
425            Default::default(),
426        )
427    }
428
429    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
430        (
431            ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
432            Default::default(),
433        )
434    }
435}
436
437/// A Reconstructor trait for the program represented by the AST.
438pub trait ProgramReconstructor: AstReconstructor {
439    fn reconstruct_program(&mut self, input: Program) -> Program {
440        Program {
441            imports: input
442                .imports
443                .into_iter()
444                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
445                .collect(),
446            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
447            modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
448            program_scopes: input
449                .program_scopes
450                .into_iter()
451                .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
452                .collect(),
453        }
454    }
455
456    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
457        Stub {
458            imports: input.imports,
459            stub_id: input.stub_id,
460            consts: input.consts,
461            structs: input.structs,
462            mappings: input.mappings,
463            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
464            span: input.span,
465        }
466    }
467
468    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
469        ProgramScope {
470            program_id: input.program_id,
471            consts: input
472                .consts
473                .into_iter()
474                .map(|(i, c)| match self.reconstruct_const(c) {
475                    (Statement::Const(declaration), _) => (i, declaration),
476                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
477                })
478                .collect(),
479            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
480            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
481            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
482            constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
483            span: input.span,
484        }
485    }
486
487    fn reconstruct_module(&mut self, input: Module) -> Module {
488        Module {
489            program_name: input.program_name,
490            path: input.path,
491            consts: input
492                .consts
493                .into_iter()
494                .map(|(i, c)| match self.reconstruct_const(c) {
495                    (Statement::Const(declaration), _) => (i, declaration),
496                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
497                })
498                .collect(),
499            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
500            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
501        }
502    }
503
504    fn reconstruct_function(&mut self, input: Function) -> Function {
505        Function {
506            annotations: input.annotations,
507            variant: input.variant,
508            identifier: input.identifier,
509            const_parameters: input
510                .const_parameters
511                .iter()
512                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
513                .collect(),
514            input: input
515                .input
516                .iter()
517                .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
518                .collect(),
519            output: input
520                .output
521                .iter()
522                .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
523                .collect(),
524            output_type: self.reconstruct_type(input.output_type).0,
525            block: self.reconstruct_block(input.block).0,
526            span: input.span,
527            id: input.id,
528        }
529    }
530
531    fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
532        Constructor {
533            annotations: input.annotations,
534            block: self.reconstruct_block(input.block).0,
535            span: input.span,
536            id: input.id,
537        }
538    }
539
540    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
541        input
542    }
543
544    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
545        Composite {
546            const_parameters: input
547                .const_parameters
548                .iter()
549                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
550                .collect(),
551            members: input
552                .members
553                .iter()
554                .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
555                .collect(),
556            ..input
557        }
558    }
559
560    fn reconstruct_import(&mut self, input: Program) -> Program {
561        self.reconstruct_program(input)
562    }
563
564    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
565        Mapping {
566            key_type: self.reconstruct_type(input.key_type).0,
567            value_type: self.reconstruct_type(input.value_type).0,
568            ..input
569        }
570    }
571}