leo_passes/const_propagation/
expression.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{
18    ArrayAccess,
19    BinaryExpression,
20    CastExpression,
21    CoreFunction,
22    Expression,
23    ExpressionReconstructor,
24    LiteralVariant,
25    MemberAccess,
26    Node,
27    RepeatExpression,
28    StructExpression,
29    TernaryExpression,
30    TupleAccess,
31    Type,
32    UnaryExpression,
33    interpreter_value::{self, StructContents, Value},
34};
35use leo_errors::StaticAnalyzerError;
36use leo_span::sym;
37
38use super::{ConstPropagationVisitor, value_to_expression};
39
40const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
41
42impl ExpressionReconstructor for ConstPropagationVisitor<'_> {
43    type AdditionalOutput = Option<Value>;
44
45    fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
46        let old_id = input.id();
47        let (new_expr, opt_value) = match input {
48            Expression::Array(array) => self.reconstruct_array(array),
49            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
50            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
51            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
52            Expression::Binary(binary) => self.reconstruct_binary(*binary),
53            Expression::Call(call) => self.reconstruct_call(*call),
54            Expression::Cast(cast) => self.reconstruct_cast(*cast),
55            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
56            Expression::Err(err) => self.reconstruct_err(err),
57            Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
58            Expression::Literal(value) => self.reconstruct_literal(value),
59            Expression::Locator(locator) => self.reconstruct_locator(locator),
60            Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
61            Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
62            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
63            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
64            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
65            Expression::Unary(unary) => self.reconstruct_unary(*unary),
66            Expression::Unit(unit) => self.reconstruct_unit(unit),
67        };
68
69        if old_id != new_expr.id() {
70            self.changed = true;
71            let old_type =
72                self.state.type_table.get(&old_id).expect("Type checking guarantees that all expressions have a type.");
73            self.state.type_table.insert(new_expr.id(), old_type);
74        }
75
76        (new_expr, opt_value)
77    }
78
79    fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
80        let mut values = Vec::new();
81        for member in input.members.iter_mut() {
82            if let Some(expr) = std::mem::take(&mut member.expression) {
83                let (new_expr, value_opt) = self.reconstruct_expression(expr);
84                member.expression = Some(new_expr);
85                if let Some(value) = value_opt {
86                    values.push(value);
87                }
88            }
89        }
90        if values.len() == input.members.len() {
91            let value = Value::Struct(StructContents {
92                name: input.name.name,
93                contents: input.members.iter().map(|mem| mem.identifier.name).zip(values).collect(),
94            });
95            (input.into(), Some(value))
96        } else {
97            (input.into(), None)
98        }
99    }
100
101    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
102        let (cond, cond_value) = self.reconstruct_expression(input.condition);
103
104        match cond_value {
105            Some(Value::Bool(true)) => self.reconstruct_expression(input.if_true),
106            Some(Value::Bool(false)) => self.reconstruct_expression(input.if_false),
107            _ => (
108                TernaryExpression {
109                    condition: cond,
110                    if_true: self.reconstruct_expression(input.if_true).0,
111                    if_false: self.reconstruct_expression(input.if_false).0,
112                    ..input
113                }
114                .into(),
115                None,
116            ),
117        }
118    }
119
120    fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
121        let span = input.span();
122        let array_id = input.array.id();
123        let (array, value_opt) = self.reconstruct_expression(input.array);
124        let (index, opt_value) = self.reconstruct_expression(input.index);
125        if let Some(value) = opt_value {
126            // We can perform compile time bounds checking.
127
128            let ty = self.state.type_table.get(&array_id);
129            let Some(Type::Array(array_ty)) = ty else {
130                panic!("Type checking guaranteed that this is an array.");
131            };
132            let len = array_ty.length.as_u32();
133
134            if let Some(len) = len {
135                let index: u32 = match value {
136                    Value::U8(x) => x as u32,
137                    Value::U16(x) => x as u32,
138                    Value::U32(x) => x,
139                    Value::U64(x) => x.try_into().unwrap_or(len),
140                    Value::U128(x) => x.try_into().unwrap_or(len),
141                    Value::I8(x) => x.try_into().unwrap_or(len),
142                    Value::I16(x) => x.try_into().unwrap_or(len),
143                    Value::I32(x) => x.try_into().unwrap_or(len),
144                    Value::I64(x) => x.try_into().unwrap_or(len),
145                    Value::I128(x) => x.try_into().unwrap_or(len),
146                    _ => panic!("Type checking guarantees this is an integer"),
147                };
148
149                if index >= len {
150                    // Only emit a bounds error if we have no other errors yet.
151                    // This prevents a chain of redundant error messages when a loop is unrolled.
152                    if !self.state.handler.had_errors() {
153                        // Get the integer string with no suffix.
154                        let str_index = match value {
155                            Value::U8(x) => format!("{x}"),
156                            Value::U16(x) => format!("{x}"),
157                            Value::U32(x) => format!("{x}"),
158                            Value::U64(x) => format!("{x}"),
159                            Value::U128(x) => format!("{x}"),
160                            Value::I8(x) => format!("{x}"),
161                            Value::I16(x) => format!("{x}"),
162                            Value::I32(x) => format!("{x}"),
163                            Value::I64(x) => format!("{x}"),
164                            Value::I128(x) => format!("{x}"),
165                            _ => unreachable!("We would have panicked above"),
166                        };
167
168                        self.emit_err(StaticAnalyzerError::array_bounds(str_index, len, span));
169                    }
170                } else if let Some(Value::Array(value)) = value_opt {
171                    // We're in bounds and we can evaluate the array at compile time, so just return the value.
172                    let result_value = value.get(index as usize).expect("We already checked bounds.");
173                    return (
174                        value_to_expression(result_value, input.span, &self.state.node_builder).expect(VALUE_ERROR),
175                        Some(result_value.clone()),
176                    );
177                }
178            }
179        } else {
180            self.array_index_not_evaluated = Some(index.span());
181        }
182        (ArrayAccess { array, index, ..input }.into(), None)
183    }
184
185    fn reconstruct_associated_constant(
186        &mut self,
187        input: leo_ast::AssociatedConstantExpression,
188    ) -> (Expression, Self::AdditionalOutput) {
189        // Currently there is only one associated constant.
190        let generator = Value::generator();
191        let expr = value_to_expression(&generator, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
192        (expr, Some(generator))
193    }
194
195    fn reconstruct_associated_function(
196        &mut self,
197        mut input: leo_ast::AssociatedFunctionExpression,
198    ) -> (Expression, Self::AdditionalOutput) {
199        let mut values = Vec::new();
200        for argument in input.arguments.iter_mut() {
201            let (new_argument, opt_value) = self.reconstruct_expression(std::mem::take(argument));
202            *argument = new_argument;
203            if let Some(value) = opt_value {
204                values.push(value);
205            }
206        }
207
208        if values.len() == input.arguments.len() && !matches!(input.variant.name, sym::CheatCode | sym::Mapping) {
209            // We've evaluated every argument, and this function isn't a cheat code or mapping
210            // operation, so maybe we can compute the result at compile time.
211            let core_function = CoreFunction::from_symbols(input.variant.name, input.name.name)
212                .expect("Type checking guarantees this is valid.");
213
214            match interpreter_value::evaluate_core_function(&mut values, core_function, &[], input.span()) {
215                Ok(Some(value)) => {
216                    // Successful evaluation.
217                    let expr = value_to_expression(&value, input.span(), &self.state.node_builder).expect(VALUE_ERROR);
218                    return (expr, Some(value));
219                }
220                Ok(None) =>
221                    // No errors, but we were unable to evaluate.
222                    {}
223                Err(err) => {
224                    self.emit_err(StaticAnalyzerError::compile_core_function(err, input.span()));
225                }
226            }
227        }
228
229        (input.into(), Default::default())
230    }
231
232    fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
233        let span = input.span();
234        let (inner, value_opt) = self.reconstruct_expression(input.inner);
235        let member_name = input.name.name;
236        if let Some(Value::Struct(contents)) = value_opt {
237            let value_result =
238                contents.contents.get(&member_name).expect("Type checking guarantees the member exists.");
239
240            (
241                value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
242                Some(value_result.clone()),
243            )
244        } else {
245            (MemberAccess { inner, ..input }.into(), None)
246        }
247    }
248
249    fn reconstruct_repeat(&mut self, input: leo_ast::RepeatExpression) -> (Expression, Self::AdditionalOutput) {
250        let (expr, expr_value) = self.reconstruct_expression(input.expr.clone());
251        let (count, count_value) = self.reconstruct_expression(input.count.clone());
252
253        if count_value.is_none() {
254            self.repeat_count_not_evaluated = Some(count.span());
255        }
256
257        match (expr_value, count.as_u32()) {
258            (Some(value), Some(count_u32)) => {
259                (RepeatExpression { expr, count, ..input }.into(), Some(Value::Array(vec![value; count_u32 as usize])))
260            }
261            _ => (RepeatExpression { expr, count, ..input }.into(), None),
262        }
263    }
264
265    fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
266        let span = input.span();
267        let (tuple, value_opt) = self.reconstruct_expression(input.tuple);
268        if let Some(Value::Tuple(tuple)) = value_opt {
269            let value_result = tuple.get(input.index.value()).expect("Type checking checked bounds.");
270            (
271                value_to_expression(value_result, span, &self.state.node_builder).expect(VALUE_ERROR),
272                Some(value_result.clone()),
273            )
274        } else {
275            (TupleAccess { tuple, ..input }.into(), None)
276        }
277    }
278
279    fn reconstruct_array(&mut self, mut input: leo_ast::ArrayExpression) -> (Expression, Self::AdditionalOutput) {
280        let mut values = Vec::new();
281        input.elements.iter_mut().for_each(|element| {
282            let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element));
283            if let Some(value) = value_opt {
284                values.push(value);
285            }
286            *element = new_element;
287        });
288        if values.len() == input.elements.len() {
289            (input.into(), Some(Value::Array(values)))
290        } else {
291            (input.into(), None)
292        }
293    }
294
295    fn reconstruct_binary(&mut self, input: leo_ast::BinaryExpression) -> (Expression, Self::AdditionalOutput) {
296        let span = input.span();
297
298        let (left, lhs_opt_value) = self.reconstruct_expression(input.left);
299        let (right, rhs_opt_value) = self.reconstruct_expression(input.right);
300
301        if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
302            // We were able to evaluate both operands, so we can evaluate this expression.
303            match interpreter_value::evaluate_binary(span, input.op, &lhs_value, &rhs_value) {
304                Ok(new_value) => {
305                    let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
306                    return (new_expr, Some(new_value));
307                }
308                Err(err) => self
309                    .emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
310            }
311        }
312
313        (BinaryExpression { left, right, ..input }.into(), None)
314    }
315
316    fn reconstruct_call(&mut self, mut input: leo_ast::CallExpression) -> (Expression, Self::AdditionalOutput) {
317        input.const_arguments.iter_mut().for_each(|arg| {
318            *arg = self.reconstruct_expression(std::mem::take(arg)).0;
319        });
320        input.arguments.iter_mut().for_each(|arg| {
321            *arg = self.reconstruct_expression(std::mem::take(arg)).0;
322        });
323        (input.into(), Default::default())
324    }
325
326    fn reconstruct_cast(&mut self, input: leo_ast::CastExpression) -> (Expression, Self::AdditionalOutput) {
327        let span = input.span();
328
329        let (expr, opt_value) = self.reconstruct_expression(input.expression);
330
331        if let Some(value) = opt_value {
332            if let Some(cast_value) = value.cast(&input.type_) {
333                let expr = value_to_expression(&cast_value, span, &self.state.node_builder).expect(VALUE_ERROR);
334                return (expr, Some(cast_value));
335            } else {
336                self.emit_err(StaticAnalyzerError::compile_time_cast(value, &input.type_, span));
337            }
338        }
339        (CastExpression { expression: expr, ..input }.into(), None)
340    }
341
342    fn reconstruct_err(&mut self, _input: leo_ast::ErrExpression) -> (Expression, Self::AdditionalOutput) {
343        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
344    }
345
346    fn reconstruct_identifier(&mut self, input: leo_ast::Identifier) -> (Expression, Self::AdditionalOutput) {
347        // Substitute the identifier with the constant value if it is a constant that's been evaluated.
348        if let Some(expression) = self.state.symbol_table.lookup_const(self.program, input.name) {
349            let (expression, opt_value) = self.reconstruct_expression(expression);
350            if opt_value.is_some() {
351                return (expression, opt_value);
352            }
353        }
354
355        (input.into(), None)
356    }
357
358    fn reconstruct_literal(&mut self, mut input: leo_ast::Literal) -> (Expression, Self::AdditionalOutput) {
359        let type_info = self.state.type_table.get(&input.id());
360
361        let value =
362            interpreter_value::literal_to_value(&input, &type_info).expect("Failed to convert literal to value");
363
364        // If we know the type of an unsuffixed literal, might as well change it to a suffixed literal. This way, we
365        // do not have to infer the type again in later passes of type checking.
366        if let LiteralVariant::Unsuffixed(s) = input.variant {
367            match type_info.expect("Expected type information to be available") {
368                Type::Integer(ty) => input.variant = LiteralVariant::Integer(ty, s),
369                Type::Field => input.variant = LiteralVariant::Field(s),
370                Type::Group => input.variant = LiteralVariant::Group(s),
371                Type::Scalar => input.variant = LiteralVariant::Scalar(s),
372                _ => panic!("Type checking should have prevented this."),
373            }
374        }
375        (input.into(), Some(value))
376    }
377
378    fn reconstruct_locator(&mut self, input: leo_ast::LocatorExpression) -> (Expression, Self::AdditionalOutput) {
379        (input.into(), Default::default())
380    }
381
382    fn reconstruct_tuple(&mut self, mut input: leo_ast::TupleExpression) -> (Expression, Self::AdditionalOutput) {
383        let mut values = Vec::with_capacity(input.elements.len());
384        for expr in input.elements.iter_mut() {
385            let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr));
386            *expr = new_expr;
387            if let Some(value) = opt_value {
388                values.push(value);
389            }
390        }
391
392        let opt_value = if values.len() == input.elements.len() { Some(Value::Tuple(values)) } else { None };
393
394        (input.into(), opt_value)
395    }
396
397    fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
398        let (receiver, opt_value) = self.reconstruct_expression(input.receiver);
399        let span = input.span;
400
401        if let Some(value) = opt_value {
402            // We were able to evaluate the operand, so we can evaluate the expression.
403            match interpreter_value::evaluate_unary(span, input.op, &value) {
404                Ok(new_value) => {
405                    let new_expr = value_to_expression(&new_value, span, &self.state.node_builder).expect(VALUE_ERROR);
406                    return (new_expr, Some(new_value));
407                }
408                Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
409            }
410        }
411        (UnaryExpression { receiver, ..input }.into(), None)
412    }
413
414    fn reconstruct_unit(&mut self, input: leo_ast::UnitExpression) -> (Expression, Self::AdditionalOutput) {
415        (input.into(), None)
416    }
417}