leo_passes/const_propagation/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{
18    interpreter_value::{self, Value},
19    *,
20};
21use leo_errors::StaticAnalyzerError;
22use leo_span::{Symbol, sym};
23
24use super::ConstPropagationVisitor;
25
26const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
27
28impl AstReconstructor for ConstPropagationVisitor<'_> {
29    type AdditionalInput = ();
30    type AdditionalOutput = Option<Value>;
31
32    /* Types */
33    fn reconstruct_array_type(&mut self, input: leo_ast::ArrayType) -> (leo_ast::Type, Self::AdditionalOutput) {
34        let (length, opt_value) = self.reconstruct_expression(*input.length, &());
35
36        // If we can't evaluate this array length, keep track of it for error reporting later
37        if opt_value.is_none() {
38            self.array_length_not_evaluated = Some(length.span());
39        }
40
41        (
42            leo_ast::Type::Array(leo_ast::ArrayType {
43                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
44                length: Box::new(length),
45            }),
46            Default::default(),
47        )
48    }
49
50    /* Expressions */
51    fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
52        let opt_old_type = self.state.type_table.get(&input.id());
53        let (new_expr, opt_value) = match input {
54            Expression::Array(array) => self.reconstruct_array(array, &()),
55            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
56            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, &()),
57            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, &()),
58            Expression::Async(async_) => self.reconstruct_async(async_, &()),
59            Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
60            Expression::Call(call) => self.reconstruct_call(*call, &()),
61            Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
62            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
63            Expression::Err(err) => self.reconstruct_err(err, &()),
64            Expression::Path(path) => self.reconstruct_path(path, &()),
65            Expression::Literal(value) => self.reconstruct_literal(value, &()),
66            Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
67            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
68            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
69            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
70            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
71            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
72            Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
73            Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
74        };
75
76        // If the expression was in the type table before, make an entry for the new expression.
77        if let Some(old_type) = opt_old_type {
78            self.state.type_table.insert(new_expr.id(), old_type);
79        }
80
81        (new_expr, opt_value)
82    }
83
84    fn reconstruct_struct_init(
85        &mut self,
86        mut input: StructExpression,
87        _additional: &(),
88    ) -> (Expression, Self::AdditionalOutput) {
89        let mut values = Vec::new();
90        input.const_arguments.iter_mut().for_each(|arg| {
91            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
92        });
93        for member in input.members.iter_mut() {
94            let expression = member.expression.take().unwrap_or_else(|| {
95                Path::from(member.identifier).with_absolute_path(Some(vec![member.identifier.name])).into()
96            });
97            let (new_expr, value_opt) = self.reconstruct_expression(expression, &());
98            member.expression = Some(new_expr);
99            if let Some(value) = value_opt {
100                values.push(value);
101            }
102        }
103
104        if values.len() == input.members.len() && input.const_arguments.is_empty() {
105            let value = Value::make_struct(
106                input.members.iter().map(|mem| mem.identifier.name).zip(values),
107                self.program,
108                input.path.absolute_path(),
109            );
110            (input.into(), Some(value))
111        } else {
112            (input.into(), None)
113        }
114    }
115
116    fn reconstruct_ternary(
117        &mut self,
118        input: TernaryExpression,
119        _additional: &(),
120    ) -> (Expression, Self::AdditionalOutput) {
121        let (cond, cond_value) = self.reconstruct_expression(input.condition, &());
122
123        match cond_value.and_then(|v| v.try_into().ok()) {
124            Some(true) => self.reconstruct_expression(input.if_true, &()),
125            Some(false) => self.reconstruct_expression(input.if_false, &()),
126            _ => (
127                TernaryExpression {
128                    condition: cond,
129                    if_true: self.reconstruct_expression(input.if_true, &()).0,
130                    if_false: self.reconstruct_expression(input.if_false, &()).0,
131                    ..input
132                }
133                .into(),
134                None,
135            ),
136        }
137    }
138
139    fn reconstruct_array_access(
140        &mut self,
141        input: ArrayAccess,
142        _additional: &(),
143    ) -> (Expression, Self::AdditionalOutput) {
144        let span = input.span();
145        let id = input.id();
146        let array_id = input.array.id();
147        let (array, array_opt) = self.reconstruct_expression(input.array, &());
148        let (index, index_opt) = self.reconstruct_expression(input.index, &());
149        if let Some(index_value) = index_opt {
150            // We can perform compile time bounds checking.
151
152            let ty = self.state.type_table.get(&array_id);
153            let Some(Type::Array(array_ty)) = ty else {
154                panic!("Type checking guaranteed that this is an array.");
155            };
156            let len = array_ty.length.as_u32();
157
158            if let Some(len) = len {
159                let index_in_bounds = matches!(index_value.as_u32(), Some(index) if index < len);
160
161                if !index_in_bounds {
162                    // Only emit a bounds error if we have no other errors yet.
163                    // This prevents a chain of redundant error messages when a loop is unrolled.
164                    if !self.state.handler.had_errors() {
165                        // Get the integer string with no suffix.
166                        let integer_with_suffix = index_value.to_string();
167                        let suffix_index = integer_with_suffix.find(['i', 'u']).unwrap_or(integer_with_suffix.len());
168                        self.emit_err(StaticAnalyzerError::array_bounds(
169                            &integer_with_suffix[..suffix_index],
170                            len,
171                            span,
172                        ));
173                    }
174                } else if let Some(array_value) = array_opt {
175                    // We're in bounds and we can evaluate the array at compile time, so just return the value.
176                    let result_value = array_value
177                        .array_index(index_value.as_u32().unwrap() as usize)
178                        .expect("We already checked bounds.");
179                    dbg!(&result_value);
180                    return (
181                        self.value_to_expression(&result_value, span, id).expect(VALUE_ERROR),
182                        Some(result_value.clone()),
183                    );
184                }
185            }
186        } else {
187            self.array_index_not_evaluated = Some(index.span());
188        }
189        (ArrayAccess { array, index, ..input }.into(), None)
190    }
191
192    fn reconstruct_associated_constant(
193        &mut self,
194        input: leo_ast::AssociatedConstantExpression,
195        _additional: &(),
196    ) -> (Expression, Self::AdditionalOutput) {
197        // Currently there is only one associated constant.
198        let generator = Value::generator();
199        let expr = self.value_to_expression_node(&generator, &input).expect(VALUE_ERROR);
200        (expr, Some(generator))
201    }
202
203    fn reconstruct_associated_function(
204        &mut self,
205        mut input: leo_ast::AssociatedFunctionExpression,
206        _additional: &(),
207    ) -> (Expression, Self::AdditionalOutput) {
208        let mut values = Vec::new();
209        for argument in input.arguments.iter_mut() {
210            let (new_argument, opt_value) = self.reconstruct_expression(std::mem::take(argument), &());
211            *argument = new_argument;
212            if let Some(value) = opt_value {
213                values.push(value);
214            }
215        }
216
217        if values.len() == input.arguments.len() && !matches!(input.variant.name, sym::CheatCode | sym::Mapping) {
218            // We've evaluated every argument, and this function isn't a cheat code or mapping
219            // operation, so maybe we can compute the result at compile time.
220            let core_function = CoreFunction::from_symbols(input.variant.name, input.name.name)
221                .expect("Type checking guarantees this is valid.");
222
223            match interpreter_value::evaluate_core_function(&mut values, core_function, &[], input.span()) {
224                Ok(Some(value)) => {
225                    // Successful evaluation.
226                    let expr = self.value_to_expression_node(&value, &input).expect(VALUE_ERROR);
227                    return (expr, Some(value));
228                }
229                Ok(None) =>
230                    // No errors, but we were unable to evaluate.
231                    {}
232                Err(err) => {
233                    self.emit_err(StaticAnalyzerError::compile_core_function(err, input.span()));
234                }
235            }
236        }
237
238        (input.into(), Default::default())
239    }
240
241    fn reconstruct_member_access(
242        &mut self,
243        input: MemberAccess,
244        _additional: &(),
245    ) -> (Expression, Self::AdditionalOutput) {
246        let span = input.span();
247        let id = input.id();
248        let (inner, value_opt) = self.reconstruct_expression(input.inner, &());
249        let member_name = input.name.name;
250        if let Some(struct_) = value_opt {
251            let value_result = struct_.member_access(member_name).expect("Type checking guarantees the member exists.");
252
253            (self.value_to_expression(&value_result, span, id).expect(VALUE_ERROR), Some(value_result.clone()))
254        } else {
255            (MemberAccess { inner, ..input }.into(), None)
256        }
257    }
258
259    fn reconstruct_repeat(
260        &mut self,
261        input: leo_ast::RepeatExpression,
262        _additional: &(),
263    ) -> (Expression, Self::AdditionalOutput) {
264        let (expr, expr_value) = self.reconstruct_expression(input.expr.clone(), &());
265        let (count, count_value) = self.reconstruct_expression(input.count.clone(), &());
266
267        if count_value.is_none() {
268            self.repeat_count_not_evaluated = Some(count.span());
269        }
270
271        match (expr_value, count.as_u32()) {
272            (Some(value), Some(count_u32)) => (
273                RepeatExpression { expr, count, ..input }.into(),
274                Some(Value::make_array(std::iter::repeat_n(value, count_u32 as usize))),
275            ),
276            _ => (RepeatExpression { expr, count, ..input }.into(), None),
277        }
278    }
279
280    fn reconstruct_tuple_access(
281        &mut self,
282        input: TupleAccess,
283        _additional: &(),
284    ) -> (Expression, Self::AdditionalOutput) {
285        let span = input.span();
286        let id = input.id();
287        let (tuple, value_opt) = self.reconstruct_expression(input.tuple, &());
288        if let Some(tuple_value) = value_opt {
289            let value_result = tuple_value.tuple_index(input.index.value()).expect("Type checking checked bounds.");
290            (self.value_to_expression(&value_result, span, id).expect(VALUE_ERROR), Some(value_result.clone()))
291        } else {
292            (TupleAccess { tuple, ..input }.into(), None)
293        }
294    }
295
296    fn reconstruct_array(
297        &mut self,
298        mut input: leo_ast::ArrayExpression,
299        _additional: &(),
300    ) -> (Expression, Self::AdditionalOutput) {
301        let mut values = Vec::new();
302        input.elements.iter_mut().for_each(|element| {
303            let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element), &());
304            if let Some(value) = value_opt {
305                values.push(value);
306            }
307            *element = new_element;
308        });
309        if values.len() == input.elements.len() {
310            (input.into(), Some(Value::make_array(values.into_iter())))
311        } else {
312            (input.into(), None)
313        }
314    }
315
316    fn reconstruct_binary(
317        &mut self,
318        input: leo_ast::BinaryExpression,
319        _additional: &(),
320    ) -> (Expression, Self::AdditionalOutput) {
321        let span = input.span();
322        let input_id = input.id();
323
324        let (left, lhs_opt_value) = self.reconstruct_expression(input.left, &());
325        let (right, rhs_opt_value) = self.reconstruct_expression(input.right, &());
326
327        if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
328            // We were able to evaluate both operands, so we can evaluate this expression.
329            match interpreter_value::evaluate_binary(
330                span,
331                input.op,
332                &lhs_value,
333                &rhs_value,
334                &self.state.type_table.get(&input_id),
335            ) {
336                Ok(new_value) => {
337                    let new_expr = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
338                    return (new_expr, Some(new_value));
339                }
340                Err(err) => self
341                    .emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
342            }
343        }
344
345        (BinaryExpression { left, right, ..input }.into(), None)
346    }
347
348    fn reconstruct_call(
349        &mut self,
350        mut input: leo_ast::CallExpression,
351        _additional: &(),
352    ) -> (Expression, Self::AdditionalOutput) {
353        input.const_arguments.iter_mut().for_each(|arg| {
354            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
355        });
356        input.arguments.iter_mut().for_each(|arg| {
357            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
358        });
359        (input.into(), Default::default())
360    }
361
362    fn reconstruct_cast(
363        &mut self,
364        input: leo_ast::CastExpression,
365        _additional: &(),
366    ) -> (Expression, Self::AdditionalOutput) {
367        let span = input.span();
368        let id = input.id();
369
370        let (expr, opt_value) = self.reconstruct_expression(input.expression, &());
371
372        if let Some(value) = opt_value {
373            if let Some(cast_value) = value.cast(&input.type_) {
374                let expr = self.value_to_expression(&cast_value, span, id).expect(VALUE_ERROR);
375                return (expr, Some(cast_value));
376            } else {
377                self.emit_err(StaticAnalyzerError::compile_time_cast(value, &input.type_, span));
378            }
379        }
380        (CastExpression { expression: expr, ..input }.into(), None)
381    }
382
383    fn reconstruct_err(
384        &mut self,
385        _input: leo_ast::ErrExpression,
386        _additional: &(),
387    ) -> (Expression, Self::AdditionalOutput) {
388        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
389    }
390
391    fn reconstruct_path(&mut self, input: leo_ast::Path, _additional: &()) -> (Expression, Self::AdditionalOutput) {
392        // Substitute the identifier with the constant value if it is a constant that's been evaluated.
393        if let Some(expression) = self.state.symbol_table.lookup_const(self.program, &input.absolute_path()) {
394            let (expression, opt_value) = self.reconstruct_expression(expression, &());
395            if opt_value.is_some() {
396                return (expression, opt_value);
397            }
398        }
399
400        (input.into(), None)
401    }
402
403    fn reconstruct_literal(
404        &mut self,
405        mut input: leo_ast::Literal,
406        _additional: &(),
407    ) -> (Expression, Self::AdditionalOutput) {
408        let type_info = self.state.type_table.get(&input.id());
409
410        // If this is an optional, then unwrap it first.
411        let type_info = type_info.as_ref().map(|ty| match ty {
412            Type::Optional(opt) => *opt.inner.clone(),
413            _ => ty.clone(),
414        });
415
416        if let Ok(value) = interpreter_value::literal_to_value(&input, &type_info) {
417            // If we know the type of an unsuffixed literal, might as well change it to a suffixed literal. This way, we
418            // do not have to infer the type again in later passes of type checking.
419            if let LiteralVariant::Unsuffixed(s) = input.variant {
420                match type_info.expect("Expected type information to be available") {
421                    Type::Integer(ty) => input.variant = LiteralVariant::Integer(ty, s),
422                    Type::Field => input.variant = LiteralVariant::Field(s),
423                    Type::Group => input.variant = LiteralVariant::Group(s),
424                    Type::Scalar => input.variant = LiteralVariant::Scalar(s),
425                    _ => panic!("Type checking should have prevented this."),
426                }
427            }
428            (input.into(), Some(value))
429        } else {
430            (input.into(), None)
431        }
432    }
433
434    fn reconstruct_locator(
435        &mut self,
436        input: leo_ast::LocatorExpression,
437        _additional: &(),
438    ) -> (Expression, Self::AdditionalOutput) {
439        (input.into(), Default::default())
440    }
441
442    fn reconstruct_tuple(
443        &mut self,
444        mut input: leo_ast::TupleExpression,
445        _additional: &(),
446    ) -> (Expression, Self::AdditionalOutput) {
447        let mut values = Vec::with_capacity(input.elements.len());
448        for expr in input.elements.iter_mut() {
449            let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr), &());
450            *expr = new_expr;
451            if let Some(value) = opt_value {
452                values.push(value);
453            }
454        }
455
456        let opt_value = if values.len() == input.elements.len() { Some(Value::make_tuple(values)) } else { None };
457
458        (input.into(), opt_value)
459    }
460
461    fn reconstruct_unary(&mut self, input: UnaryExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
462        let input_id = input.id();
463        let span = input.span;
464        let (receiver, opt_value) = self.reconstruct_expression(input.receiver, &());
465
466        if let Some(value) = opt_value {
467            // We were able to evaluate the operand, so we can evaluate the expression.
468            match interpreter_value::evaluate_unary(span, input.op, &value, &self.state.type_table.get(&input_id)) {
469                Ok(new_value) => {
470                    let new_expr = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
471                    return (new_expr, Some(new_value));
472                }
473                Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
474            }
475        }
476        (UnaryExpression { receiver, ..input }.into(), None)
477    }
478
479    fn reconstruct_unit(
480        &mut self,
481        input: leo_ast::UnitExpression,
482        _additional: &(),
483    ) -> (Expression, Self::AdditionalOutput) {
484        (input.into(), None)
485    }
486
487    /* Statements */
488    fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
489        // Catching asserts at compile time is not feasible here due to control flow, but could be done in
490        // a later pass after loops are unrolled and conditionals are flattened.
491        input.variant = match input.variant {
492            AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr, &()).0),
493
494            AssertVariant::AssertEq(lhs, rhs) => AssertVariant::AssertEq(
495                self.reconstruct_expression(lhs, &()).0,
496                self.reconstruct_expression(rhs, &()).0,
497            ),
498
499            AssertVariant::AssertNeq(lhs, rhs) => AssertVariant::AssertNeq(
500                self.reconstruct_expression(lhs, &()).0,
501                self.reconstruct_expression(rhs, &()).0,
502            ),
503        };
504
505        (input.into(), None)
506    }
507
508    fn reconstruct_assign(&mut self, assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
509        let value = self.reconstruct_expression(assign.value, &()).0;
510        let place = self.reconstruct_expression(assign.place, &()).0;
511        (AssignStatement { value, place, ..assign }.into(), None)
512    }
513
514    fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
515        self.in_scope(block.id(), |slf| {
516            block.statements.retain_mut(|statement| {
517                let bogus_statement = Statement::dummy();
518                let this_statement = std::mem::replace(statement, bogus_statement);
519                *statement = slf.reconstruct_statement(this_statement).0;
520                !statement.is_empty()
521            });
522            (block, None)
523        })
524    }
525
526    fn reconstruct_conditional(
527        &mut self,
528        mut conditional: ConditionalStatement,
529    ) -> (Statement, Self::AdditionalOutput) {
530        conditional.condition = self.reconstruct_expression(conditional.condition, &()).0;
531        conditional.then = self.reconstruct_block(conditional.then).0;
532        if let Some(mut otherwise) = conditional.otherwise {
533            *otherwise = self.reconstruct_statement(*otherwise).0;
534            conditional.otherwise = Some(otherwise);
535        }
536
537        (Statement::Conditional(conditional), None)
538    }
539
540    fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
541        if matches!(input.type_, Type::Optional(_)) {
542            return (input.into(), None);
543        }
544
545        let span = input.span();
546
547        let type_ = self.reconstruct_type(input.type_).0;
548        let (expr, opt_value) = self.reconstruct_expression(input.value, &());
549
550        if opt_value.is_some() {
551            let path: &[Symbol] = if self.state.symbol_table.global_scope() {
552                // Then we need to insert the const with its full module-scoped path.
553                &self.module.iter().copied().chain(std::iter::once(input.place.name)).collect::<Vec<_>>()
554            } else {
555                &[input.place.name]
556            };
557            if self.state.symbol_table.lookup_const(self.program, path).is_none() {
558                // It wasn't already evaluated - insert it and record that we've made a change.
559                self.state.symbol_table.insert_const(self.program, path, expr.clone());
560                if self.state.symbol_table.global_scope() {
561                    // We made a change in the global scope, so this was a real change.
562                    self.changed = true;
563                }
564            }
565        } else {
566            self.const_not_evaluated = Some(span);
567        }
568
569        input.type_ = type_;
570        input.value = expr;
571
572        (Statement::Const(input), None)
573    }
574
575    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
576        (
577            DefinitionStatement {
578                type_: definition.type_.map(|ty| self.reconstruct_type(ty).0),
579                value: self.reconstruct_expression(definition.value, &()).0,
580                ..definition
581            }
582            .into(),
583            None,
584        )
585    }
586
587    fn reconstruct_expression_statement(
588        &mut self,
589        mut input: ExpressionStatement,
590    ) -> (Statement, Self::AdditionalOutput) {
591        input.expression = self.reconstruct_expression(input.expression, &()).0;
592
593        if matches!(&input.expression, Expression::Unit(..) | Expression::Literal(..)) {
594            // We were able to evaluate this at compile time, but we need to get rid of this statement as
595            // we can't have expression statements that aren't calls.
596            (Statement::dummy(), Default::default())
597        } else {
598            (input.into(), Default::default())
599        }
600    }
601
602    fn reconstruct_iteration(&mut self, iteration: IterationStatement) -> (Statement, Self::AdditionalOutput) {
603        let id = iteration.id();
604        let type_ = iteration.type_.map(|ty| self.reconstruct_type(ty).0);
605        let start = self.reconstruct_expression(iteration.start, &()).0;
606        let stop = self.reconstruct_expression(iteration.stop, &()).0;
607        self.in_scope(id, |slf| {
608            (
609                IterationStatement { type_, start, stop, block: slf.reconstruct_block(iteration.block).0, ..iteration }
610                    .into(),
611                None,
612            )
613        })
614    }
615
616    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
617        (
618            ReturnStatement { expression: self.reconstruct_expression(input.expression, &()).0, ..input }.into(),
619            Default::default(),
620        )
621    }
622}