leo_passes/common_subexpression_elimination/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24/// An atomic expression - path or literal.
25#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27    Path(Vec<Symbol>),
28    Literal(LiteralVariant),
29}
30
31/// An expression that can be mapped to a variable, and eliminated if it appears again.
32///
33/// For now we are rather conservative in the types of expressions we allow.
34/// We define this separate type rather than using `Expression` largely for ease
35/// of hashing and comparison while ignoring superfluous information like node ids and
36/// spans. It also makes explicit in the type that subexpressions must be atoms.
37#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39    Atom(Atom),
40    Array(Vec<Atom>),
41    ArrayAccess { array: Atom, index: Atom },
42    Binary { op: BinaryOperation, left: Atom, right: Atom },
43    Repeat { value: Atom, count: Atom },
44    Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45    Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49    fn from(value: Atom) -> Self {
50        Expr::Atom(value)
51    }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56    pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60    pub state: &'a mut CompilerState,
61
62    pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66    pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67        self.scopes.push(Default::default());
68        let result = func(self);
69        self.scopes.pop();
70        result
71    }
72
73    /// Turn `expression` into an `Atom` if possible, looking it up in the expression
74    /// tables when it's a path. Also changes `expression` into the found value.
75    fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76        let value = match expression {
77            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
78            Expression::Path(path) => {
79                let atom_path =
80                    Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
81                let expr = Expr::Atom(atom_path);
82                if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
83                    // This path is mapped to some name already, so replace it.
84                    *path = Path::new(
85                        Vec::new(),
86                        Identifier::new(*name, self.state.node_builder.next_id()),
87                        true,
88                        Some(vec![*name]),
89                        path.span(),
90                        self.state.node_builder.next_id(),
91                    );
92                    Atom::Path(vec![*name])
93                } else {
94                    let Expr::Atom(atom_path) = expr else { unreachable!() };
95                    atom_path
96                }
97            }
98
99            Expression::ArrayAccess(_)
100            | Expression::AssociatedConstant(_)
101            | Expression::AssociatedFunction(_)
102            | Expression::Async(_)
103            | Expression::Array(_)
104            | Expression::Binary(_)
105            | Expression::Call(_)
106            | Expression::Cast(_)
107            | Expression::Err(_)
108            | Expression::Locator(_)
109            | Expression::MemberAccess(_)
110            | Expression::Repeat(_)
111            | Expression::Struct(_)
112            | Expression::Ternary(_)
113            | Expression::Tuple(_)
114            | Expression::TupleAccess(_)
115            | Expression::Unary(_)
116            | Expression::Unit(_) => return None,
117        };
118
119        Some(value)
120    }
121
122    /// Reconstruct the expression, looking it up in the table of expressions to try to replace it with a
123    /// variable.
124    ///
125    /// - `place` If this expression is the right hand side of a definition, `place` is the left hand side,
126    ///
127    /// Returns (transformed expression, place_not_needed). `place_not_needed` is true iff it has been mapped to
128    /// another path, and thus its definition is no longer needed.
129    pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
130        let span = expression.span();
131        let expr: Expr = match &mut expression {
132            Expression::ArrayAccess(array_access) => {
133                let array = self.try_atom(&mut array_access.array)?;
134                let index = self.try_atom(&mut array_access.index)?;
135                Expr::ArrayAccess { array, index }
136            }
137            Expression::Array(array_expression) => {
138                let atoms = array_expression
139                    .elements
140                    .iter_mut()
141                    .map(|elt| self.try_atom(elt))
142                    .collect::<Option<Vec<Atom>>>()?;
143                Expr::Array(atoms)
144            }
145            Expression::Binary(binary_expression) => {
146                let left = self.try_atom(&mut binary_expression.left)?;
147                let right = self.try_atom(&mut binary_expression.right)?;
148                let (left, right) = if matches!(
149                    binary_expression.op,
150                    BinaryOperation::Add
151                        | BinaryOperation::AddWrapped
152                        | BinaryOperation::BitwiseAnd
153                        | BinaryOperation::BitwiseOr
154                        | BinaryOperation::Eq
155                        | BinaryOperation::Neq
156                        | BinaryOperation::Mul
157                ) && right < left
158                {
159                    // If it's a commutative op, order the operands in a deterministic order.
160                    (right, left)
161                } else {
162                    (left, right)
163                };
164                Expr::Binary { op: binary_expression.op, left, right }
165            }
166            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
167            Expression::Path(path) => {
168                Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
169            }
170            Expression::Repeat(repeat_expression) => {
171                let value = self.try_atom(&mut repeat_expression.expr)?;
172                let count = self.try_atom(&mut repeat_expression.count)?;
173                Expr::Repeat { value, count }
174            }
175            Expression::Ternary(ternary_expression) => {
176                let condition = self.try_atom(&mut ternary_expression.condition)?;
177                let if_true = self.try_atom(&mut ternary_expression.if_true)?;
178                let if_false = self.try_atom(&mut ternary_expression.if_false)?;
179                Expr::Ternary { condition, if_true, if_false }
180            }
181            Expression::Unary(unary) => {
182                let receiver = self.try_atom(&mut unary.receiver)?;
183                Expr::Unary { op: unary.op, receiver }
184            }
185
186            Expression::AssociatedFunction(associated_function_expression) => {
187                for arg in &mut associated_function_expression.arguments {
188                    if !matches!(arg, Expression::Locator(_)) {
189                        self.try_atom(arg)?;
190                    }
191                }
192                return Some((expression, false));
193            }
194
195            Expression::Call(call) => {
196                // Don't worry about the const expressions.
197                for arg in &mut call.arguments {
198                    self.try_atom(arg)?;
199                }
200                return Some((expression, false));
201            }
202
203            Expression::Cast(cast) => {
204                self.try_atom(&mut cast.expression)?;
205                return Some((expression, false));
206            }
207
208            Expression::MemberAccess(member_access) => {
209                self.try_atom(&mut member_access.inner)?;
210                return Some((expression, false));
211            }
212
213            Expression::Struct(struct_expression) => {
214                for initializer in &mut struct_expression.members {
215                    if let Some(expr) = initializer.expression.as_mut() {
216                        self.try_atom(expr)?;
217                    }
218                }
219                return Some((expression, false));
220            }
221
222            Expression::Tuple(tuple_expression) => {
223                // Tuple expressions only exist in return statements at this point in
224                // compilation, so we need only visit each member.
225                tuple_expression.elements = tuple_expression
226                    .elements
227                    .drain(..)
228                    .map(|expr| self.try_expr(expr, None).map(|x| x.0))
229                    .collect::<Option<Vec<_>>>()?;
230                return Some((expression, false));
231            }
232
233            Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
234
235            Expression::Locator(_)
236            | Expression::Async(_)
237            | Expression::AssociatedConstant(_)
238            | Expression::Err(_)
239            | Expression::Unit(_) => {
240                return Some((expression, false));
241            }
242        };
243
244        for map in self.scopes.iter().rev() {
245            if let Some(name) = map.expressions.get(&expr).cloned() {
246                // We already have a symbol whose value is this expression.
247                let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
248                if let Some(place) = place {
249                    // We were defining a new variable, whose right hand side is already defined, so map
250                    // this variable to the previous variable.
251                    self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
252                    return Some((identifier.into(), true));
253                }
254                return Some((identifier.into(), false));
255            }
256        }
257
258        if let Some(place) = place {
259            // No variable yet refers to this expression, so map the expression to the variable.
260            self.scopes.last_mut().unwrap().expressions.insert(expr, place);
261        }
262
263        Some((expression, false))
264    }
265}