leo_passes/monomorphization/
ast.rs1use super::MonomorphizationVisitor;
18use crate::{ConstPropagationVisitor, Replacer};
19
20use leo_ast::{
21 AstReconstructor,
22 CallExpression,
23 CompositeType,
24 Expression,
25 Identifier,
26 Node as _,
27 ProgramReconstructor,
28 StructExpression,
29 StructVariableInitializer,
30 Type,
31 Variant,
32};
33
34use indexmap::IndexMap;
35use itertools::{Either, Itertools};
36
37impl<'a> MonomorphizationVisitor<'a> {
38 fn try_evaluate_const_args(&mut self, const_args: &[Expression]) -> Option<Vec<Expression>> {
42 let mut const_evaluator = ConstPropagationVisitor::new(self.state, self.program);
43
44 let (evaluated_const_args, non_const_args): (Vec<_>, Vec<_>) = const_args
45 .iter()
46 .map(|arg| const_evaluator.reconstruct_expression(arg.clone(), &()))
47 .partition_map(|(evaluated_arg, evaluated_value)| match (evaluated_value, evaluated_arg) {
48 (Some(_), expr @ Expression::Literal(_)) => Either::Left(expr),
49 _ => Either::Right(()),
50 });
51
52 if !non_const_args.is_empty() { None } else { Some(evaluated_const_args) }
53 }
54}
55
56impl AstReconstructor for MonomorphizationVisitor<'_> {
57 type AdditionalInput = ();
58 type AdditionalOutput = ();
59
60 fn reconstruct_composite_type(&mut self, input: leo_ast::CompositeType) -> (leo_ast::Type, Self::AdditionalOutput) {
62 if input.const_arguments.is_empty() {
64 return (Type::Composite(input), Default::default());
65 }
66
67 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
72 self.unresolved_struct_types.push(input.clone());
73 return (Type::Composite(input), Default::default());
74 };
75
76 self.changed = true;
78 (
79 Type::Composite(CompositeType {
80 path: self.monomorphize_struct(&input.path, &evaluated_const_args),
81 const_arguments: vec![], program: input.program,
83 }),
84 Default::default(),
85 )
86 }
87
88 fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
90 let opt_old_type = self.state.type_table.get(&input.id());
91 let (new_expr, opt_value) = match input {
92 Expression::Array(array) => self.reconstruct_array(array, &()),
93 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
94 Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, &()),
95 Expression::Async(async_) => self.reconstruct_async(async_, &()),
96 Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
97 Expression::Call(call) => self.reconstruct_call(*call, &()),
98 Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
99 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
100 Expression::Err(err) => self.reconstruct_err(err, &()),
101 Expression::Path(path) => self.reconstruct_path(path, &()),
102 Expression::Literal(value) => self.reconstruct_literal(value, &()),
103 Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
104 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
105 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
106 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
107 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
108 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
109 Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
110 Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
111 };
112
113 if let Some(old_type) = opt_old_type {
115 self.state.type_table.insert(new_expr.id(), old_type);
116 }
117
118 (new_expr, opt_value)
119 }
120
121 fn reconstruct_call(
122 &mut self,
123 input_call: CallExpression,
124 _additional: &(),
125 ) -> (Expression, Self::AdditionalOutput) {
126 if input_call.program.is_some_and(|prog| prog != self.program) {
128 return (input_call.into(), Default::default());
129 }
130
131 if input_call.const_arguments.is_empty() {
133 return (input_call.into(), Default::default());
134 }
135
136 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input_call.const_arguments) else {
141 self.unresolved_calls.push(input_call.clone());
142 return (input_call.into(), Default::default());
143 };
144
145 let callee_fn = self
147 .reconstructed_functions
148 .get(&input_call.function.absolute_path())
149 .expect("Callee should already be reconstructed (post-order traversal).");
150
151 if !matches!(callee_fn.variant, Variant::Inline) {
153 return (input_call.into(), Default::default());
154 }
155
156 let new_callee_path = input_call.function.clone().with_updated_last_symbol(leo_span::Symbol::intern(&format!(
162 "\"{}::[{}]\"",
163 input_call.function.identifier().name,
164 evaluated_const_args.iter().format(", ")
165 )));
166
167 if self.reconstructed_functions.get(&new_callee_path.absolute_path()).is_none() {
170 let const_param_map: IndexMap<_, _> = callee_fn
172 .const_parameters
173 .iter()
174 .map(|param| param.identifier().name)
175 .zip_eq(&evaluated_const_args)
176 .collect();
177
178 let replace_identifier = |expr: &Expression| match expr {
180 Expression::Path(path) => const_param_map
181 .get(&path.identifier().name)
182 .map_or(Expression::Path(path.clone()), |&expr| expr.clone()),
183 _ => expr.clone(),
184 };
185
186 let mut replacer = Replacer::new(replace_identifier, true , self.state);
187
188 let mut function = replacer.reconstruct_function(callee_fn.clone());
193
194 function = self.reconstruct_function(function);
196 function.identifier = Identifier {
197 name: new_callee_path.identifier().name,
198 span: leo_span::Span::default(),
199 id: self.state.node_builder.next_id(),
200 };
201 function.const_parameters = vec![];
202 function.id = self.state.node_builder.next_id();
203
204 self.reconstructed_functions.insert(new_callee_path.absolute_path(), function);
206
207 self.monomorphized_functions.insert(input_call.function.absolute_path());
209 }
210
211 self.changed = true;
213
214 (
216 CallExpression {
217 function: new_callee_path,
218 const_arguments: vec![], arguments: input_call.arguments,
220 program: input_call.program,
221 span: input_call.span, id: input_call.id,
223 }
224 .into(),
225 Default::default(),
226 )
227 }
228
229 fn reconstruct_struct_init(
230 &mut self,
231 mut input: StructExpression,
232 _additional: &(),
233 ) -> (Expression, Self::AdditionalOutput) {
234 let members = input
236 .members
237 .clone()
238 .into_iter()
239 .map(|member| StructVariableInitializer {
240 identifier: member.identifier,
241 expression: member.expression.map(|expr| self.reconstruct_expression(expr, &()).0),
242 span: member.span,
243 id: member.id,
244 })
245 .collect();
246
247 if input.const_arguments.is_empty() {
249 input.members = members;
250 return (input.into(), Default::default());
251 }
252
253 let Some(evaluated_const_args) = self.try_evaluate_const_args(&input.const_arguments) else {
258 self.unresolved_struct_exprs.push(input.clone());
259 input.members = members;
260 return (input.into(), Default::default());
261 };
262
263 self.changed = true;
265
266 (
268 StructExpression {
269 path: self.monomorphize_struct(&input.path, &evaluated_const_args),
270 members,
271 const_arguments: vec![], span: input.span, id: input.id,
274 }
275 .into(),
276 Default::default(),
277 )
278 }
279}