leo_passes/monomorphization/
ast.rs1use super::MonomorphizationVisitor;
18use crate::{ConstPropagationVisitor, Replacer};
19
20use leo_ast::{
21 AstReconstructor,
22 CallExpression,
23 CompositeType,
24 Expression,
25 Identifier,
26 Node as _,
27 ProgramReconstructor,
28 StructExpression,
29 StructVariableInitializer,
30 Type,
31 Variant,
32};
33
34use indexmap::IndexMap;
35use itertools::{Either, Itertools};
36
37impl<'a> MonomorphizationVisitor<'a> {
38 fn try_evaluate_const_args(&mut self, const_args: &[Expression]) -> Option<Vec<Expression>> {
42 let mut const_evaluator = ConstPropagationVisitor::new(self.state, self.program);
43
44 let (evaluated_const_args, non_const_args): (Vec<_>, Vec<_>) = const_args
45 .iter()
46 .map(|arg| const_evaluator.reconstruct_expression(arg.clone(), &()))
47 .partition_map(|(evaluated_arg, evaluated_value)| match (evaluated_value, evaluated_arg) {
48 (Some(_), expr @ Expression::Literal(_)) => Either::Left(expr),
49 _ => Either::Right(()),
50 });
51
52 if !non_const_args.is_empty() { None } else { Some(evaluated_const_args) }
53 }
54}
55
56impl AstReconstructor for MonomorphizationVisitor<'_> {
57 type AdditionalInput = ();
58 type AdditionalOutput = ();
59
60 fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
62 if input.const_arguments.is_empty() {
64 return (Type::Composite(input), Default::default());
65 }
66
67 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
72 self.unresolved_struct_types.push(input.clone());
73 return (Type::Composite(input), Default::default());
74 };
75
76 self.changed = true;
78 (
79 Type::Composite(CompositeType {
80 path: self.monomorphize_struct(&input.path, &evaluated_const_args),
81 const_arguments: vec![], program: input.program,
83 }),
84 Default::default(),
85 )
86 }
87
88 fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
90 let opt_old_type = self.state.type_table.get(&input.id());
91 let (new_expr, opt_value) = match input {
92 Expression::Array(array) => self.reconstruct_array(array, &()),
93 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
94 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, &()),
95 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, &()),
96 Expression::Async(async_) => self.reconstruct_async(async_, &()),
97 Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
98 Expression::Call(call) => self.reconstruct_call(*call, &()),
99 Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
100 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
101 Expression::Err(err) => self.reconstruct_err(err, &()),
102 Expression::Path(path) => self.reconstruct_path(path, &()),
103 Expression::Literal(value) => self.reconstruct_literal(value, &()),
104 Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
105 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
106 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
107 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
108 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
109 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
110 Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
111 Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
112 };
113
114 if let Some(old_type) = opt_old_type {
116 self.state.type_table.insert(new_expr.id(), old_type);
117 }
118
119 (new_expr, opt_value)
120 }
121
122 fn reconstruct_call(
123 &mut self,
124 input_call: CallExpression,
125 _additional: &(),
126 ) -> (Expression, Self::AdditionalOutput) {
127 if input_call.program.is_some_and(|prog| prog != self.program) {
129 return (input_call.into(), Default::default());
130 }
131
132 if input_call.const_arguments.is_empty() {
134 return (input_call.into(), Default::default());
135 }
136
137 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input_call.const_arguments) else {
142 self.unresolved_calls.push(input_call.clone());
143 return (input_call.into(), Default::default());
144 };
145
146 let callee_fn = self
148 .reconstructed_functions
149 .get(&input_call.function.absolute_path())
150 .expect("Callee should already be reconstructed (post-order traversal).");
151
152 if !matches!(callee_fn.variant, Variant::Inline) {
154 return (input_call.into(), Default::default());
155 }
156
157 let new_callee_path = input_call.function.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!(
163 "\"{}::[{}]\"",
164 input_call.function.identifier().name,
165 evaluated_const_args.iter().format(", ")
166 )));
167
168 if self.reconstructed_functions.get(&new_callee_path.absolute_path()).is_none() {
171 let const_param_map: IndexMap<_, _> = callee_fn
173 .const_parameters
174 .iter()
175 .map(|param| param.identifier().name)
176 .zip_eq(&evaluated_const_args)
177 .collect();
178
179 let replace_identifier = |expr: &Expression| match expr {
181 Expression::Path(path) => const_param_map
182 .get(&path.identifier().name)
183 .map_or(Expression::Path(path.clone()), |&expr| expr.clone()),
184 _ => expr.clone(),
185 };
186
187 let mut replacer = Replacer::new(replace_identifier, true , self.state);
188
189 let mut function = replacer.reconstruct_function(callee_fn.clone());
194
195 function = self.reconstruct_function(function);
197 function.identifier = Identifier {
198 name: new_callee_path.identifier().name,
199 span: leo_span::Span::default(),
200 id: self.state.node_builder.next_id(),
201 };
202 function.const_parameters = vec![];
203 function.id = self.state.node_builder.next_id();
204
205 self.reconstructed_functions.insert(new_callee_path.absolute_path(), function);
207
208 self.monomorphized_functions.insert(input_call.function.absolute_path());
210 }
211
212 self.changed = true;
214
215 (
217 CallExpression {
218 function: new_callee_path,
219 const_arguments: vec![], arguments: input_call.arguments,
221 program: input_call.program,
222 span: input_call.span, id: input_call.id,
224 }
225 .into(),
226 Default::default(),
227 )
228 }
229
230 fn reconstruct_struct_init(
231 &mut self,
232 mut input: StructExpression,
233 _additional: &(),
234 ) -> (Expression, Self::AdditionalOutput) {
235 let members = input
237 .members
238 .clone()
239 .into_iter()
240 .map(|member| StructVariableInitializer {
241 identifier: member.identifier,
242 expression: member.expression.map(|expr| self.reconstruct_expression(expr, &()).0),
243 span: member.span,
244 id: member.id,
245 })
246 .collect();
247
248 if input.const_arguments.is_empty() {
250 input.members = members;
251 return (input.into(), Default::default());
252 }
253
254 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
259 self.unresolved_struct_exprs.push(input.clone());
260 input.members = members;
261 return (input.into(), Default::default());
262 };
263
264 self.changed = true;
266
267 (
269 StructExpression {
270 path: self.monomorphize_struct(&input.path, &evaluated_const_args),
271 members,
272 const_arguments: vec![], span: input.span, id: input.id,
275 }
276 .into(),
277 Default::default(),
278 )
279 }
280}