leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for types in the AST.
24pub trait TypeReconstructor: ExpressionReconstructor {
25    fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
26        match input {
27            Type::Array(array_type) => self.reconstruct_array_type(array_type),
28            Type::Future(future_type) => self.reconstruct_future_type(future_type),
29            Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
30            Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
31            Type::Address
32            | Type::Boolean
33            | Type::Composite(_)
34            | Type::Field
35            | Type::Group
36            | Type::Identifier(_)
37            | Type::Integer(_)
38            | Type::Scalar
39            | Type::Signature
40            | Type::String
41            | Type::Numeric
42            | Type::Unit
43            | Type::Err => (input.clone(), Default::default()),
44        }
45    }
46
47    fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
48        (
49            Type::Array(ArrayType {
50                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
51                length: Box::new(self.reconstruct_expression(*input.length).0),
52            }),
53            Default::default(),
54        )
55    }
56
57    fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
58        (
59            Type::Future(FutureType {
60                inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
61                ..input
62            }),
63            Default::default(),
64        )
65    }
66
67    fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
68        (
69            Type::Mapping(MappingType {
70                key: Box::new(self.reconstruct_type(*input.key).0),
71                value: Box::new(self.reconstruct_type(*input.value).0),
72                ..input
73            }),
74            Default::default(),
75        )
76    }
77
78    fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
79        (
80            Type::Tuple(TupleType {
81                elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
82            }),
83            Default::default(),
84        )
85    }
86}
87
88/// A Reconstructor trait for expressions in the AST.
89pub trait ExpressionReconstructor {
90    type AdditionalOutput: Default;
91
92    fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
93        match input {
94            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
95            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
96            Expression::Array(array) => self.reconstruct_array(array),
97            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
98            Expression::Binary(binary) => self.reconstruct_binary(*binary),
99            Expression::Call(call) => self.reconstruct_call(*call),
100            Expression::Cast(cast) => self.reconstruct_cast(*cast),
101            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
102            Expression::Err(err) => self.reconstruct_err(err),
103            Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
104            Expression::Literal(value) => self.reconstruct_literal(value),
105            Expression::Locator(locator) => self.reconstruct_locator(locator),
106            Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
107            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
108            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
109            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
110            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
111            Expression::Unary(unary) => self.reconstruct_unary(*unary),
112            Expression::Unit(unit) => self.reconstruct_unit(unit),
113        }
114    }
115
116    fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
117        (
118            ArrayAccess {
119                array: self.reconstruct_expression(input.array).0,
120                index: self.reconstruct_expression(input.index).0,
121                ..input
122            }
123            .into(),
124            Default::default(),
125        )
126    }
127
128    fn reconstruct_associated_constant(
129        &mut self,
130        input: AssociatedConstantExpression,
131    ) -> (Expression, Self::AdditionalOutput) {
132        (input.into(), Default::default())
133    }
134
135    fn reconstruct_associated_function(
136        &mut self,
137        input: AssociatedFunctionExpression,
138    ) -> (Expression, Self::AdditionalOutput) {
139        (
140            AssociatedFunctionExpression {
141                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
142                ..input
143            }
144            .into(),
145            Default::default(),
146        )
147    }
148
149    fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
150        (MemberAccess { inner: self.reconstruct_expression(input.inner).0, ..input }.into(), Default::default())
151    }
152
153    fn reconstruct_repeat(&mut self, input: RepeatExpression) -> (Expression, Self::AdditionalOutput) {
154        (
155            RepeatExpression {
156                expr: self.reconstruct_expression(input.expr).0,
157                count: self.reconstruct_expression(input.count).0,
158                ..input
159            }
160            .into(),
161            Default::default(),
162        )
163    }
164
165    fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
166        (TupleAccess { tuple: self.reconstruct_expression(input.tuple).0, ..input }.into(), Default::default())
167    }
168
169    fn reconstruct_array(&mut self, input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
170        (
171            ArrayExpression {
172                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
173                ..input
174            }
175            .into(),
176            Default::default(),
177        )
178    }
179
180    fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
181        (
182            BinaryExpression {
183                left: self.reconstruct_expression(input.left).0,
184                right: self.reconstruct_expression(input.right).0,
185                ..input
186            }
187            .into(),
188            Default::default(),
189        )
190    }
191
192    fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
193        (
194            CallExpression {
195                const_arguments: input
196                    .const_arguments
197                    .into_iter()
198                    .map(|arg| self.reconstruct_expression(arg).0)
199                    .collect(),
200                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
201                ..input
202            }
203            .into(),
204            Default::default(),
205        )
206    }
207
208    fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
209        (
210            CastExpression { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
211            Default::default(),
212        )
213    }
214
215    fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
216        (
217            StructExpression {
218                members: input
219                    .members
220                    .into_iter()
221                    .map(|member| StructVariableInitializer {
222                        identifier: member.identifier,
223                        expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
224                        span: member.span,
225                        id: member.id,
226                    })
227                    .collect(),
228                ..input
229            }
230            .into(),
231            Default::default(),
232        )
233    }
234
235    fn reconstruct_err(&mut self, _input: ErrExpression) -> (Expression, Self::AdditionalOutput) {
236        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
237    }
238
239    fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
240        (input.into(), Default::default())
241    }
242
243    fn reconstruct_literal(&mut self, input: Literal) -> (Expression, Self::AdditionalOutput) {
244        (input.into(), Default::default())
245    }
246
247    fn reconstruct_locator(&mut self, input: LocatorExpression) -> (Expression, Self::AdditionalOutput) {
248        (input.into(), Default::default())
249    }
250
251    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
252        (
253            TernaryExpression {
254                condition: self.reconstruct_expression(input.condition).0,
255                if_true: self.reconstruct_expression(input.if_true).0,
256                if_false: self.reconstruct_expression(input.if_false).0,
257                span: input.span,
258                id: input.id,
259            }
260            .into(),
261            Default::default(),
262        )
263    }
264
265    fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) {
266        (
267            TupleExpression {
268                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
269                ..input
270            }
271            .into(),
272            Default::default(),
273        )
274    }
275
276    fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
277        (
278            UnaryExpression { receiver: self.reconstruct_expression(input.receiver).0, ..input }.into(),
279            Default::default(),
280        )
281    }
282
283    fn reconstruct_unit(&mut self, input: UnitExpression) -> (Expression, Self::AdditionalOutput) {
284        (input.into(), Default::default())
285    }
286}
287
288/// A Reconstructor trait for statements in the AST.
289pub trait StatementReconstructor: TypeReconstructor {
290    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
291        match input {
292            Statement::Assert(assert) => self.reconstruct_assert(assert),
293            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
294            Statement::Block(stmt) => {
295                let (stmt, output) = self.reconstruct_block(stmt);
296                (stmt.into(), output)
297            }
298            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
299            Statement::Const(stmt) => self.reconstruct_const(stmt),
300            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
301            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
302            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
303            Statement::Return(stmt) => self.reconstruct_return(stmt),
304        }
305    }
306
307    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
308        (
309            AssertStatement {
310                variant: match input.variant {
311                    AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
312                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
313                        self.reconstruct_expression(left).0,
314                        self.reconstruct_expression(right).0,
315                    ),
316                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
317                        self.reconstruct_expression(left).0,
318                        self.reconstruct_expression(right).0,
319                    ),
320                },
321                ..input
322            }
323            .into(),
324            Default::default(),
325        )
326    }
327
328    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
329        (AssignStatement { value: self.reconstruct_expression(input.value).0, ..input }.into(), Default::default())
330    }
331
332    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
333        (
334            Block {
335                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
336                span: input.span,
337                id: input.id,
338            },
339            Default::default(),
340        )
341    }
342
343    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
344        (
345            ConditionalStatement {
346                condition: self.reconstruct_expression(input.condition).0,
347                then: self.reconstruct_block(input.then).0,
348                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
349                ..input
350            }
351            .into(),
352            Default::default(),
353        )
354    }
355
356    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
357        (
358            ConstDeclaration {
359                type_: self.reconstruct_type(input.type_).0,
360                value: self.reconstruct_expression(input.value).0,
361                ..input
362            }
363            .into(),
364            Default::default(),
365        )
366    }
367
368    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
369        (
370            DefinitionStatement {
371                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
372                value: self.reconstruct_expression(input.value).0,
373                ..input
374            }
375            .into(),
376            Default::default(),
377        )
378    }
379
380    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
381        (
382            ExpressionStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
383            Default::default(),
384        )
385    }
386
387    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
388        (
389            IterationStatement {
390                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
391                start: self.reconstruct_expression(input.start).0,
392                stop: self.reconstruct_expression(input.stop).0,
393                block: self.reconstruct_block(input.block).0,
394                ..input
395            }
396            .into(),
397            Default::default(),
398        )
399    }
400
401    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
402        (
403            ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
404            Default::default(),
405        )
406    }
407}
408
409/// A Reconstructor trait for the program represented by the AST.
410pub trait ProgramReconstructor: StatementReconstructor {
411    fn reconstruct_program(&mut self, input: Program) -> Program {
412        Program {
413            imports: input
414                .imports
415                .into_iter()
416                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
417                .collect(),
418            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
419            program_scopes: input
420                .program_scopes
421                .into_iter()
422                .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
423                .collect(),
424        }
425    }
426
427    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
428        Stub {
429            imports: input.imports,
430            stub_id: input.stub_id,
431            consts: input.consts,
432            structs: input.structs,
433            mappings: input.mappings,
434            span: input.span,
435            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
436        }
437    }
438
439    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
440        ProgramScope {
441            program_id: input.program_id,
442            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
443            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
444            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
445            consts: input
446                .consts
447                .into_iter()
448                .map(|(i, c)| match self.reconstruct_const(c) {
449                    (Statement::Const(declaration), _) => (i, declaration),
450                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
451                })
452                .collect(),
453            span: input.span,
454        }
455    }
456
457    fn reconstruct_function(&mut self, input: Function) -> Function {
458        Function {
459            annotations: input.annotations,
460            variant: input.variant,
461            identifier: input.identifier,
462            const_parameters: input
463                .const_parameters
464                .iter()
465                .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
466                .collect(),
467            input: input
468                .input
469                .iter()
470                .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
471                .collect(),
472            output: input
473                .output
474                .iter()
475                .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
476                .collect(),
477            output_type: self.reconstruct_type(input.output_type).0,
478            block: self.reconstruct_block(input.block).0,
479            span: input.span,
480            id: input.id,
481        }
482    }
483
484    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
485        input
486    }
487
488    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
489        Composite {
490            members: input
491                .members
492                .iter()
493                .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
494                .collect(),
495            ..input
496        }
497    }
498
499    fn reconstruct_import(&mut self, input: Program) -> Program {
500        self.reconstruct_program(input)
501    }
502
503    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
504        Mapping {
505            key_type: self.reconstruct_type(input.key_type).0,
506            value_type: self.reconstruct_type(input.value_type).0,
507            ..input
508        }
509    }
510}