1use leo_ast::{
18 ArrayAccess,
19 BinaryExpression,
20 CastExpression,
21 CoreFunction,
22 Expression,
23 ExpressionReconstructor,
24 LiteralVariant,
25 MemberAccess,
26 Node,
27 RepeatExpression,
28 StructExpression,
29 TernaryExpression,
30 TupleAccess,
31 Type,
32 UnaryExpression,
33 interpreter_value::{self, StructContents, Value},
34};
35use leo_errors::StaticAnalyzerError;
36use leo_span::sym;
37
38use super::{ConstPropagationVisitor, value_to_expression};
39
40const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
41
42impl ExpressionReconstructor for ConstPropagationVisitor<'_> {
43 type AdditionalOutput = Option<Value>;
44
45 fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
46 let old_id = input.id();
47 let (new_expr, opt_value) = match input {
48 Expression::Array(array) => self.reconstruct_array(array),
49 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
50 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
51 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
52 Expression::Binary(binary) => self.reconstruct_binary(*binary),
53 Expression::Call(call) => self.reconstruct_call(*call),
54 Expression::Cast(cast) => self.reconstruct_cast(*cast),
55 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
56 Expression::Err(err) => self.reconstruct_err(err),
57 Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
58 Expression::Literal(value) => self.reconstruct_literal(value),
59 Expression::Locator(locator) => self.reconstruct_locator(locator),
60 Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
61 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
62 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
63 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
64 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
65 Expression::Unary(unary) => self.reconstruct_unary(*unary),
66 Expression::Unit(unit) => self.reconstruct_unit(unit),
67 };
68
69 if old_id != new_expr.id() {
70 self.changed = true;
71 let old_type =
72 self.state.type_table.get(&old_id).expect("Type checking guarantees that all expressions have a type.");
73 self.state.type_table.insert(new_expr.id(), old_type);
74 }
75
76 (new_expr, opt_value)
77 }
78
79 fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
80 let mut values = Vec::new();
81 for member in input.members.iter_mut() {
82 if let Some(expr) = std::mem::take(&mut member.expression) {
83 let (new_expr, value_opt) = self.reconstruct_expression(expr);
84 member.expression = Some(new_expr);
85 if let Some(value) = value_opt {
86 values.push(value);
87 }
88 }
89 }
90 if values.len() == input.members.len() {
91 let value = Value::Struct(StructContents {
92 name: input.name.name,
93 contents: input.members.iter().map(|mem| mem.identifier.name).zip(values).collect(),
94 });
95 (input.into(), Some(value))
96 } else {
97 (input.into(), None)
98 }
99 }
100
101 fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
102 let (cond, cond_value) = self.reconstruct_expression(input.condition);
103
104 match cond_value {
105 Some(Value::Bool(true)) => self.reconstruct_expression(input.if_true),
106 Some(Value::Bool(false)) => self.reconstruct_expression(input.if_false),
107 _ => (
108 TernaryExpression {
109 condition: cond,
110 if_true: self.reconstruct_expression(input.if_true).0,
111 if_false: self.reconstruct_expression(input.if_false).0,
112 ..input
113 }
114 .into(),
115 None,
116 ),
117 }
118 }
119
120 fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
121 let span = input.span();
122 let array_id = input.array.id();
123 let (array, value_opt) = self.reconstruct_expression(input.array);
124 let (index, opt_value) = self.reconstruct_expression(input.index);
125 if let Some(value) = opt_value {
126 let ty = self.state.type_table.get(&array_id);
129 let Some(Type::Array(array_ty)) = ty else {
130 panic!("Type checking guaranteed that this is an array.");
131 };
132 let len = array_ty.length.as_u32();
133
134 if let Some(len) = len {
135 let index: u32 = match value {
136 Value::U8(x) => x as u32,
137 Value::U16(x) => x as u32,
138 Value::U32(x) => x,
139 Value::U64(x) => x.try_into().unwrap_or(len),
140 Value::U128(x) => x.try_into().unwrap_or(len),
141 Value::I8(x) => x.try_into().unwrap_or(len),
142 Value::I16(x) => x.try_into().unwrap_or(len),
143 Value::I32(x) => x.try_into().unwrap_or(len),
144 Value::I64(x) => x.try_into().unwrap_or(len),
145 Value::I128(x) => x.try_into().unwrap_or(len),
146 _ => panic!("Type checking guarantees this is an integer"),
147 };
148
149 if index >= len {
150 if !self.state.handler.had_errors() {
153 let str_index = match value {
155 Value::U8(x) => format!("{x}"),
156 Value::U16(x) => format!("{x}"),
157 Value::U32(x) => format!("{x}"),
158 Value::U64(x) => format!("{x}"),
159 Value::U128(x) => format!("{x}"),
160 Value::I8(x) => format!("{x}"),
161 Value::I16(x) => format!("{x}"),
162 Value::I32(x) => format!("{x}"),
163 Value::I64(x) => format!("{x}"),
164 Value::I128(x) => format!("{x}"),
165 _ => unreachable!("We would have panicked above"),
166 };
167
168 self.emit_err(StaticAnalyzerError::array_bounds(str_index, len, span));
169 }
170 } else if let Some(Value::Array(value)) = value_opt {
171 let result_value = value.get(index as usize).expect("We already checked bounds.");
173 return (
174 value_to_expression(result_value, input.span, &self.state.node_builder).expect(VALUE_ERROR),
175 Some(result_value.clone()),
176 );
177 }
178 }
179 } else {
180 self.array_index_not_evaluated = Some(index.span());
181 }
182 (ArrayAccess { array, index, ..input }.into(), None)
183 }
184
185 fn reconstruct_associated_constant(
186 &mut self,
187 input: leo_ast::AssociatedConstantExpression,
188 ) -> (Expression, Self::AdditionalOutput) {
189 let generator = Value::generator();
191 let expr = value_to_expression(&generator, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
192 (expr, Some(generator))
193 }
194
195 fn reconstruct_associated_function(
196 &mut self,
197 mut input: leo_ast::AssociatedFunctionExpression,
198 ) -> (Expression, Self::AdditionalOutput) {
199 let mut values = Vec::new();
200 for argument in input.arguments.iter_mut() {
201 let (new_argument, opt_value) = self.reconstruct_expression(std::mem::take(argument));
202 *argument = new_argument;
203 if let Some(value) = opt_value {
204 values.push(value);
205 }
206 }
207
208 if values.len() == input.arguments.len() && !matches!(input.variant.name, sym::CheatCode | sym::Mapping) {
209 let core_function = CoreFunction::from_symbols(input.variant.name, input.name.name)
212 .expect("Type checking guarantees this is valid.");
213
214 match interpreter_value::evaluate_core_function(&mut values, core_function, &[], input.span()) {
215 Ok(Some(value)) => {
216 let expr = value_to_expression(&value, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
218 return (expr, Some(value));
219 }
220 Ok(None) =>
221 {}
223 Err(err) => {
224 self.emit_err(StaticAnalyzerError::compile_core_function(err, input.span()));
225 }
226 }
227 }
228
229 (input.into(), Default::default())
230 }
231
232 fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
233 let span = input.span();
234 let (inner, value_opt) = self.reconstruct_expression(input.inner);
235 let member_name = input.name.name;
236 if let Some(Value::Struct(contents)) = value_opt {
237 let value_result =
238 contents.contents.get(&member_name).expect("Type checking guarantees the member exists.");
239
240 (
241 value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
242 Some(value_result.clone()),
243 )
244 } else {
245 (MemberAccess { inner, ..input }.into(), None)
246 }
247 }
248
249 fn reconstruct_repeat(&mut self, input: leo_ast::RepeatExpression) -> (Expression, Self::AdditionalOutput) {
250 let (expr, expr_value) = self.reconstruct_expression(input.expr.clone());
251 let (count, count_value) = self.reconstruct_expression(input.count.clone());
252
253 if count_value.is_none() {
254 self.repeat_count_not_evaluated = Some(count.span());
255 }
256
257 match (expr_value, count.as_u32()) {
258 (Some(value), Some(count_u32)) => {
259 (RepeatExpression { expr, count, ..input }.into(), Some(Value::Array(vec![value; count_u32 as usize])))
260 }
261 _ => (RepeatExpression { expr, count, ..input }.into(), None),
262 }
263 }
264
265 fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
266 let span = input.span();
267 let (tuple, value_opt) = self.reconstruct_expression(input.tuple);
268 if let Some(Value::Tuple(tuple)) = value_opt {
269 let value_result = tuple.get(input.index.value()).expect("Type checking checked bounds.");
270 (
271 value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
272 Some(value_result.clone()),
273 )
274 } else {
275 (TupleAccess { tuple, ..input }.into(), None)
276 }
277 }
278
279 fn reconstruct_array(&mut self, mut input: leo_ast::ArrayExpression) -> (Expression, Self::AdditionalOutput) {
280 let mut values = Vec::new();
281 input.elements.iter_mut().for_each(|element| {
282 let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element));
283 if let Some(value) = value_opt {
284 values.push(value);
285 }
286 *element = new_element;
287 });
288 if values.len() == input.elements.len() {
289 (input.into(), Some(Value::Array(values)))
290 } else {
291 (input.into(), None)
292 }
293 }
294
295 fn reconstruct_binary(&mut self, input: leo_ast::BinaryExpression) -> (Expression, Self::AdditionalOutput) {
296 let span = input.span();
297
298 let (left, lhs_opt_value) = self.reconstruct_expression(input.left);
299 let (right, rhs_opt_value) = self.reconstruct_expression(input.right);
300
301 if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
302 match interpreter_value::evaluate_binary(span, input.op, &lhs_value, &rhs_value) {
304 Ok(new_value) => {
305 let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
306 return (new_expr, Some(new_value));
307 }
308 Err(err) => self
309 .emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
310 }
311 }
312
313 (BinaryExpression { left, right, ..input }.into(), None)
314 }
315
316 fn reconstruct_call(&mut self, mut input: leo_ast::CallExpression) -> (Expression, Self::AdditionalOutput) {
317 input.const_arguments.iter_mut().for_each(|arg| {
318 *arg = self.reconstruct_expression(std::mem::take(arg)).0;
319 });
320 input.arguments.iter_mut().for_each(|arg| {
321 *arg = self.reconstruct_expression(std::mem::take(arg)).0;
322 });
323 (input.into(), Default::default())
324 }
325
326 fn reconstruct_cast(&mut self, input: leo_ast::CastExpression) -> (Expression, Self::AdditionalOutput) {
327 let span = input.span();
328
329 let (expr, opt_value) = self.reconstruct_expression(input.expression);
330
331 if let Some(value) = opt_value {
332 if let Some(cast_value) = value.cast(&input.type_) {
333 let expr = value_to_expression(&cast_value, span, &self.state.node_builder).expect(VALUE_ERROR);
334 return (expr, Some(cast_value));
335 } else {
336 self.emit_err(StaticAnalyzerError::compile_time_cast(value, &input.type_, span));
337 }
338 }
339 (CastExpression { expression: expr, ..input }.into(), None)
340 }
341
342 fn reconstruct_err(&mut self, _input: leo_ast::ErrExpression) -> (Expression, Self::AdditionalOutput) {
343 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
344 }
345
346 fn reconstruct_identifier(&mut self, input: leo_ast::Identifier) -> (Expression, Self::AdditionalOutput) {
347 if let Some(expression) = self.state.symbol_table.lookup_const(self.program, input.name) {
349 let (expression, opt_value) = self.reconstruct_expression(expression);
350 if opt_value.is_some() {
351 return (expression, opt_value);
352 }
353 }
354
355 (input.into(), None)
356 }
357
358 fn reconstruct_literal(&mut self, mut input: leo_ast::Literal) -> (Expression, Self::AdditionalOutput) {
359 let type_info = self.state.type_table.get(&input.id());
360
361 let value =
362 interpreter_value::literal_to_value(&input, &type_info).expect("Failed to convert literal to value");
363
364 if let LiteralVariant::Unsuffixed(s) = input.variant {
367 match type_info.expect("Expected type information to be available") {
368 Type::Integer(ty) => input.variant = LiteralVariant::Integer(ty, s),
369 Type::Field => input.variant = LiteralVariant::Field(s),
370 Type::Group => input.variant = LiteralVariant::Group(s),
371 Type::Scalar => input.variant = LiteralVariant::Scalar(s),
372 _ => panic!("Type checking should have prevented this."),
373 }
374 }
375 (input.into(), Some(value))
376 }
377
378 fn reconstruct_locator(&mut self, input: leo_ast::LocatorExpression) -> (Expression, Self::AdditionalOutput) {
379 (input.into(), Default::default())
380 }
381
382 fn reconstruct_tuple(&mut self, mut input: leo_ast::TupleExpression) -> (Expression, Self::AdditionalOutput) {
383 let mut values = Vec::with_capacity(input.elements.len());
384 for expr in input.elements.iter_mut() {
385 let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr));
386 *expr = new_expr;
387 if let Some(value) = opt_value {
388 values.push(value);
389 }
390 }
391
392 let opt_value = if values.len() == input.elements.len() { Some(Value::Tuple(values)) } else { None };
393
394 (input.into(), opt_value)
395 }
396
397 fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
398 let (receiver, opt_value) = self.reconstruct_expression(input.receiver);
399 let span = input.span;
400
401 if let Some(value) = opt_value {
402 match interpreter_value::evaluate_unary(span, input.op, &value) {
404 Ok(new_value) => {
405 let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
406 return (new_expr, Some(new_value));
407 }
408 Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
409 }
410 }
411 (UnaryExpression { receiver, ..input }.into(), None)
412 }
413
414 fn reconstruct_unit(&mut self, input: leo_ast::UnitExpression) -> (Expression, Self::AdditionalOutput) {
415 (input.into(), None)
416 }
417}