leo_passes/const_propagation/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{
18    interpreter_value::{self, Value},
19    *,
20};
21use leo_errors::StaticAnalyzerError;
22use leo_span::Symbol;
23
24use super::ConstPropagationVisitor;
25
26const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
27
28impl AstReconstructor for ConstPropagationVisitor<'_> {
29    type AdditionalInput = ();
30    type AdditionalOutput = Option<Value>;
31
32    /* Types */
33    fn reconstruct_array_type(&mut self, input: leo_ast::ArrayType) -> (leo_ast::Type, Self::AdditionalOutput) {
34        let (length, opt_value) = self.reconstruct_expression(*input.length, &());
35
36        // If we can't evaluate this array length, keep track of it for error reporting later
37        if opt_value.is_none() {
38            self.array_length_not_evaluated = Some(length.span());
39        }
40
41        (
42            leo_ast::Type::Array(leo_ast::ArrayType {
43                element_type: Box::new(self.reconstruct_type(*input.element_type).0),
44                length: Box::new(length),
45            }),
46            Default::default(),
47        )
48    }
49
50    /* Expressions */
51    fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
52        let opt_old_type = self.state.type_table.get(&input.id());
53        let (new_expr, opt_value) = match input {
54            Expression::Array(array) => self.reconstruct_array(array, &()),
55            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
56            Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, &()),
57            Expression::Async(async_) => self.reconstruct_async(async_, &()),
58            Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
59            Expression::Call(call) => self.reconstruct_call(*call, &()),
60            Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
61            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
62            Expression::Err(err) => self.reconstruct_err(err, &()),
63            Expression::Path(path) => self.reconstruct_path(path, &()),
64            Expression::Literal(value) => self.reconstruct_literal(value, &()),
65            Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
66            Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
67            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
68            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
69            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
70            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
71            Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
72            Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
73        };
74
75        // If the expression was in the type table before, make an entry for the new expression.
76        if let Some(old_type) = opt_old_type {
77            self.state.type_table.insert(new_expr.id(), old_type);
78        }
79
80        (new_expr, opt_value)
81    }
82
83    fn reconstruct_struct_init(
84        &mut self,
85        mut input: StructExpression,
86        _additional: &(),
87    ) -> (Expression, Self::AdditionalOutput) {
88        let mut values = Vec::new();
89        input.const_arguments.iter_mut().for_each(|arg| {
90            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
91        });
92        for member in input.members.iter_mut() {
93            let expression = member.expression.take().unwrap_or_else(|| {
94                Path::from(member.identifier).with_absolute_path(Some(vec![member.identifier.name])).into()
95            });
96            let (new_expr, value_opt) = self.reconstruct_expression(expression, &());
97            member.expression = Some(new_expr);
98            if let Some(value) = value_opt {
99                values.push(value);
100            }
101        }
102
103        if values.len() == input.members.len() && input.const_arguments.is_empty() {
104            let value = Value::make_struct(
105                input.members.iter().map(|mem| mem.identifier.name).zip(values),
106                self.program,
107                input.path.absolute_path(),
108            );
109            (input.into(), Some(value))
110        } else {
111            (input.into(), None)
112        }
113    }
114
115    fn reconstruct_ternary(
116        &mut self,
117        input: TernaryExpression,
118        _additional: &(),
119    ) -> (Expression, Self::AdditionalOutput) {
120        let (cond, cond_value) = self.reconstruct_expression(input.condition, &());
121
122        match cond_value.and_then(|v| v.try_into().ok()) {
123            Some(true) => self.reconstruct_expression(input.if_true, &()),
124            Some(false) => self.reconstruct_expression(input.if_false, &()),
125            _ => (
126                TernaryExpression {
127                    condition: cond,
128                    if_true: self.reconstruct_expression(input.if_true, &()).0,
129                    if_false: self.reconstruct_expression(input.if_false, &()).0,
130                    ..input
131                }
132                .into(),
133                None,
134            ),
135        }
136    }
137
138    fn reconstruct_array_access(
139        &mut self,
140        input: ArrayAccess,
141        _additional: &(),
142    ) -> (Expression, Self::AdditionalOutput) {
143        let span = input.span();
144        let id = input.id();
145        let array_id = input.array.id();
146        let (array, array_opt) = self.reconstruct_expression(input.array, &());
147        let (index, index_opt) = self.reconstruct_expression(input.index, &());
148        if let Some(index_value) = index_opt {
149            // We can perform compile time bounds checking.
150
151            let ty = self.state.type_table.get(&array_id);
152            let Some(Type::Array(array_ty)) = ty else {
153                panic!("Type checking guaranteed that this is an array.");
154            };
155            let len = array_ty.length.as_u32();
156
157            if let Some(len) = len {
158                let index_in_bounds = matches!(index_value.as_u32(), Some(index) if index < len);
159
160                if !index_in_bounds {
161                    // Only emit a bounds error if we have no other errors yet.
162                    // This prevents a chain of redundant error messages when a loop is unrolled.
163                    if !self.state.handler.had_errors() {
164                        // Get the integer string with no suffix.
165                        let integer_with_suffix = index_value.to_string();
166                        let suffix_index = integer_with_suffix.find(['i', 'u']).unwrap_or(integer_with_suffix.len());
167                        self.emit_err(StaticAnalyzerError::array_bounds(
168                            &integer_with_suffix[..suffix_index],
169                            len,
170                            span,
171                        ));
172                    }
173                } else if let Some(array_value) = array_opt {
174                    // We're in bounds and we can evaluate the array at compile time, so just return the value.
175                    let result_value = array_value
176                        .array_index(index_value.as_u32().unwrap() as usize)
177                        .expect("We already checked bounds.");
178                    return (
179                        self.value_to_expression(&result_value, span, id).expect(VALUE_ERROR),
180                        Some(result_value.clone()),
181                    );
182                }
183            }
184        } else {
185            self.array_index_not_evaluated = Some(index.span());
186        }
187        (ArrayAccess { array, index, ..input }.into(), None)
188    }
189
190    fn reconstruct_intrinsic(
191        &mut self,
192        mut input: leo_ast::IntrinsicExpression,
193        _additional: &(),
194    ) -> (Expression, Self::AdditionalOutput) {
195        let mut values = Vec::new();
196        for argument in input.arguments.iter_mut() {
197            let (new_argument, opt_value) = self.reconstruct_expression(std::mem::take(argument), &());
198            *argument = new_argument;
199            if let Some(value) = opt_value {
200                values.push(value);
201            }
202        }
203
204        let intrinsic = Intrinsic::from_symbol(input.name, &input.type_parameters)
205            .expect("Type checking guarantees this is valid.");
206
207        if values.len() == input.arguments.len() && intrinsic.is_pure() {
208            // We've evaluated every argument, and this function has no side
209            // effects, so maybe we can compute the result at compile time.
210
211            match interpreter_value::evaluate_intrinsic(&mut values, intrinsic, &[], input.span()) {
212                Ok(Some(value)) => {
213                    // Successful evaluation.
214                    let expr = self.value_to_expression_node(&value, &input).expect(VALUE_ERROR);
215                    return (expr, Some(value));
216                }
217                Ok(None) =>
218                    // No errors, but we were unable to evaluate.
219                    {}
220                Err(err) => {
221                    self.emit_err(StaticAnalyzerError::compile_intrinsic(err, input.span()));
222                }
223            }
224        }
225
226        (input.into(), Default::default())
227    }
228
229    fn reconstruct_member_access(
230        &mut self,
231        input: MemberAccess,
232        _additional: &(),
233    ) -> (Expression, Self::AdditionalOutput) {
234        let span = input.span();
235        let id = input.id();
236        let (inner, value_opt) = self.reconstruct_expression(input.inner, &());
237        let member_name = input.name.name;
238        if let Some(struct_) = value_opt {
239            let value_result = struct_.member_access(member_name).expect("Type checking guarantees the member exists.");
240
241            (self.value_to_expression(&value_result, span, id).expect(VALUE_ERROR), Some(value_result.clone()))
242        } else {
243            (MemberAccess { inner, ..input }.into(), None)
244        }
245    }
246
247    fn reconstruct_repeat(
248        &mut self,
249        input: leo_ast::RepeatExpression,
250        _additional: &(),
251    ) -> (Expression, Self::AdditionalOutput) {
252        let (expr, expr_value) = self.reconstruct_expression(input.expr.clone(), &());
253        let (count, count_value) = self.reconstruct_expression(input.count.clone(), &());
254
255        if count_value.is_none() {
256            self.repeat_count_not_evaluated = Some(count.span());
257        }
258
259        match (expr_value, count.as_u32()) {
260            (Some(value), Some(count_u32)) => (
261                RepeatExpression { expr, count, ..input }.into(),
262                Some(Value::make_array(std::iter::repeat_n(value, count_u32 as usize))),
263            ),
264            _ => (RepeatExpression { expr, count, ..input }.into(), None),
265        }
266    }
267
268    fn reconstruct_tuple_access(
269        &mut self,
270        input: TupleAccess,
271        _additional: &(),
272    ) -> (Expression, Self::AdditionalOutput) {
273        let span = input.span();
274        let id = input.id();
275        let (tuple, value_opt) = self.reconstruct_expression(input.tuple, &());
276        if let Some(tuple_value) = value_opt {
277            let value_result = tuple_value.tuple_index(input.index.value()).expect("Type checking checked bounds.");
278            (self.value_to_expression(&value_result, span, id).expect(VALUE_ERROR), Some(value_result.clone()))
279        } else {
280            (TupleAccess { tuple, ..input }.into(), None)
281        }
282    }
283
284    fn reconstruct_array(
285        &mut self,
286        mut input: leo_ast::ArrayExpression,
287        _additional: &(),
288    ) -> (Expression, Self::AdditionalOutput) {
289        let mut values = Vec::new();
290        input.elements.iter_mut().for_each(|element| {
291            let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element), &());
292            if let Some(value) = value_opt {
293                values.push(value);
294            }
295            *element = new_element;
296        });
297        if values.len() == input.elements.len() {
298            (input.into(), Some(Value::make_array(values.into_iter())))
299        } else {
300            (input.into(), None)
301        }
302    }
303
304    fn reconstruct_binary(
305        &mut self,
306        input: leo_ast::BinaryExpression,
307        _additional: &(),
308    ) -> (Expression, Self::AdditionalOutput) {
309        let span = input.span();
310        let input_id = input.id();
311
312        let (left, lhs_opt_value) = self.reconstruct_expression(input.left, &());
313        let (right, rhs_opt_value) = self.reconstruct_expression(input.right, &());
314
315        if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
316            // We were able to evaluate both operands, so we can evaluate this expression.
317            match interpreter_value::evaluate_binary(
318                span,
319                input.op,
320                &lhs_value,
321                &rhs_value,
322                &self.state.type_table.get(&input_id),
323            ) {
324                Ok(new_value) => {
325                    let new_expr = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
326                    return (new_expr, Some(new_value));
327                }
328                Err(err) => self
329                    .emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
330            }
331        }
332
333        (BinaryExpression { left, right, ..input }.into(), None)
334    }
335
336    fn reconstruct_call(
337        &mut self,
338        mut input: leo_ast::CallExpression,
339        _additional: &(),
340    ) -> (Expression, Self::AdditionalOutput) {
341        input.const_arguments.iter_mut().for_each(|arg| {
342            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
343        });
344        input.arguments.iter_mut().for_each(|arg| {
345            *arg = self.reconstruct_expression(std::mem::take(arg), &()).0;
346        });
347        (input.into(), Default::default())
348    }
349
350    fn reconstruct_cast(
351        &mut self,
352        input: leo_ast::CastExpression,
353        _additional: &(),
354    ) -> (Expression, Self::AdditionalOutput) {
355        let span = input.span();
356        let id = input.id();
357
358        let (expr, opt_value) = self.reconstruct_expression(input.expression, &());
359
360        if let Some(value) = opt_value {
361            if let Some(cast_value) = value.cast(&input.type_) {
362                let expr = self.value_to_expression(&cast_value, span, id).expect(VALUE_ERROR);
363                return (expr, Some(cast_value));
364            } else {
365                self.emit_err(StaticAnalyzerError::compile_time_cast(value, &input.type_, span));
366            }
367        }
368        (CastExpression { expression: expr, ..input }.into(), None)
369    }
370
371    fn reconstruct_err(
372        &mut self,
373        _input: leo_ast::ErrExpression,
374        _additional: &(),
375    ) -> (Expression, Self::AdditionalOutput) {
376        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
377    }
378
379    fn reconstruct_path(&mut self, input: leo_ast::Path, _additional: &()) -> (Expression, Self::AdditionalOutput) {
380        // Substitute the identifier with the constant value if it is a constant that's been evaluated.
381        if let Some(expression) = self.state.symbol_table.lookup_const(self.program, &input.absolute_path()) {
382            let (expression, opt_value) = self.reconstruct_expression(expression, &());
383            if opt_value.is_some() {
384                return (expression, opt_value);
385            }
386        }
387
388        (input.into(), None)
389    }
390
391    fn reconstruct_literal(
392        &mut self,
393        mut input: leo_ast::Literal,
394        _additional: &(),
395    ) -> (Expression, Self::AdditionalOutput) {
396        let type_info = self.state.type_table.get(&input.id());
397
398        // If this is an optional, then unwrap it first.
399        let type_info = type_info.as_ref().map(|ty| match ty {
400            Type::Optional(opt) => *opt.inner.clone(),
401            _ => ty.clone(),
402        });
403
404        if let Ok(value) = interpreter_value::literal_to_value(&input, &type_info) {
405            // If we know the type of an unsuffixed literal, might as well change it to a suffixed literal. This way, we
406            // do not have to infer the type again in later passes of type checking.
407            if let LiteralVariant::Unsuffixed(s) = input.variant {
408                match type_info.expect("Expected type information to be available") {
409                    Type::Integer(ty) => input.variant = LiteralVariant::Integer(ty, s),
410                    Type::Field => input.variant = LiteralVariant::Field(s),
411                    Type::Group => input.variant = LiteralVariant::Group(s),
412                    Type::Scalar => input.variant = LiteralVariant::Scalar(s),
413                    _ => panic!("Type checking should have prevented this."),
414                }
415            }
416            (input.into(), Some(value))
417        } else {
418            (input.into(), None)
419        }
420    }
421
422    fn reconstruct_locator(
423        &mut self,
424        input: leo_ast::LocatorExpression,
425        _additional: &(),
426    ) -> (Expression, Self::AdditionalOutput) {
427        (input.into(), Default::default())
428    }
429
430    fn reconstruct_tuple(
431        &mut self,
432        mut input: leo_ast::TupleExpression,
433        _additional: &(),
434    ) -> (Expression, Self::AdditionalOutput) {
435        let mut values = Vec::with_capacity(input.elements.len());
436        for expr in input.elements.iter_mut() {
437            let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr), &());
438            *expr = new_expr;
439            if let Some(value) = opt_value {
440                values.push(value);
441            }
442        }
443
444        let opt_value = if values.len() == input.elements.len() { Some(Value::make_tuple(values)) } else { None };
445
446        (input.into(), opt_value)
447    }
448
449    fn reconstruct_unary(&mut self, input: UnaryExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
450        let input_id = input.id();
451        let span = input.span;
452        let (receiver, opt_value) = self.reconstruct_expression(input.receiver, &());
453
454        if let Some(value) = opt_value {
455            // We were able to evaluate the operand, so we can evaluate the expression.
456            match interpreter_value::evaluate_unary(span, input.op, &value, &self.state.type_table.get(&input_id)) {
457                Ok(new_value) => {
458                    let new_expr = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
459                    return (new_expr, Some(new_value));
460                }
461                Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
462            }
463        }
464        (UnaryExpression { receiver, ..input }.into(), None)
465    }
466
467    fn reconstruct_unit(
468        &mut self,
469        input: leo_ast::UnitExpression,
470        _additional: &(),
471    ) -> (Expression, Self::AdditionalOutput) {
472        (input.into(), None)
473    }
474
475    /* Statements */
476    fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
477        // Catching asserts at compile time is not feasible here due to control flow, but could be done in
478        // a later pass after loops are unrolled and conditionals are flattened.
479        input.variant = match input.variant {
480            AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr, &()).0),
481
482            AssertVariant::AssertEq(lhs, rhs) => AssertVariant::AssertEq(
483                self.reconstruct_expression(lhs, &()).0,
484                self.reconstruct_expression(rhs, &()).0,
485            ),
486
487            AssertVariant::AssertNeq(lhs, rhs) => AssertVariant::AssertNeq(
488                self.reconstruct_expression(lhs, &()).0,
489                self.reconstruct_expression(rhs, &()).0,
490            ),
491        };
492
493        (input.into(), None)
494    }
495
496    fn reconstruct_assign(&mut self, assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
497        let value = self.reconstruct_expression(assign.value, &()).0;
498        let place = self.reconstruct_expression(assign.place, &()).0;
499        (AssignStatement { value, place, ..assign }.into(), None)
500    }
501
502    fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
503        self.in_scope(block.id(), |slf| {
504            block.statements.retain_mut(|statement| {
505                let bogus_statement = Statement::dummy();
506                let this_statement = std::mem::replace(statement, bogus_statement);
507                *statement = slf.reconstruct_statement(this_statement).0;
508                !statement.is_empty()
509            });
510            (block, None)
511        })
512    }
513
514    fn reconstruct_conditional(
515        &mut self,
516        mut conditional: ConditionalStatement,
517    ) -> (Statement, Self::AdditionalOutput) {
518        if let Expression::Literal(lit) = &conditional.condition
519            && let LiteralVariant::Boolean(cond_lit) = lit.variant
520        {
521            if cond_lit {
522                return (Statement::Block(self.reconstruct_block(conditional.then).0), None);
523            } else if let Some(otherwise) = conditional.otherwise {
524                return self.reconstruct_statement(*otherwise);
525            } else {
526                return (Statement::dummy(), None);
527            }
528        }
529
530        conditional.condition = self.reconstruct_expression(conditional.condition, &()).0;
531        conditional.then = self.reconstruct_block(conditional.then).0;
532        if let Some(mut otherwise) = conditional.otherwise {
533            *otherwise = self.reconstruct_statement(*otherwise).0;
534            conditional.otherwise = Some(otherwise);
535        }
536
537        (Statement::Conditional(conditional), None)
538    }
539
540    fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
541        if matches!(input.type_, Type::Optional(_)) {
542            return (input.into(), None);
543        }
544
545        let span = input.span();
546
547        let type_ = self.reconstruct_type(input.type_).0;
548        let (expr, opt_value) = self.reconstruct_expression(input.value, &());
549
550        if opt_value.is_some() {
551            let path: &[Symbol] = if self.state.symbol_table.global_scope() {
552                // Then we need to insert the const with its full module-scoped path.
553                &self.module.iter().copied().chain(std::iter::once(input.place.name)).collect::<Vec<_>>()
554            } else {
555                &[input.place.name]
556            };
557            if self.state.symbol_table.lookup_const(self.program, path).is_none() {
558                // It wasn't already evaluated - insert it and record that we've made a change.
559                self.state.symbol_table.insert_const(self.program, path, expr.clone());
560                self.changed = true;
561            }
562        } else {
563            self.const_not_evaluated = Some(span);
564        }
565
566        input.type_ = type_;
567        input.value = expr;
568
569        (Statement::Const(input), None)
570    }
571
572    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
573        (
574            DefinitionStatement {
575                type_: definition.type_.map(|ty| self.reconstruct_type(ty).0),
576                value: self.reconstruct_expression(definition.value, &()).0,
577                ..definition
578            }
579            .into(),
580            None,
581        )
582    }
583
584    fn reconstruct_expression_statement(
585        &mut self,
586        mut input: ExpressionStatement,
587    ) -> (Statement, Self::AdditionalOutput) {
588        input.expression = self.reconstruct_expression(input.expression, &()).0;
589
590        if matches!(&input.expression, Expression::Unit(..) | Expression::Literal(..)) {
591            // We were able to evaluate this at compile time, but we need to get rid of this statement as
592            // we can't have expression statements that aren't calls.
593            (Statement::dummy(), Default::default())
594        } else {
595            (input.into(), Default::default())
596        }
597    }
598
599    fn reconstruct_iteration(&mut self, iteration: IterationStatement) -> (Statement, Self::AdditionalOutput) {
600        let id = iteration.id();
601        let type_ = iteration.type_.map(|ty| self.reconstruct_type(ty).0);
602        let start = self.reconstruct_expression(iteration.start, &()).0;
603        let stop = self.reconstruct_expression(iteration.stop, &()).0;
604        self.in_scope(id, |slf| {
605            (
606                IterationStatement { type_, start, stop, block: slf.reconstruct_block(iteration.block).0, ..iteration }
607                    .into(),
608                None,
609            )
610        })
611    }
612
613    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
614        (
615            ReturnStatement { expression: self.reconstruct_expression(input.expression, &()).0, ..input }.into(),
616            Default::default(),
617        )
618    }
619}