leo_passes/common_subexpression_elimination/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24/// An atomic expression - path or literal.
25#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27    Path(Vec<Symbol>),
28    Literal(LiteralVariant),
29}
30
31/// An expression that can be mapped to a variable, and eliminated if it appears again.
32///
33/// For now we are rather conservative in the types of expressions we allow.
34/// We define this separate type rather than using `Expression` largely for ease
35/// of hashing and comparison while ignoring superfluous information like node ids and
36/// spans. It also makes explicit in the type that subexpressions must be atoms.
37#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39    Atom(Atom),
40    Array(Vec<Atom>),
41    ArrayAccess { array: Atom, index: Atom },
42    Binary { op: BinaryOperation, left: Atom, right: Atom },
43    Repeat { value: Atom, count: Atom },
44    Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45    Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49    fn from(value: Atom) -> Self {
50        Expr::Atom(value)
51    }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56    pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60    pub state: &'a mut CompilerState,
61
62    pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66    pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67        self.scopes.push(Default::default());
68        let result = func(self);
69        self.scopes.pop();
70        result
71    }
72
73    /// Turn `expression` into an `Atom` if possible, looking it up in the expression
74    /// tables when it's a path. Also changes `expression` into the found value.
75    fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76        // Get the ID of the expression.
77        let id = expression.id();
78        // Modify the expression in place if it's a path that can be replaced.
79        let value = match expression {
80            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
81            Expression::Path(path) => {
82                let atom_path =
83                    Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
84                let expr = Expr::Atom(atom_path);
85                if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
86                    // Get the type of the expression.
87                    let type_ = self.state.type_table.get(&id)?;
88                    // Construct a new path for this identifier.
89                    let p = Path::new(
90                        Vec::new(),
91                        Identifier::new(*name, self.state.node_builder.next_id()),
92                        true,
93                        Some(vec![*name]),
94                        path.span(),
95                        self.state.node_builder.next_id(),
96                    );
97                    // Assign the type of the path.
98                    self.state.type_table.insert(p.id(), type_);
99                    // This path is mapped to some name already, so replace it.
100                    *path = p;
101                    Atom::Path(vec![*name])
102                } else {
103                    let Expr::Atom(atom_path) = expr else { unreachable!() };
104                    atom_path
105                }
106            }
107
108            Expression::ArrayAccess(_)
109            | Expression::AssociatedConstant(_)
110            | Expression::AssociatedFunction(_)
111            | Expression::Async(_)
112            | Expression::Array(_)
113            | Expression::Binary(_)
114            | Expression::Call(_)
115            | Expression::Cast(_)
116            | Expression::Err(_)
117            | Expression::Locator(_)
118            | Expression::MemberAccess(_)
119            | Expression::Repeat(_)
120            | Expression::Struct(_)
121            | Expression::Ternary(_)
122            | Expression::Tuple(_)
123            | Expression::TupleAccess(_)
124            | Expression::Unary(_)
125            | Expression::Unit(_) => return None,
126        };
127
128        Some(value)
129    }
130
131    /// Reconstruct the expression, looking it up in the table of expressions to try to replace it with a
132    /// variable.
133    ///
134    /// - `place` If this expression is the right hand side of a definition, `place` is the left hand side,
135    ///
136    /// Returns (transformed expression, place_not_needed). `place_not_needed` is true iff it has been mapped to
137    /// another path, and thus its definition is no longer needed.
138    pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
139        let span = expression.span();
140        let expr: Expr = match &mut expression {
141            Expression::ArrayAccess(array_access) => {
142                let array = self.try_atom(&mut array_access.array)?;
143                let index = self.try_atom(&mut array_access.index)?;
144                Expr::ArrayAccess { array, index }
145            }
146            Expression::Array(array_expression) => {
147                let atoms = array_expression
148                    .elements
149                    .iter_mut()
150                    .map(|elt| self.try_atom(elt))
151                    .collect::<Option<Vec<Atom>>>()?;
152                Expr::Array(atoms)
153            }
154            Expression::Binary(binary_expression) => {
155                let left = self.try_atom(&mut binary_expression.left)?;
156                let right = self.try_atom(&mut binary_expression.right)?;
157                let (left, right) = if matches!(
158                    binary_expression.op,
159                    BinaryOperation::Add
160                        | BinaryOperation::AddWrapped
161                        | BinaryOperation::BitwiseAnd
162                        | BinaryOperation::BitwiseOr
163                        | BinaryOperation::Eq
164                        | BinaryOperation::Neq
165                        | BinaryOperation::Mul
166                ) && right < left
167                {
168                    // If it's a commutative op, order the operands in a deterministic order.
169                    (right, left)
170                } else {
171                    (left, right)
172                };
173                Expr::Binary { op: binary_expression.op, left, right }
174            }
175            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
176            Expression::Path(path) => {
177                Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
178            }
179            Expression::Repeat(repeat_expression) => {
180                let value = self.try_atom(&mut repeat_expression.expr)?;
181                let count = self.try_atom(&mut repeat_expression.count)?;
182                Expr::Repeat { value, count }
183            }
184            Expression::Ternary(ternary_expression) => {
185                let condition = self.try_atom(&mut ternary_expression.condition)?;
186                let if_true = self.try_atom(&mut ternary_expression.if_true)?;
187                let if_false = self.try_atom(&mut ternary_expression.if_false)?;
188                Expr::Ternary { condition, if_true, if_false }
189            }
190            Expression::Unary(unary) => {
191                let receiver = self.try_atom(&mut unary.receiver)?;
192                Expr::Unary { op: unary.op, receiver }
193            }
194
195            Expression::AssociatedFunction(associated_function_expression) => {
196                for arg in &mut associated_function_expression.arguments {
197                    if !matches!(arg, Expression::Locator(_)) {
198                        self.try_atom(arg)?;
199                    }
200                }
201                return Some((expression, false));
202            }
203
204            Expression::Call(call) => {
205                // Don't worry about the const expressions.
206                for arg in &mut call.arguments {
207                    self.try_atom(arg)?;
208                }
209                return Some((expression, false));
210            }
211
212            Expression::Cast(cast) => {
213                self.try_atom(&mut cast.expression)?;
214                return Some((expression, false));
215            }
216
217            Expression::MemberAccess(member_access) => {
218                self.try_atom(&mut member_access.inner)?;
219                return Some((expression, false));
220            }
221
222            Expression::Struct(struct_expression) => {
223                for initializer in &mut struct_expression.members {
224                    if let Some(expr) = initializer.expression.as_mut() {
225                        self.try_atom(expr)?;
226                    }
227                }
228                return Some((expression, false));
229            }
230
231            Expression::Tuple(tuple_expression) => {
232                // Tuple expressions only exist in return statements at this point in
233                // compilation, so we need only visit each member.
234                tuple_expression.elements = tuple_expression
235                    .elements
236                    .drain(..)
237                    .map(|expr| self.try_expr(expr, None).map(|x| x.0))
238                    .collect::<Option<Vec<_>>>()?;
239                return Some((expression, false));
240            }
241
242            Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
243
244            Expression::Locator(_)
245            | Expression::Async(_)
246            | Expression::AssociatedConstant(_)
247            | Expression::Err(_)
248            | Expression::Unit(_) => {
249                return Some((expression, false));
250            }
251        };
252
253        for map in self.scopes.iter().rev() {
254            if let Some(name) = map.expressions.get(&expr).cloned() {
255                // We already have a symbol whose value is this expression.
256                let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
257                // Get the type of the expression.
258                let type_ = self.state.type_table.get(&expression.id())?.clone();
259                // Assign the type of the new expression.
260                self.state.type_table.insert(identifier.id, type_.clone());
261                if let Some(place) = place {
262                    // We were defining a new variable, whose right hand side is already defined, so map
263                    // this variable to the previous variable.
264                    self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
265                    return Some((identifier.into(), true));
266                }
267                return Some((identifier.into(), false));
268            }
269        }
270
271        if let Some(place) = place {
272            // No variable yet refers to this expression, so map the expression to the variable.
273            self.scopes.last_mut().unwrap().expressions.insert(expr, place);
274        }
275
276        Some((expression, false))
277    }
278}