leo_ast/passes/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains Visitor trait implementations for the AST.
18//! It implements default methods for each node to be made
19//! given the type of node its visiting.
20
21use crate::*;
22
23// TODO: The Visitor and Reconstructor patterns need a redesign so that the default implementation can easily be invoked though its implemented in an overriding trait.
24// Here is a pattern that seems to work
25// trait ProgramVisitor {
26//     // The trait method that can be overridden
27//     fn visit_program_scope(&mut self);
28//
29//     // Private helper function containing the default implementation
30//     fn default_visit_program_scope(&mut self) {
31//         println!("Do default stuff");
32//     }
33// }
34//
35// struct YourStruct;
36//
37// impl ProgramVisitor for YourStruct {
38//     fn visit_program_scope(&mut self) {
39//         println!("Do custom stuff.");
40//         // Call the default implementation
41//         self.default_visit_program_scope();
42//     }
43// }
44
45/// A Visitor trait for types in the AST.
46pub trait AstVisitor {
47    /* Types */
48    fn visit_type(&mut self, input: &Type) {
49        match input {
50            Type::Array(array_type) => self.visit_array_type(array_type),
51            Type::Composite(composite_type) => self.visit_composite_type(composite_type),
52            Type::Future(future_type) => self.visit_future_type(future_type),
53            Type::Mapping(mapping_type) => self.visit_mapping_type(mapping_type),
54            Type::Optional(optional_type) => self.visit_optional_type(optional_type),
55            Type::Tuple(tuple_type) => self.visit_tuple_type(tuple_type),
56            Type::Vector(array_type) => self.visit_vector_type(array_type),
57            Type::Address
58            | Type::Boolean
59            | Type::Field
60            | Type::Group
61            | Type::Identifier(_)
62            | Type::Integer(_)
63            | Type::Scalar
64            | Type::Signature
65            | Type::String
66            | Type::Numeric
67            | Type::Unit
68            | Type::Err => {}
69        }
70    }
71
72    fn visit_array_type(&mut self, input: &ArrayType) {
73        self.visit_type(&input.element_type);
74        self.visit_expression(&input.length, &Default::default());
75    }
76
77    fn visit_composite_type(&mut self, input: &CompositeType) {
78        input.const_arguments.iter().for_each(|expr| {
79            self.visit_expression(expr, &Default::default());
80        });
81    }
82
83    fn visit_future_type(&mut self, input: &FutureType) {
84        input.inputs.iter().for_each(|input| self.visit_type(input));
85    }
86
87    fn visit_mapping_type(&mut self, input: &MappingType) {
88        self.visit_type(&input.key);
89        self.visit_type(&input.value);
90    }
91
92    fn visit_optional_type(&mut self, input: &OptionalType) {
93        self.visit_type(&input.inner);
94    }
95
96    fn visit_tuple_type(&mut self, input: &TupleType) {
97        input.elements().iter().for_each(|input| self.visit_type(input));
98    }
99
100    fn visit_vector_type(&mut self, input: &VectorType) {
101        self.visit_type(&input.element_type);
102    }
103
104    /* Expressions */
105    type AdditionalInput: Default;
106    type Output: Default;
107
108    fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
109        match input {
110            Expression::Array(array) => self.visit_array(array, additional),
111            Expression::ArrayAccess(access) => self.visit_array_access(access, additional),
112            Expression::Async(async_) => self.visit_async(async_, additional),
113            Expression::Binary(binary) => self.visit_binary(binary, additional),
114            Expression::Call(call) => self.visit_call(call, additional),
115            Expression::Cast(cast) => self.visit_cast(cast, additional),
116            Expression::Struct(struct_) => self.visit_struct_init(struct_, additional),
117            Expression::Err(err) => self.visit_err(err, additional),
118            Expression::Path(path) => self.visit_path(path, additional),
119            Expression::Literal(literal) => self.visit_literal(literal, additional),
120            Expression::Locator(locator) => self.visit_locator(locator, additional),
121            Expression::MemberAccess(access) => self.visit_member_access(access, additional),
122            Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
123            Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
124            Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
125            Expression::TupleAccess(access) => self.visit_tuple_access(access, additional),
126            Expression::Unary(unary) => self.visit_unary(unary, additional),
127            Expression::Unit(unit) => self.visit_unit(unit, additional),
128            Expression::Intrinsic(intr) => self.visit_intrinsic(intr, additional),
129        }
130    }
131
132    fn visit_array_access(&mut self, input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
133        self.visit_expression(&input.array, &Default::default());
134        self.visit_expression(&input.index, &Default::default());
135        Default::default()
136    }
137
138    fn visit_member_access(&mut self, input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
139        self.visit_expression(&input.inner, &Default::default());
140        Default::default()
141    }
142
143    fn visit_tuple_access(&mut self, input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
144        self.visit_expression(&input.tuple, &Default::default());
145        Default::default()
146    }
147
148    fn visit_array(&mut self, input: &ArrayExpression, _additional: &Self::AdditionalInput) -> Self::Output {
149        input.elements.iter().for_each(|expr| {
150            self.visit_expression(expr, &Default::default());
151        });
152        Default::default()
153    }
154
155    fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
156        self.visit_block(&input.block);
157        Default::default()
158    }
159
160    fn visit_binary(&mut self, input: &BinaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
161        self.visit_expression(&input.left, &Default::default());
162        self.visit_expression(&input.right, &Default::default());
163        Default::default()
164    }
165
166    fn visit_call(&mut self, input: &CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
167        input.const_arguments.iter().for_each(|expr| {
168            self.visit_expression(expr, &Default::default());
169        });
170        input.arguments.iter().for_each(|expr| {
171            self.visit_expression(expr, &Default::default());
172        });
173        Default::default()
174    }
175
176    fn visit_intrinsic(&mut self, input: &IntrinsicExpression, _additional: &Self::AdditionalInput) -> Self::Output {
177        input.arguments.iter().for_each(|arg| {
178            self.visit_expression(arg, &Default::default());
179        });
180        Default::default()
181    }
182
183    fn visit_cast(&mut self, input: &CastExpression, _additional: &Self::AdditionalInput) -> Self::Output {
184        self.visit_expression(&input.expression, &Default::default());
185        Default::default()
186    }
187
188    fn visit_struct_init(&mut self, input: &StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
189        input.const_arguments.iter().for_each(|expr| {
190            self.visit_expression(expr, &Default::default());
191        });
192        for StructVariableInitializer { expression, .. } in input.members.iter() {
193            if let Some(expression) = expression {
194                self.visit_expression(expression, &Default::default());
195            }
196        }
197        Default::default()
198    }
199
200    fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
201        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
202    }
203
204    fn visit_path(&mut self, _input: &Path, _additional: &Self::AdditionalInput) -> Self::Output {
205        Default::default()
206    }
207
208    fn visit_literal(&mut self, _input: &Literal, _additional: &Self::AdditionalInput) -> Self::Output {
209        Default::default()
210    }
211
212    fn visit_locator(&mut self, _input: &LocatorExpression, _additional: &Self::AdditionalInput) -> Self::Output {
213        Default::default()
214    }
215
216    fn visit_repeat(&mut self, input: &RepeatExpression, _additional: &Self::AdditionalInput) -> Self::Output {
217        self.visit_expression(&input.expr, &Default::default());
218        self.visit_expression(&input.count, &Default::default());
219        Default::default()
220    }
221
222    fn visit_ternary(&mut self, input: &TernaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
223        self.visit_expression(&input.condition, &Default::default());
224        self.visit_expression(&input.if_true, &Default::default());
225        self.visit_expression(&input.if_false, &Default::default());
226        Default::default()
227    }
228
229    fn visit_tuple(&mut self, input: &TupleExpression, _additional: &Self::AdditionalInput) -> Self::Output {
230        input.elements.iter().for_each(|expr| {
231            self.visit_expression(expr, &Default::default());
232        });
233        Default::default()
234    }
235
236    fn visit_unary(&mut self, input: &UnaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
237        self.visit_expression(&input.receiver, &Default::default());
238        Default::default()
239    }
240
241    fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
242        Default::default()
243    }
244
245    /* Statements */
246    fn visit_statement(&mut self, input: &Statement) {
247        match input {
248            Statement::Assert(stmt) => self.visit_assert(stmt),
249            Statement::Assign(stmt) => self.visit_assign(stmt),
250            Statement::Block(stmt) => self.visit_block(stmt),
251            Statement::Conditional(stmt) => self.visit_conditional(stmt),
252            Statement::Const(stmt) => self.visit_const(stmt),
253            Statement::Definition(stmt) => self.visit_definition(stmt),
254            Statement::Expression(stmt) => self.visit_expression_statement(stmt),
255            Statement::Iteration(stmt) => self.visit_iteration(stmt),
256            Statement::Return(stmt) => self.visit_return(stmt),
257        }
258    }
259
260    fn visit_assert(&mut self, input: &AssertStatement) {
261        match &input.variant {
262            AssertVariant::Assert(expr) => self.visit_expression(expr, &Default::default()),
263            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
264                self.visit_expression(left, &Default::default());
265                self.visit_expression(right, &Default::default())
266            }
267        };
268    }
269
270    fn visit_assign(&mut self, input: &AssignStatement) {
271        self.visit_expression(&input.place, &Default::default());
272        self.visit_expression(&input.value, &Default::default());
273    }
274
275    fn visit_block(&mut self, input: &Block) {
276        input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
277    }
278
279    fn visit_conditional(&mut self, input: &ConditionalStatement) {
280        self.visit_expression(&input.condition, &Default::default());
281        self.visit_block(&input.then);
282        if let Some(stmt) = input.otherwise.as_ref() {
283            self.visit_statement(stmt);
284        }
285    }
286
287    fn visit_const(&mut self, input: &ConstDeclaration) {
288        self.visit_type(&input.type_);
289        self.visit_expression(&input.value, &Default::default());
290    }
291
292    fn visit_definition(&mut self, input: &DefinitionStatement) {
293        if let Some(ty) = input.type_.as_ref() {
294            self.visit_type(ty)
295        }
296        self.visit_expression(&input.value, &Default::default());
297    }
298
299    fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
300        self.visit_expression(&input.expression, &Default::default());
301    }
302
303    fn visit_iteration(&mut self, input: &IterationStatement) {
304        if let Some(ty) = input.type_.as_ref() {
305            self.visit_type(ty)
306        }
307        self.visit_expression(&input.start, &Default::default());
308        self.visit_expression(&input.stop, &Default::default());
309        self.visit_block(&input.block);
310    }
311
312    fn visit_return(&mut self, input: &ReturnStatement) {
313        self.visit_expression(&input.expression, &Default::default());
314    }
315}
316
317/// A Visitor trait for the program represented by the AST.
318pub trait ProgramVisitor: AstVisitor {
319    fn visit_program(&mut self, input: &Program) {
320        input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
321        input.modules.values().for_each(|module| self.visit_module(module));
322        input.imports.values().for_each(|import| self.visit_import(&import.0));
323        input.stubs.values().for_each(|stub| self.visit_stub(stub));
324    }
325
326    fn visit_program_scope(&mut self, input: &ProgramScope) {
327        input.consts.iter().for_each(|(_, c)| self.visit_const(c));
328        input.structs.iter().for_each(|(_, c)| self.visit_struct(c));
329        input.mappings.iter().for_each(|(_, c)| self.visit_mapping(c));
330        input.storage_variables.iter().for_each(|(_, c)| self.visit_storage_variable(c));
331        input.functions.iter().for_each(|(_, c)| self.visit_function(c));
332        if let Some(c) = input.constructor.as_ref() {
333            self.visit_constructor(c);
334        }
335    }
336
337    fn visit_module(&mut self, input: &Module) {
338        input.consts.iter().for_each(|(_, c)| self.visit_const(c));
339        input.structs.iter().for_each(|(_, c)| self.visit_struct(c));
340        input.functions.iter().for_each(|(_, c)| self.visit_function(c));
341    }
342
343    fn visit_stub(&mut self, _input: &Stub) {}
344
345    fn visit_import(&mut self, input: &Program) {
346        self.visit_program(input)
347    }
348
349    fn visit_struct(&mut self, input: &Composite) {
350        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
351        input.members.iter().for_each(|member| self.visit_type(&member.type_));
352    }
353
354    fn visit_mapping(&mut self, input: &Mapping) {
355        self.visit_type(&input.key_type);
356        self.visit_type(&input.value_type);
357    }
358
359    fn visit_storage_variable(&mut self, input: &StorageVariable) {
360        self.visit_type(&input.type_);
361    }
362
363    fn visit_function(&mut self, input: &Function) {
364        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
365        input.input.iter().for_each(|input| self.visit_type(&input.type_));
366        input.output.iter().for_each(|output| self.visit_type(&output.type_));
367        self.visit_type(&input.output_type);
368        self.visit_block(&input.block);
369    }
370
371    fn visit_constructor(&mut self, input: &Constructor) {
372        self.visit_block(&input.block);
373    }
374
375    fn visit_function_stub(&mut self, _input: &FunctionStub) {}
376
377    fn visit_struct_stub(&mut self, _input: &Composite) {}
378}