leo_passes/monomorphization/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::MonomorphizationVisitor;
18use crate::{ConstPropagationVisitor, Replacer};
19
20use leo_ast::{
21    AstReconstructor,
22    CallExpression,
23    CompositeType,
24    Expression,
25    Identifier,
26    Node as _,
27    ProgramReconstructor,
28    StructExpression,
29    StructVariableInitializer,
30    Type,
31    Variant,
32};
33
34use indexmap::IndexMap;
35use itertools::{Either, Itertools};
36
37impl<'a> MonomorphizationVisitor<'a> {
38    /// Evaluates the given constant arguments if possible.
39    ///
40    /// Returns `Some` with all evaluated expressions if all are constants, or `None` if any argument is not constant.
41    fn try_evaluate_const_args(&mut self, const_args: &[Expression]) -> Option<Vec<Expression>> {
42        let mut const_evaluator = ConstPropagationVisitor::new(self.state, self.program);
43
44        let (evaluated_const_args, non_const_args): (Vec<_>, Vec<_>) = const_args
45            .iter()
46            .map(|arg| const_evaluator.reconstruct_expression(arg.clone(), &()))
47            .partition_map(|(evaluated_arg, evaluated_value)| match (evaluated_value, evaluated_arg) {
48                (Some(_), expr @ Expression::Literal(_)) => Either::Left(expr),
49                _ => Either::Right(()),
50            });
51
52        if !non_const_args.is_empty() { None } else { Some(evaluated_const_args) }
53    }
54}
55
56impl AstReconstructor for MonomorphizationVisitor<'_> {
57    type AdditionalInput = ();
58    type AdditionalOutput = ();
59
60    /* Types */
61    fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
62        // Proceed only if there are some const arguments.
63        if input.const_arguments.is_empty() {
64            return (Type::Composite(input), Default::default());
65        }
66
67        // Ensure all const arguments can be evaluated to literals; if not, we skip this struct type instantiation for
68        // now and mark it as unresolved.
69        //
70        // The types of the const arguments are already checked in the type checking pass.
71        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
72            self.unresolved_struct_types.push(input.clone());
73            return (Type::Composite(input), Default::default());
74        };
75
76        // At this stage, we know that we're going to modify the program
77        self.changed = true;
78        (
79            Type::Composite(CompositeType {
80                path: self.monomorphize_struct(&input.path, &evaluated_const_args),
81                const_arguments: vec![], // remove const arguments
82                program: input.program,
83            }),
84            Default::default(),
85        )
86    }
87
88    /* Expressions */
89    fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
90        let opt_old_type = self.state.type_table.get(&input.id());
91        let (new_expr, opt_value) = match input {
92            Expression::Array(array) => self.reconstruct_array(array, &()),
93            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
94            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, &()),
95            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, &()),
96            Expression::Async(async_) => self.reconstruct_async(async_, &()),
97            Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
98            Expression::Call(call) => self.reconstruct_call(*call, &()),
99            Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
100            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
101            Expression::Err(err) => self.reconstruct_err(err, &()),
102            Expression::Path(path) => self.reconstruct_path(path, &()),
103            Expression::Literal(value) => self.reconstruct_literal(value, &()),
104            Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
105            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
106            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
107            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
108            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
109            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
110            Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
111            Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
112        };
113
114        // If the expression was in the type table before, make an entry for the new expression.
115        if let Some(old_type) = opt_old_type {
116            self.state.type_table.insert(new_expr.id(), old_type);
117        }
118
119        (new_expr, opt_value)
120    }
121
122    fn reconstruct_call(
123        &mut self,
124        input_call: CallExpression,
125        _additional: &(),
126    ) -> (Expression, Self::AdditionalOutput) {
127        // Skip calls to functions from other programs.
128        if input_call.program.is_some_and(|prog| prog != self.program) {
129            return (input_call.into(), Default::default());
130        }
131
132        // Proceed only if there are some const arguments.
133        if input_call.const_arguments.is_empty() {
134            return (input_call.into(), Default::default());
135        }
136
137        // Ensure all const arguments can be evaluated to literals; if not, we skip this call for now and mark it as
138        // unresolved.
139        //
140        // The types of the const arguments are already checked in the type checking pass.
141        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input_call.const_arguments) else {
142            self.unresolved_calls.push(input_call.clone());
143            return (input_call.into(), Default::default());
144        };
145
146        // Look up the already reconstructed function by name.
147        let callee_fn = self
148            .reconstructed_functions
149            .get(&input_call.function.absolute_path())
150            .expect("Callee should already be reconstructed (post-order traversal).");
151
152        // Proceed only if the function variant is `inline`.
153        if !matches!(callee_fn.variant, Variant::Inline) {
154            return (input_call.into(), Default::default());
155        }
156
157        // Generate a unique name for the monomorphized function based on const arguments.
158        //
159        // For a function `fn foo::[x: u32, y: u32](..)`, the generated name would be `"foo::[1u32, 2u32]"` for a call
160        // that sets `x` to `1u32` and `y` to `2u32`. We know this name is safe to use because it's not a valid
161        // identifier in the user code.
162        let new_callee_path = input_call.function.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!(
163            "\"{}::[{}]\"",
164            input_call.function.identifier().name,
165            evaluated_const_args.iter().format(", ")
166        )));
167
168        // Check if the new callee name is not already present in `reconstructed_functions`. This ensures that we do not
169        // add a duplicate definition for the same function.
170        if self.reconstructed_functions.get(&new_callee_path.absolute_path()).is_none() {
171            // Build mapping from const parameters to const argument values.
172            let const_param_map: IndexMap<_, _> = callee_fn
173                .const_parameters
174                .iter()
175                .map(|param| param.identifier().name)
176                .zip_eq(&evaluated_const_args)
177                .collect();
178
179            // Function to replace identifier expressions with their corresponding const argument or keep them unchanged.
180            let replace_identifier = |expr: &Expression| match expr {
181                Expression::Path(path) => const_param_map
182                    .get(&path.identifier().name)
183                    .map_or(Expression::Path(path.clone()), |&expr| expr.clone()),
184                _ => expr.clone(),
185            };
186
187            let mut replacer = Replacer::new(replace_identifier, true /* refresh IDs */, self.state);
188
189            // Create a new version of `callee_fn` that has a new name, no const parameters, and a new function ID.
190
191            // First, reconstruct the function by changing all instances of const generic parameters to literals
192            // according to `const_param_map`.
193            let mut function = replacer.reconstruct_function(callee_fn.clone());
194
195            // Now, reconstruct the function to actually monomorphize its content such as generic struct expressions.
196            function = self.reconstruct_function(function);
197            function.identifier = Identifier {
198                name: new_callee_path.identifier().name,
199                span: leo_span::Span::default(),
200                id: self.state.node_builder.next_id(),
201            };
202            function.const_parameters = vec![];
203            function.id = self.state.node_builder.next_id();
204
205            // Keep track of the new function in case other functions need it.
206            self.reconstructed_functions.insert(new_callee_path.absolute_path(), function);
207
208            // Now keep track of the function we just monomorphized
209            self.monomorphized_functions.insert(input_call.function.absolute_path());
210        }
211
212        // At this stage, we know that we're going to modify the program
213        self.changed = true;
214
215        // Finally, construct the updated call expression that points to a monomorphized version and return it.
216        (
217            CallExpression {
218                function: new_callee_path,
219                const_arguments: vec![], // remove const arguments
220                arguments: input_call.arguments,
221                program: input_call.program,
222                span: input_call.span, // Keep pointing to the original call expression
223                id: input_call.id,
224            }
225            .into(),
226            Default::default(),
227        )
228    }
229
230    fn reconstruct_struct_init(
231        &mut self,
232        mut input: StructExpression,
233        _additional: &(),
234    ) -> (Expression, Self::AdditionalOutput) {
235        // Handle all the struct members first
236        let members = input
237            .members
238            .clone()
239            .into_iter()
240            .map(|member| StructVariableInitializer {
241                identifier: member.identifier,
242                expression: member.expression.map(|expr| self.reconstruct_expression(expr, &()).0),
243                span: member.span,
244                id: member.id,
245            })
246            .collect();
247
248        // Proceed only if there are some const arguments.
249        if input.const_arguments.is_empty() {
250            input.members = members;
251            return (input.into(), Default::default());
252        }
253
254        // Ensure all const arguments can be evaluated to literals; if not, we skip this struct expression for now and
255        // mark it as unresolved.
256        //
257        // The types of the const arguments are already checked in the type checking pass.
258        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
259            self.unresolved_struct_exprs.push(input.clone());
260            input.members = members;
261            return (input.into(), Default::default());
262        };
263
264        // At this stage, we know that we're going to modify the program
265        self.changed = true;
266
267        // Finally, construct the updated struct expression that points to a monomorphized version and return it.
268        (
269            StructExpression {
270                path: self.monomorphize_struct(&input.path, &evaluated_const_args),
271                members,
272                const_arguments: vec![], // remove const arguments
273                span: input.span,        // Keep pointing to the original struct expression
274                id: input.id,
275            }
276            .into(),
277            Default::default(),
278        )
279    }
280}