leo_passes/monomorphization/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::MonomorphizationVisitor;
18use crate::{ConstPropagationVisitor, Replacer};
19
20use leo_ast::{
21    AstReconstructor,
22    CallExpression,
23    CompositeType,
24    Expression,
25    Identifier,
26    Node as _,
27    ProgramReconstructor,
28    StructExpression,
29    StructVariableInitializer,
30    Type,
31    Variant,
32};
33
34use indexmap::IndexMap;
35use itertools::{Either, Itertools};
36
37impl<'a> MonomorphizationVisitor<'a> {
38    /// Evaluates the given constant arguments if possible.
39    ///
40    /// Returns `Some` with all evaluated expressions if all are constants, or `None` if any argument is not constant.
41    fn try_evaluate_const_args(&mut self, const_args: &[Expression]) -> Option<Vec<Expression>> {
42        let mut const_evaluator = ConstPropagationVisitor::new(self.state, self.program);
43
44        let (evaluated_const_args, non_const_args): (Vec<_>, Vec<_>) = const_args
45            .iter()
46            .map(|arg| const_evaluator.reconstruct_expression(arg.clone(), &()))
47            .partition_map(|(evaluated_arg, evaluated_value)| match (evaluated_value, evaluated_arg) {
48                (Some(_), expr @ Expression::Literal(_)) => Either::Left(expr),
49                _ => Either::Right(()),
50            });
51
52        if !non_const_args.is_empty() { None } else { Some(evaluated_const_args) }
53    }
54}
55
56impl AstReconstructor for MonomorphizationVisitor<'_> {
57    type AdditionalInput = ();
58    type AdditionalOutput = ();
59
60    /* Types */
61    fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
62        // Proceed only if there are some const arguments.
63        if input.const_arguments.is_empty() {
64            return (Type::Composite(input), Default::default());
65        }
66
67        // Ensure all const arguments can be evaluated to literals; if not, we skip this struct type instantiation for
68        // now and mark it as unresolved.
69        //
70        // The types of the const arguments are already checked in the type checking pass.
71        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
72            self.unresolved_struct_types.push(input.clone());
73            return (Type::Composite(input), Default::default());
74        };
75
76        // At this stage, we know that we're going to modify the program
77        self.changed = true;
78        (
79            Type::Composite(CompositeType {
80                path: self.monomorphize_struct(&input.path, &evaluated_const_args),
81                const_arguments: vec![], // remove const arguments
82                program: input.program,
83            }),
84            Default::default(),
85        )
86    }
87
88    /* Expressions */
89    fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
90        let opt_old_type = self.state.type_table.get(&input.id());
91        let (new_expr, opt_value) = match input {
92            Expression::Array(array) => self.reconstruct_array(array, &()),
93            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
94            Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, &()),
95            Expression::Async(async_) => self.reconstruct_async(async_, &()),
96            Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
97            Expression::Call(call) => self.reconstruct_call(*call, &()),
98            Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
99            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
100            Expression::Err(err) => self.reconstruct_err(err, &()),
101            Expression::Path(path) => self.reconstruct_path(path, &()),
102            Expression::Literal(value) => self.reconstruct_literal(value, &()),
103            Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
104            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
105            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
106            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
107            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
108            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
109            Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
110            Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
111        };
112
113        // If the expression was in the type table before, make an entry for the new expression.
114        if let Some(old_type) = opt_old_type {
115            self.state.type_table.insert(new_expr.id(), old_type);
116        }
117
118        (new_expr, opt_value)
119    }
120
121    fn reconstruct_call(
122        &mut self,
123        input_call: CallExpression,
124        _additional: &(),
125    ) -> (Expression, Self::AdditionalOutput) {
126        // Skip calls to functions from other programs.
127        if input_call.program.is_some_and(|prog| prog != self.program) {
128            return (input_call.into(), Default::default());
129        }
130
131        // Proceed only if there are some const arguments.
132        if input_call.const_arguments.is_empty() {
133            return (input_call.into(), Default::default());
134        }
135
136        // Ensure all const arguments can be evaluated to literals; if not, we skip this call for now and mark it as
137        // unresolved.
138        //
139        // The types of the const arguments are already checked in the type checking pass.
140        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input_call.const_arguments) else {
141            self.unresolved_calls.push(input_call.clone());
142            return (input_call.into(), Default::default());
143        };
144
145        // Look up the already reconstructed function by name.
146        let callee_fn = self
147            .reconstructed_functions
148            .get(&input_call.function.absolute_path())
149            .expect("Callee should already be reconstructed (post-order traversal).");
150
151        // Proceed only if the function variant is `inline`.
152        if !matches!(callee_fn.variant, Variant::Inline) {
153            return (input_call.into(), Default::default());
154        }
155
156        // Generate a unique name for the monomorphized function based on const arguments.
157        //
158        // For a function `fn foo::[x: u32, y: u32](..)`, the generated name would be `"foo::[1u32, 2u32]"` for a call
159        // that sets `x` to `1u32` and `y` to `2u32`. We know this name is safe to use because it's not a valid
160        // identifier in the user code.
161        let new_callee_path = input_call.function.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!(
162            "\"{}::[{}]\"",
163            input_call.function.identifier().name,
164            evaluated_const_args.iter().format(", ")
165        )));
166
167        // Check if the new callee name is not already present in `reconstructed_functions`. This ensures that we do not
168        // add a duplicate definition for the same function.
169        if self.reconstructed_functions.get(&new_callee_path.absolute_path()).is_none() {
170            // Build mapping from const parameters to const argument values.
171            let const_param_map: IndexMap<_, _> = callee_fn
172                .const_parameters
173                .iter()
174                .map(|param| param.identifier().name)
175                .zip_eq(&evaluated_const_args)
176                .collect();
177
178            // Function to replace identifier expressions with their corresponding const argument or keep them unchanged.
179            let replace_identifier = |expr: &Expression| match expr {
180                Expression::Path(path) => const_param_map
181                    .get(&path.identifier().name)
182                    .map_or(Expression::Path(path.clone()), |&expr| expr.clone()),
183                _ => expr.clone(),
184            };
185
186            let mut replacer = Replacer::new(replace_identifier, true /* refresh IDs */, self.state);
187
188            // Create a new version of `callee_fn` that has a new name, no const parameters, and a new function ID.
189
190            // First, reconstruct the function by changing all instances of const generic parameters to literals
191            // according to `const_param_map`.
192            let mut function = replacer.reconstruct_function(callee_fn.clone());
193
194            // Now, reconstruct the function to actually monomorphize its content such as generic struct expressions.
195            function = self.reconstruct_function(function);
196            function.identifier = Identifier {
197                name: new_callee_path.identifier().name,
198                span: leo_span::Span::default(),
199                id: self.state.node_builder.next_id(),
200            };
201            function.const_parameters = vec![];
202            function.id = self.state.node_builder.next_id();
203
204            // Keep track of the new function in case other functions need it.
205            self.reconstructed_functions.insert(new_callee_path.absolute_path(), function);
206
207            // Now keep track of the function we just monomorphized
208            self.monomorphized_functions.insert(input_call.function.absolute_path());
209        }
210
211        // At this stage, we know that we're going to modify the program
212        self.changed = true;
213
214        // Finally, construct the updated call expression that points to a monomorphized version and return it.
215        (
216            CallExpression {
217                function: new_callee_path,
218                const_arguments: vec![], // remove const arguments
219                arguments: input_call.arguments,
220                program: input_call.program,
221                span: input_call.span, // Keep pointing to the original call expression
222                id: input_call.id,
223            }
224            .into(),
225            Default::default(),
226        )
227    }
228
229    fn reconstruct_struct_init(
230        &mut self,
231        mut input: StructExpression,
232        _additional: &(),
233    ) -> (Expression, Self::AdditionalOutput) {
234        // Handle all the struct members first
235        let members = input
236            .members
237            .clone()
238            .into_iter()
239            .map(|member| StructVariableInitializer {
240                identifier: member.identifier,
241                expression: member.expression.map(|expr| self.reconstruct_expression(expr, &()).0),
242                span: member.span,
243                id: member.id,
244            })
245            .collect();
246
247        // Proceed only if there are some const arguments.
248        if input.const_arguments.is_empty() {
249            input.members = members;
250            return (input.into(), Default::default());
251        }
252
253        // Ensure all const arguments can be evaluated to literals; if not, we skip this struct expression for now and
254        // mark it as unresolved.
255        //
256        // The types of the const arguments are already checked in the type checking pass.
257        let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
258            self.unresolved_struct_exprs.push(input.clone());
259            input.members = members;
260            return (input.into(), Default::default());
261        };
262
263        // At this stage, we know that we're going to modify the program
264        self.changed = true;
265
266        // Finally, construct the updated struct expression that points to a monomorphized version and return it.
267        (
268            StructExpression {
269                path: self.monomorphize_struct(&input.path, &evaluated_const_args),
270                members,
271                const_arguments: vec![], // remove const arguments
272                span: input.span,        // Keep pointing to the original struct expression
273                id: input.id,
274            }
275            .into(),
276            Default::default(),
277        )
278    }
279}