leo_ast/passes/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains Visitor trait implementations for the AST.
18//! It implements default methods for each node to be made
19//! given the type of node its visiting.
20
21use crate::*;
22
23// TODO: The Visitor and Reconstructor patterns need a redesign so that the default implementation can easily be invoked though its implemented in an overriding trait.
24// Here is a pattern that seems to work
25// trait ProgramVisitor {
26//     // The trait method that can be overridden
27//     fn visit_program_scope(&mut self);
28//
29//     // Private helper function containing the default implementation
30//     fn default_visit_program_scope(&mut self) {
31//         println!("Do default stuff");
32//     }
33// }
34//
35// struct YourStruct;
36//
37// impl ProgramVisitor for YourStruct {
38//     fn visit_program_scope(&mut self) {
39//         println!("Do custom stuff.");
40//         // Call the default implementation
41//         self.default_visit_program_scope();
42//     }
43// }
44
45/// A Visitor trait for types in the AST.
46pub trait AstVisitor {
47    /* Types */
48    fn visit_type(&mut self, input: &Type) {
49        match input {
50            Type::Array(array_type) => self.visit_array_type(array_type),
51            Type::Composite(composite_type) => self.visit_composite_type(composite_type),
52            Type::Future(future_type) => self.visit_future_type(future_type),
53            Type::Mapping(mapping_type) => self.visit_mapping_type(mapping_type),
54            Type::Optional(optional_type) => self.visit_optional_type(optional_type),
55            Type::Tuple(tuple_type) => self.visit_tuple_type(tuple_type),
56            Type::Address
57            | Type::Boolean
58            | Type::Field
59            | Type::Group
60            | Type::Identifier(_)
61            | Type::Integer(_)
62            | Type::Scalar
63            | Type::Signature
64            | Type::String
65            | Type::Numeric
66            | Type::Unit
67            | Type::Err => {}
68        }
69    }
70
71    fn visit_array_type(&mut self, input: &ArrayType) {
72        self.visit_type(&input.element_type);
73        self.visit_expression(&input.length, &Default::default());
74    }
75
76    fn visit_composite_type(&mut self, input: &CompositeType) {
77        input.const_arguments.iter().for_each(|expr| {
78            self.visit_expression(expr, &Default::default());
79        });
80    }
81
82    fn visit_future_type(&mut self, input: &FutureType) {
83        input.inputs.iter().for_each(|input| self.visit_type(input));
84    }
85
86    fn visit_mapping_type(&mut self, input: &MappingType) {
87        self.visit_type(&input.key);
88        self.visit_type(&input.value);
89    }
90
91    fn visit_optional_type(&mut self, input: &OptionalType) {
92        self.visit_type(&input.inner);
93    }
94
95    fn visit_tuple_type(&mut self, input: &TupleType) {
96        input.elements().iter().for_each(|input| self.visit_type(input));
97    }
98
99    /* Expressions */
100    type AdditionalInput: Default;
101    type Output: Default;
102
103    fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
104        match input {
105            Expression::Array(array) => self.visit_array(array, additional),
106            Expression::ArrayAccess(access) => self.visit_array_access(access, additional),
107            Expression::AssociatedConstant(constant) => self.visit_associated_constant(constant, additional),
108            Expression::AssociatedFunction(function) => self.visit_associated_function(function, additional),
109            Expression::Async(async_) => self.visit_async(async_, additional),
110            Expression::Binary(binary) => self.visit_binary(binary, additional),
111            Expression::Call(call) => self.visit_call(call, additional),
112            Expression::Cast(cast) => self.visit_cast(cast, additional),
113            Expression::Struct(struct_) => self.visit_struct_init(struct_, additional),
114            Expression::Err(err) => self.visit_err(err, additional),
115            Expression::Path(path) => self.visit_path(path, additional),
116            Expression::Literal(literal) => self.visit_literal(literal, additional),
117            Expression::Locator(locator) => self.visit_locator(locator, additional),
118            Expression::MemberAccess(access) => self.visit_member_access(access, additional),
119            Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
120            Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
121            Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
122            Expression::TupleAccess(access) => self.visit_tuple_access(access, additional),
123            Expression::Unary(unary) => self.visit_unary(unary, additional),
124            Expression::Unit(unit) => self.visit_unit(unit, additional),
125        }
126    }
127
128    fn visit_array_access(&mut self, input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
129        self.visit_expression(&input.array, &Default::default());
130        self.visit_expression(&input.index, &Default::default());
131        Default::default()
132    }
133
134    fn visit_member_access(&mut self, input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
135        self.visit_expression(&input.inner, &Default::default());
136        Default::default()
137    }
138
139    fn visit_tuple_access(&mut self, input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
140        self.visit_expression(&input.tuple, &Default::default());
141        Default::default()
142    }
143
144    fn visit_array(&mut self, input: &ArrayExpression, _additional: &Self::AdditionalInput) -> Self::Output {
145        input.elements.iter().for_each(|expr| {
146            self.visit_expression(expr, &Default::default());
147        });
148        Default::default()
149    }
150
151    fn visit_associated_constant(
152        &mut self,
153        _input: &AssociatedConstantExpression,
154        _additional: &Self::AdditionalInput,
155    ) -> Self::Output {
156        Default::default()
157    }
158
159    fn visit_associated_function(
160        &mut self,
161        input: &AssociatedFunctionExpression,
162        _additional: &Self::AdditionalInput,
163    ) -> Self::Output {
164        input.arguments.iter().for_each(|arg| {
165            self.visit_expression(arg, &Default::default());
166        });
167        Default::default()
168    }
169
170    fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
171        self.visit_block(&input.block);
172        Default::default()
173    }
174
175    fn visit_binary(&mut self, input: &BinaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
176        self.visit_expression(&input.left, &Default::default());
177        self.visit_expression(&input.right, &Default::default());
178        Default::default()
179    }
180
181    fn visit_call(&mut self, input: &CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
182        input.const_arguments.iter().for_each(|expr| {
183            self.visit_expression(expr, &Default::default());
184        });
185        input.arguments.iter().for_each(|expr| {
186            self.visit_expression(expr, &Default::default());
187        });
188        Default::default()
189    }
190
191    fn visit_cast(&mut self, input: &CastExpression, _additional: &Self::AdditionalInput) -> Self::Output {
192        self.visit_expression(&input.expression, &Default::default());
193        Default::default()
194    }
195
196    fn visit_struct_init(&mut self, input: &StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
197        input.const_arguments.iter().for_each(|expr| {
198            self.visit_expression(expr, &Default::default());
199        });
200        for StructVariableInitializer { expression, .. } in input.members.iter() {
201            if let Some(expression) = expression {
202                self.visit_expression(expression, &Default::default());
203            }
204        }
205        Default::default()
206    }
207
208    fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
209        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
210    }
211
212    fn visit_path(&mut self, _input: &Path, _additional: &Self::AdditionalInput) -> Self::Output {
213        Default::default()
214    }
215
216    fn visit_literal(&mut self, _input: &Literal, _additional: &Self::AdditionalInput) -> Self::Output {
217        Default::default()
218    }
219
220    fn visit_locator(&mut self, _input: &LocatorExpression, _additional: &Self::AdditionalInput) -> Self::Output {
221        Default::default()
222    }
223
224    fn visit_repeat(&mut self, input: &RepeatExpression, _additional: &Self::AdditionalInput) -> Self::Output {
225        self.visit_expression(&input.expr, &Default::default());
226        self.visit_expression(&input.count, &Default::default());
227        Default::default()
228    }
229
230    fn visit_ternary(&mut self, input: &TernaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
231        self.visit_expression(&input.condition, &Default::default());
232        self.visit_expression(&input.if_true, &Default::default());
233        self.visit_expression(&input.if_false, &Default::default());
234        Default::default()
235    }
236
237    fn visit_tuple(&mut self, input: &TupleExpression, _additional: &Self::AdditionalInput) -> Self::Output {
238        input.elements.iter().for_each(|expr| {
239            self.visit_expression(expr, &Default::default());
240        });
241        Default::default()
242    }
243
244    fn visit_unary(&mut self, input: &UnaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
245        self.visit_expression(&input.receiver, &Default::default());
246        Default::default()
247    }
248
249    fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
250        Default::default()
251    }
252
253    /* Statements */
254    fn visit_statement(&mut self, input: &Statement) {
255        match input {
256            Statement::Assert(stmt) => self.visit_assert(stmt),
257            Statement::Assign(stmt) => self.visit_assign(stmt),
258            Statement::Block(stmt) => self.visit_block(stmt),
259            Statement::Conditional(stmt) => self.visit_conditional(stmt),
260            Statement::Const(stmt) => self.visit_const(stmt),
261            Statement::Definition(stmt) => self.visit_definition(stmt),
262            Statement::Expression(stmt) => self.visit_expression_statement(stmt),
263            Statement::Iteration(stmt) => self.visit_iteration(stmt),
264            Statement::Return(stmt) => self.visit_return(stmt),
265        }
266    }
267
268    fn visit_assert(&mut self, input: &AssertStatement) {
269        match &input.variant {
270            AssertVariant::Assert(expr) => self.visit_expression(expr, &Default::default()),
271            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
272                self.visit_expression(left, &Default::default());
273                self.visit_expression(right, &Default::default())
274            }
275        };
276    }
277
278    fn visit_assign(&mut self, input: &AssignStatement) {
279        self.visit_expression(&input.place, &Default::default());
280        self.visit_expression(&input.value, &Default::default());
281    }
282
283    fn visit_block(&mut self, input: &Block) {
284        input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
285    }
286
287    fn visit_conditional(&mut self, input: &ConditionalStatement) {
288        self.visit_expression(&input.condition, &Default::default());
289        self.visit_block(&input.then);
290        if let Some(stmt) = input.otherwise.as_ref() {
291            self.visit_statement(stmt);
292        }
293    }
294
295    fn visit_const(&mut self, input: &ConstDeclaration) {
296        self.visit_type(&input.type_);
297        self.visit_expression(&input.value, &Default::default());
298    }
299
300    fn visit_definition(&mut self, input: &DefinitionStatement) {
301        if let Some(ty) = input.type_.as_ref() {
302            self.visit_type(ty)
303        }
304        self.visit_expression(&input.value, &Default::default());
305    }
306
307    fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
308        self.visit_expression(&input.expression, &Default::default());
309    }
310
311    fn visit_iteration(&mut self, input: &IterationStatement) {
312        if let Some(ty) = input.type_.as_ref() {
313            self.visit_type(ty)
314        }
315        self.visit_expression(&input.start, &Default::default());
316        self.visit_expression(&input.stop, &Default::default());
317        self.visit_block(&input.block);
318    }
319
320    fn visit_return(&mut self, input: &ReturnStatement) {
321        self.visit_expression(&input.expression, &Default::default());
322    }
323}
324
325/// A Visitor trait for the program represented by the AST.
326pub trait ProgramVisitor: AstVisitor {
327    fn visit_program(&mut self, input: &Program) {
328        input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
329        input.modules.values().for_each(|module| self.visit_module(module));
330        input.imports.values().for_each(|import| self.visit_import(&import.0));
331        input.stubs.values().for_each(|stub| self.visit_stub(stub));
332    }
333
334    fn visit_program_scope(&mut self, input: &ProgramScope) {
335        input.consts.iter().for_each(|(_, c)| (self.visit_const(c)));
336        input.structs.iter().for_each(|(_, c)| (self.visit_struct(c)));
337        input.mappings.iter().for_each(|(_, c)| (self.visit_mapping(c)));
338        input.functions.iter().for_each(|(_, c)| (self.visit_function(c)));
339        if let Some(c) = input.constructor.as_ref() {
340            self.visit_constructor(c);
341        }
342    }
343
344    fn visit_module(&mut self, input: &Module) {
345        input.consts.iter().for_each(|(_, c)| (self.visit_const(c)));
346        input.structs.iter().for_each(|(_, c)| (self.visit_struct(c)));
347        input.functions.iter().for_each(|(_, c)| (self.visit_function(c)));
348    }
349
350    fn visit_stub(&mut self, _input: &Stub) {}
351
352    fn visit_import(&mut self, input: &Program) {
353        self.visit_program(input)
354    }
355
356    fn visit_struct(&mut self, input: &Composite) {
357        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
358        input.members.iter().for_each(|member| self.visit_type(&member.type_));
359    }
360
361    fn visit_mapping(&mut self, input: &Mapping) {
362        self.visit_type(&input.key_type);
363        self.visit_type(&input.value_type);
364    }
365
366    fn visit_function(&mut self, input: &Function) {
367        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
368        input.input.iter().for_each(|input| self.visit_type(&input.type_));
369        input.output.iter().for_each(|output| self.visit_type(&output.type_));
370        self.visit_type(&input.output_type);
371        self.visit_block(&input.block);
372    }
373
374    fn visit_constructor(&mut self, input: &Constructor) {
375        self.visit_block(&input.block);
376    }
377
378    fn visit_function_stub(&mut self, _input: &FunctionStub) {}
379
380    fn visit_struct_stub(&mut self, _input: &Composite) {}
381}