leo_ast/passes/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains Visitor trait implementations for the AST.
18//! It implements default methods for each node to be made
19//! given the type of node its visiting.
20
21use crate::*;
22
23// TODO: The Visitor and Reconstructor patterns need a redesign so that the default implementation can easily be invoked though its implemented in an overriding trait.
24// Here is a pattern that seems to work
25// trait ProgramVisitor {
26//     // The trait method that can be overridden
27//     fn visit_program_scope(&mut self);
28//
29//     // Private helper function containing the default implementation
30//     fn default_visit_program_scope(&mut self) {
31//         println!("Do default stuff");
32//     }
33// }
34//
35// struct YourStruct;
36//
37// impl ProgramVisitor for YourStruct {
38//     fn visit_program_scope(&mut self) {
39//         println!("Do custom stuff.");
40//         // Call the default implementation
41//         self.default_visit_program_scope();
42//     }
43// }
44
45/// A Visitor trait for types in the AST.
46pub trait AstVisitor {
47    /* Types */
48    fn visit_type(&mut self, input: &Type) {
49        match input {
50            Type::Array(array_type) => self.visit_array_type(array_type),
51            Type::Composite(composite_type) => self.visit_composite_type(composite_type),
52            Type::Future(future_type) => self.visit_future_type(future_type),
53            Type::Mapping(mapping_type) => self.visit_mapping_type(mapping_type),
54            Type::Optional(optional_type) => self.visit_optional_type(optional_type),
55            Type::Tuple(tuple_type) => self.visit_tuple_type(tuple_type),
56            Type::Vector(array_type) => self.visit_vector_type(array_type),
57            Type::Address
58            | Type::Boolean
59            | Type::Field
60            | Type::Group
61            | Type::Identifier(_)
62            | Type::Integer(_)
63            | Type::Scalar
64            | Type::Signature
65            | Type::String
66            | Type::Numeric
67            | Type::Unit
68            | Type::Err => {}
69        }
70    }
71
72    fn visit_array_type(&mut self, input: &ArrayType) {
73        self.visit_type(&input.element_type);
74        self.visit_expression(&input.length, &Default::default());
75    }
76
77    fn visit_composite_type(&mut self, input: &CompositeType) {
78        input.const_arguments.iter().for_each(|expr| {
79            self.visit_expression(expr, &Default::default());
80        });
81    }
82
83    fn visit_future_type(&mut self, input: &FutureType) {
84        input.inputs.iter().for_each(|input| self.visit_type(input));
85    }
86
87    fn visit_mapping_type(&mut self, input: &MappingType) {
88        self.visit_type(&input.key);
89        self.visit_type(&input.value);
90    }
91
92    fn visit_optional_type(&mut self, input: &OptionalType) {
93        self.visit_type(&input.inner);
94    }
95
96    fn visit_tuple_type(&mut self, input: &TupleType) {
97        input.elements().iter().for_each(|input| self.visit_type(input));
98    }
99
100    fn visit_vector_type(&mut self, input: &VectorType) {
101        self.visit_type(&input.element_type);
102    }
103
104    /* Expressions */
105    type AdditionalInput: Default;
106    type Output: Default;
107
108    fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
109        match input {
110            Expression::Array(array) => self.visit_array(array, additional),
111            Expression::ArrayAccess(access) => self.visit_array_access(access, additional),
112            Expression::AssociatedConstant(constant) => self.visit_associated_constant(constant, additional),
113            Expression::AssociatedFunction(function) => self.visit_associated_function(function, additional),
114            Expression::Async(async_) => self.visit_async(async_, additional),
115            Expression::Binary(binary) => self.visit_binary(binary, additional),
116            Expression::Call(call) => self.visit_call(call, additional),
117            Expression::Cast(cast) => self.visit_cast(cast, additional),
118            Expression::Struct(struct_) => self.visit_struct_init(struct_, additional),
119            Expression::Err(err) => self.visit_err(err, additional),
120            Expression::Path(path) => self.visit_path(path, additional),
121            Expression::Literal(literal) => self.visit_literal(literal, additional),
122            Expression::Locator(locator) => self.visit_locator(locator, additional),
123            Expression::MemberAccess(access) => self.visit_member_access(access, additional),
124            Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
125            Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
126            Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
127            Expression::TupleAccess(access) => self.visit_tuple_access(access, additional),
128            Expression::Unary(unary) => self.visit_unary(unary, additional),
129            Expression::Unit(unit) => self.visit_unit(unit, additional),
130        }
131    }
132
133    fn visit_array_access(&mut self, input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
134        self.visit_expression(&input.array, &Default::default());
135        self.visit_expression(&input.index, &Default::default());
136        Default::default()
137    }
138
139    fn visit_member_access(&mut self, input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
140        self.visit_expression(&input.inner, &Default::default());
141        Default::default()
142    }
143
144    fn visit_tuple_access(&mut self, input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
145        self.visit_expression(&input.tuple, &Default::default());
146        Default::default()
147    }
148
149    fn visit_array(&mut self, input: &ArrayExpression, _additional: &Self::AdditionalInput) -> Self::Output {
150        input.elements.iter().for_each(|expr| {
151            self.visit_expression(expr, &Default::default());
152        });
153        Default::default()
154    }
155
156    fn visit_associated_constant(
157        &mut self,
158        _input: &AssociatedConstantExpression,
159        _additional: &Self::AdditionalInput,
160    ) -> Self::Output {
161        Default::default()
162    }
163
164    fn visit_associated_function(
165        &mut self,
166        input: &AssociatedFunctionExpression,
167        _additional: &Self::AdditionalInput,
168    ) -> Self::Output {
169        input.arguments.iter().for_each(|arg| {
170            self.visit_expression(arg, &Default::default());
171        });
172        Default::default()
173    }
174
175    fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
176        self.visit_block(&input.block);
177        Default::default()
178    }
179
180    fn visit_binary(&mut self, input: &BinaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
181        self.visit_expression(&input.left, &Default::default());
182        self.visit_expression(&input.right, &Default::default());
183        Default::default()
184    }
185
186    fn visit_call(&mut self, input: &CallExpression, _additional: &Self::AdditionalInput) -> Self::Output {
187        input.const_arguments.iter().for_each(|expr| {
188            self.visit_expression(expr, &Default::default());
189        });
190        input.arguments.iter().for_each(|expr| {
191            self.visit_expression(expr, &Default::default());
192        });
193        Default::default()
194    }
195
196    fn visit_cast(&mut self, input: &CastExpression, _additional: &Self::AdditionalInput) -> Self::Output {
197        self.visit_expression(&input.expression, &Default::default());
198        Default::default()
199    }
200
201    fn visit_struct_init(&mut self, input: &StructExpression, _additional: &Self::AdditionalInput) -> Self::Output {
202        input.const_arguments.iter().for_each(|expr| {
203            self.visit_expression(expr, &Default::default());
204        });
205        for StructVariableInitializer { expression, .. } in input.members.iter() {
206            if let Some(expression) = expression {
207                self.visit_expression(expression, &Default::default());
208            }
209        }
210        Default::default()
211    }
212
213    fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
214        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
215    }
216
217    fn visit_path(&mut self, _input: &Path, _additional: &Self::AdditionalInput) -> Self::Output {
218        Default::default()
219    }
220
221    fn visit_literal(&mut self, _input: &Literal, _additional: &Self::AdditionalInput) -> Self::Output {
222        Default::default()
223    }
224
225    fn visit_locator(&mut self, _input: &LocatorExpression, _additional: &Self::AdditionalInput) -> Self::Output {
226        Default::default()
227    }
228
229    fn visit_repeat(&mut self, input: &RepeatExpression, _additional: &Self::AdditionalInput) -> Self::Output {
230        self.visit_expression(&input.expr, &Default::default());
231        self.visit_expression(&input.count, &Default::default());
232        Default::default()
233    }
234
235    fn visit_ternary(&mut self, input: &TernaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
236        self.visit_expression(&input.condition, &Default::default());
237        self.visit_expression(&input.if_true, &Default::default());
238        self.visit_expression(&input.if_false, &Default::default());
239        Default::default()
240    }
241
242    fn visit_tuple(&mut self, input: &TupleExpression, _additional: &Self::AdditionalInput) -> Self::Output {
243        input.elements.iter().for_each(|expr| {
244            self.visit_expression(expr, &Default::default());
245        });
246        Default::default()
247    }
248
249    fn visit_unary(&mut self, input: &UnaryExpression, _additional: &Self::AdditionalInput) -> Self::Output {
250        self.visit_expression(&input.receiver, &Default::default());
251        Default::default()
252    }
253
254    fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
255        Default::default()
256    }
257
258    /* Statements */
259    fn visit_statement(&mut self, input: &Statement) {
260        match input {
261            Statement::Assert(stmt) => self.visit_assert(stmt),
262            Statement::Assign(stmt) => self.visit_assign(stmt),
263            Statement::Block(stmt) => self.visit_block(stmt),
264            Statement::Conditional(stmt) => self.visit_conditional(stmt),
265            Statement::Const(stmt) => self.visit_const(stmt),
266            Statement::Definition(stmt) => self.visit_definition(stmt),
267            Statement::Expression(stmt) => self.visit_expression_statement(stmt),
268            Statement::Iteration(stmt) => self.visit_iteration(stmt),
269            Statement::Return(stmt) => self.visit_return(stmt),
270        }
271    }
272
273    fn visit_assert(&mut self, input: &AssertStatement) {
274        match &input.variant {
275            AssertVariant::Assert(expr) => self.visit_expression(expr, &Default::default()),
276            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
277                self.visit_expression(left, &Default::default());
278                self.visit_expression(right, &Default::default())
279            }
280        };
281    }
282
283    fn visit_assign(&mut self, input: &AssignStatement) {
284        self.visit_expression(&input.place, &Default::default());
285        self.visit_expression(&input.value, &Default::default());
286    }
287
288    fn visit_block(&mut self, input: &Block) {
289        input.statements.iter().for_each(|stmt| self.visit_statement(stmt));
290    }
291
292    fn visit_conditional(&mut self, input: &ConditionalStatement) {
293        self.visit_expression(&input.condition, &Default::default());
294        self.visit_block(&input.then);
295        if let Some(stmt) = input.otherwise.as_ref() {
296            self.visit_statement(stmt);
297        }
298    }
299
300    fn visit_const(&mut self, input: &ConstDeclaration) {
301        self.visit_type(&input.type_);
302        self.visit_expression(&input.value, &Default::default());
303    }
304
305    fn visit_definition(&mut self, input: &DefinitionStatement) {
306        if let Some(ty) = input.type_.as_ref() {
307            self.visit_type(ty)
308        }
309        self.visit_expression(&input.value, &Default::default());
310    }
311
312    fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
313        self.visit_expression(&input.expression, &Default::default());
314    }
315
316    fn visit_iteration(&mut self, input: &IterationStatement) {
317        if let Some(ty) = input.type_.as_ref() {
318            self.visit_type(ty)
319        }
320        self.visit_expression(&input.start, &Default::default());
321        self.visit_expression(&input.stop, &Default::default());
322        self.visit_block(&input.block);
323    }
324
325    fn visit_return(&mut self, input: &ReturnStatement) {
326        self.visit_expression(&input.expression, &Default::default());
327    }
328}
329
330/// A Visitor trait for the program represented by the AST.
331pub trait ProgramVisitor: AstVisitor {
332    fn visit_program(&mut self, input: &Program) {
333        input.program_scopes.values().for_each(|scope| self.visit_program_scope(scope));
334        input.modules.values().for_each(|module| self.visit_module(module));
335        input.imports.values().for_each(|import| self.visit_import(&import.0));
336        input.stubs.values().for_each(|stub| self.visit_stub(stub));
337    }
338
339    fn visit_program_scope(&mut self, input: &ProgramScope) {
340        input.consts.iter().for_each(|(_, c)| self.visit_const(c));
341        input.structs.iter().for_each(|(_, c)| self.visit_struct(c));
342        input.mappings.iter().for_each(|(_, c)| self.visit_mapping(c));
343        input.storage_variables.iter().for_each(|(_, c)| self.visit_storage_variable(c));
344        input.functions.iter().for_each(|(_, c)| self.visit_function(c));
345        if let Some(c) = input.constructor.as_ref() {
346            self.visit_constructor(c);
347        }
348    }
349
350    fn visit_module(&mut self, input: &Module) {
351        input.consts.iter().for_each(|(_, c)| self.visit_const(c));
352        input.structs.iter().for_each(|(_, c)| self.visit_struct(c));
353        input.functions.iter().for_each(|(_, c)| self.visit_function(c));
354    }
355
356    fn visit_stub(&mut self, _input: &Stub) {}
357
358    fn visit_import(&mut self, input: &Program) {
359        self.visit_program(input)
360    }
361
362    fn visit_struct(&mut self, input: &Composite) {
363        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
364        input.members.iter().for_each(|member| self.visit_type(&member.type_));
365    }
366
367    fn visit_mapping(&mut self, input: &Mapping) {
368        self.visit_type(&input.key_type);
369        self.visit_type(&input.value_type);
370    }
371
372    fn visit_storage_variable(&mut self, input: &StorageVariable) {
373        self.visit_type(&input.type_);
374    }
375
376    fn visit_function(&mut self, input: &Function) {
377        input.const_parameters.iter().for_each(|input| self.visit_type(&input.type_));
378        input.input.iter().for_each(|input| self.visit_type(&input.type_));
379        input.output.iter().for_each(|output| self.visit_type(&output.type_));
380        self.visit_type(&input.output_type);
381        self.visit_block(&input.block);
382    }
383
384    fn visit_constructor(&mut self, input: &Constructor) {
385        self.visit_block(&input.block);
386    }
387
388    fn visit_function_stub(&mut self, _input: &FunctionStub) {}
389
390    fn visit_struct_stub(&mut self, _input: &Composite) {}
391}