leo_passes/ssa_const_propagation/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::SsaConstPropagationVisitor;
18
19use leo_ast::{
20    interpreter_value::{self, Value},
21    *,
22};
23use leo_errors::StaticAnalyzerError;
24
25const VALUE_ERROR: &str = "A non-future value should always be able to be converted into an expression";
26
27impl AstReconstructor for SsaConstPropagationVisitor<'_> {
28    type AdditionalInput = ();
29    type AdditionalOutput = Option<Value>;
30
31    /// Reconstruct a path expression. If the path refers to a variable that has
32    /// a constant value, replace it with that constant.
33    fn reconstruct_path(&mut self, input: Path, _additional: &()) -> (Expression, Self::AdditionalOutput) {
34        // In SSA form, paths should refer to local variables (or struct members).
35        // Check if this variable has a constant value.
36        let identifier_name = input.identifier().name;
37
38        if let Some(constant_value) = self.constants.get(&identifier_name).cloned() {
39            // Replace the path with the constant value.
40            let span = input.span();
41            let id = input.id();
42            let (new_expr, _) = self.value_to_expression(&constant_value, span, id).expect(VALUE_ERROR);
43            self.changed = true;
44            (new_expr, Some(constant_value))
45        } else {
46            // No constant value for this variable, keep the path as is.
47            (input.into(), None)
48        }
49    }
50
51    /// Reconstruct a literal expression and convert it to a Value.
52    fn reconstruct_literal(&mut self, mut input: Literal, _additional: &()) -> (Expression, Self::AdditionalOutput) {
53        let type_info = self.state.type_table.get(&input.id());
54
55        // If this is an optional, then unwrap it first.
56        let type_info = type_info.as_ref().map(|ty| match ty {
57            Type::Optional(opt) => *opt.inner.clone(),
58            _ => ty.clone(),
59        });
60
61        if let Ok(value) = interpreter_value::literal_to_value(&input, &type_info) {
62            match input.variant {
63                LiteralVariant::Address(ref s) if s.ends_with("aleo") => {
64                    // Do not fold program names as the VM needs to handle them directly
65                    (input.into(), None)
66                }
67
68                // If we know the type of an unsuffixed literal, might as well change it to a suffixed literal.
69                LiteralVariant::Unsuffixed(s) => {
70                    match type_info.expect("Expected type information to be available") {
71                        Type::Integer(ty) => input.variant = LiteralVariant::Integer(ty, s),
72                        Type::Field => input.variant = LiteralVariant::Field(s),
73                        Type::Group => input.variant = LiteralVariant::Group(s),
74                        Type::Scalar => input.variant = LiteralVariant::Scalar(s),
75                        _ => panic!("Type checking should have prevented this."),
76                    }
77                    (input.into(), Some(value))
78                }
79                _ => (input.into(), Some(value)),
80            }
81        } else {
82            (input.into(), None)
83        }
84    }
85
86    /// Reconstruct a binary expression and fold it if both operands are constants.
87    fn reconstruct_binary(
88        &mut self,
89        input: BinaryExpression,
90        _additional: &(),
91    ) -> (Expression, Self::AdditionalOutput) {
92        let span = input.span();
93        let input_id = input.id();
94
95        let (left, lhs_opt_value) = self.reconstruct_expression(input.left, &());
96        let (right, rhs_opt_value) = self.reconstruct_expression(input.right, &());
97
98        if let (Some(lhs_value), Some(rhs_value)) = (lhs_opt_value, rhs_opt_value) {
99            // We were able to evaluate both operands, so we can evaluate this expression.
100            match interpreter_value::evaluate_binary(
101                span,
102                input.op,
103                &lhs_value,
104                &rhs_value,
105                &self.state.type_table.get(&input_id),
106            ) {
107                Ok(new_value) => {
108                    let (new_expr, _) = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
109                    self.changed = true;
110                    return (new_expr, Some(new_value));
111                }
112                Err(err) => self
113                    .emit_err(StaticAnalyzerError::compile_time_binary_op(lhs_value, rhs_value, input.op, err, span)),
114            }
115        }
116
117        (BinaryExpression { left, right, ..input }.into(), None)
118    }
119
120    /// Reconstruct a unary expression and fold it if the operand is a constant.
121    fn reconstruct_unary(&mut self, input: UnaryExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
122        let input_id = input.id();
123        let span = input.span;
124        let (receiver, opt_value) = self.reconstruct_expression(input.receiver, &());
125
126        if let Some(value) = opt_value {
127            // We were able to evaluate the operand, so we can evaluate the expression.
128            match interpreter_value::evaluate_unary(span, input.op, &value, &self.state.type_table.get(&input_id)) {
129                Ok(new_value) => {
130                    let (new_expr, _) = self.value_to_expression(&new_value, span, input_id).expect(VALUE_ERROR);
131                    self.changed = true;
132                    return (new_expr, Some(new_value));
133                }
134                Err(err) => self.emit_err(StaticAnalyzerError::compile_time_unary_op(value, input.op, err, span)),
135            }
136        }
137        (UnaryExpression { receiver, ..input }.into(), None)
138    }
139
140    /// Reconstruct a ternary expression and fold it if the condition is a constant.
141    fn reconstruct_ternary(
142        &mut self,
143        input: TernaryExpression,
144        _additional: &(),
145    ) -> (Expression, Self::AdditionalOutput) {
146        let (cond, cond_value) = self.reconstruct_expression(input.condition, &());
147
148        match cond_value.and_then(|v| v.try_into().ok()) {
149            Some(true) => {
150                self.changed = true;
151                self.reconstruct_expression(input.if_true, &())
152            }
153            Some(false) => {
154                self.changed = true;
155                self.reconstruct_expression(input.if_false, &())
156            }
157            _ => (
158                TernaryExpression {
159                    condition: cond,
160                    if_true: self.reconstruct_expression(input.if_true, &()).0,
161                    if_false: self.reconstruct_expression(input.if_false, &()).0,
162                    ..input
163                }
164                .into(),
165                None,
166            ),
167        }
168    }
169
170    /// Reconstruct an array access expression and fold it if array and index are constants.
171    fn reconstruct_array_access(
172        &mut self,
173        input: ArrayAccess,
174        _additional: &(),
175    ) -> (Expression, Self::AdditionalOutput) {
176        let span = input.span();
177        let id = input.id();
178
179        let (array, array_opt) = self.reconstruct_expression(input.array, &());
180        let (index, index_opt) = self.reconstruct_expression(input.index, &());
181
182        if let Some(index_value) = index_opt
183            && let Some(array_value) = array_opt
184        {
185            let result_value =
186                array_value.array_index(index_value.as_u32().unwrap() as usize).expect("We already checked bounds.");
187            self.changed = true;
188            let (new_expr, _) = self.value_to_expression(&result_value, span, id).expect(VALUE_ERROR);
189            return (new_expr, Some(result_value.clone()));
190        }
191
192        (ArrayAccess { array, index, ..input }.into(), None)
193    }
194
195    /// Reconstruct an array expression and fold it if all elements are constants.
196    fn reconstruct_array(
197        &mut self,
198        mut input: ArrayExpression,
199        _additional: &(),
200    ) -> (Expression, Self::AdditionalOutput) {
201        let mut values = Vec::new();
202        let mut elements_changed = false;
203        input.elements.iter_mut().for_each(|element| {
204            let old_element = element.clone();
205            let (new_element, value_opt) = self.reconstruct_expression(std::mem::take(element), &());
206            // Check if the element actually changed (not just its structure, but if it's a different expression)
207            if old_element.id() != new_element.id() {
208                elements_changed = true;
209            }
210            if let Some(value) = value_opt {
211                values.push(value);
212            }
213            *element = new_element;
214        });
215        // Only set changed if elements actually changed. Don't set changed just because
216        // we can evaluate the array - that would cause an infinite loop since the array
217        // expression structure doesn't change.
218        if elements_changed {
219            self.changed = true;
220        }
221
222        if values.len() == input.elements.len() {
223            (input.into(), Some(Value::make_array(values.into_iter())))
224        } else {
225            (input.into(), None)
226        }
227    }
228
229    /// Reconstruct a tuple expression and fold it if all elements are constants.
230    fn reconstruct_tuple(
231        &mut self,
232        mut input: TupleExpression,
233        _additional: &(),
234    ) -> (Expression, Self::AdditionalOutput) {
235        let mut values = Vec::with_capacity(input.elements.len());
236        let mut elements_changed = false;
237        for expr in input.elements.iter_mut() {
238            let old_expr = expr.clone();
239            let (new_expr, opt_value) = self.reconstruct_expression(std::mem::take(expr), &());
240            // Check if the element actually changed
241            if old_expr.id() != new_expr.id() {
242                elements_changed = true;
243            }
244            *expr = new_expr;
245            if let Some(value) = opt_value {
246                values.push(value);
247            }
248        }
249
250        // Only set changed if elements actually changed. Don't set changed just because
251        // we can evaluate the tuple - that would cause an infinite loop since the tuple
252        // expression structure doesn't change.
253        if elements_changed {
254            self.changed = true;
255        }
256
257        let opt_value = if values.len() == input.elements.len() { Some(Value::make_tuple(values)) } else { None };
258
259        (input.into(), opt_value)
260    }
261
262    /* Statements */
263    /// Reconstruct a definition statement. If the RHS evaluates to a constant, track it
264    /// in the constants map for propagation.
265    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
266        // Reconstruct the RHS expression first.
267        let (new_value, opt_value) = self.reconstruct_expression(input.value, &());
268
269        if let Some(value) = opt_value {
270            match &input.place {
271                DefinitionPlace::Single(identifier) => {
272                    self.constants.insert(identifier.name, value);
273                }
274                DefinitionPlace::Multiple(identifiers) => {
275                    for (i, id) in identifiers.iter().enumerate() {
276                        if let Some(v) = value.tuple_index(i) {
277                            self.constants.insert(id.name, v);
278                        }
279                    }
280                }
281            }
282        }
283
284        input.value = new_value;
285
286        (input.into(), None)
287    }
288
289    fn reconstruct_assign(&mut self, _input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
290        panic!("there should be no assignments at this stage");
291    }
292}