leo_passes/common_subexpression_elimination/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24/// An atomic expression - path or literal.
25#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27    Path(Vec<Symbol>),
28    Literal(LiteralVariant),
29}
30
31/// An expression that can be mapped to a variable, and eliminated if it appears again.
32///
33/// For now we are rather conservative in the types of expressions we allow.
34/// We define this separate type rather than using `Expression` largely for ease
35/// of hashing and comparison while ignoring superfluous information like node ids and
36/// spans. It also makes explicit in the type that subexpressions must be atoms.
37#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39    Atom(Atom),
40    Array(Vec<Atom>),
41    ArrayAccess { array: Atom, index: Atom },
42    Binary { op: BinaryOperation, left: Atom, right: Atom },
43    Repeat { value: Atom, count: Atom },
44    Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45    Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49    fn from(value: Atom) -> Self {
50        Expr::Atom(value)
51    }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56    pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60    pub state: &'a mut CompilerState,
61
62    pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66    pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67        self.scopes.push(Default::default());
68        let result = func(self);
69        self.scopes.pop();
70        result
71    }
72
73    /// Turn `expression` into an `Atom` if possible, looking it up in the expression
74    /// tables when it's a path. Also changes `expression` into the found value.
75    fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76        // Get the ID of the expression.
77        let id = expression.id();
78        // Modify the expression in place if it's a path that can be replaced.
79        let value = match expression {
80            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
81            Expression::Path(path) => {
82                let atom_path =
83                    Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
84                let expr = Expr::Atom(atom_path);
85                if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
86                    // Get the type of the expression.
87                    let type_ = self.state.type_table.get(&id)?;
88                    // Construct a new path for this identifier.
89                    let p = Path::new(
90                        Vec::new(),
91                        Identifier::new(*name, self.state.node_builder.next_id()),
92                        true,
93                        Some(vec![*name]),
94                        path.span(),
95                        self.state.node_builder.next_id(),
96                    );
97                    // Assign the type of the path.
98                    self.state.type_table.insert(p.id(), type_);
99                    // This path is mapped to some name already, so replace it.
100                    *path = p;
101                    Atom::Path(vec![*name])
102                } else {
103                    let Expr::Atom(atom_path) = expr else { unreachable!() };
104                    atom_path
105                }
106            }
107
108            Expression::ArrayAccess(_)
109            | Expression::Intrinsic(_)
110            | Expression::Async(_)
111            | Expression::Array(_)
112            | Expression::Binary(_)
113            | Expression::Call(_)
114            | Expression::Cast(_)
115            | Expression::Err(_)
116            | Expression::Locator(_)
117            | Expression::MemberAccess(_)
118            | Expression::Repeat(_)
119            | Expression::Struct(_)
120            | Expression::Ternary(_)
121            | Expression::Tuple(_)
122            | Expression::TupleAccess(_)
123            | Expression::Unary(_)
124            | Expression::Unit(_) => return None,
125        };
126
127        Some(value)
128    }
129
130    /// Reconstruct the expression, looking it up in the table of expressions to try to replace it with a
131    /// variable.
132    ///
133    /// - `place` If this expression is the right hand side of a definition, `place` is the left hand side,
134    ///
135    /// Returns (transformed expression, place_not_needed). `place_not_needed` is true iff it has been mapped to
136    /// another path, and thus its definition is no longer needed.
137    pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
138        let span = expression.span();
139        let expr: Expr = match &mut expression {
140            Expression::ArrayAccess(array_access) => {
141                let array = self.try_atom(&mut array_access.array)?;
142                let index = self.try_atom(&mut array_access.index)?;
143                Expr::ArrayAccess { array, index }
144            }
145            Expression::Array(array_expression) => {
146                let atoms = array_expression
147                    .elements
148                    .iter_mut()
149                    .map(|elt| self.try_atom(elt))
150                    .collect::<Option<Vec<Atom>>>()?;
151                Expr::Array(atoms)
152            }
153            Expression::Binary(binary_expression) => {
154                let left = self.try_atom(&mut binary_expression.left)?;
155                let right = self.try_atom(&mut binary_expression.right)?;
156                let (left, right) = if matches!(
157                    binary_expression.op,
158                    BinaryOperation::Add
159                        | BinaryOperation::AddWrapped
160                        | BinaryOperation::BitwiseAnd
161                        | BinaryOperation::BitwiseOr
162                        | BinaryOperation::Eq
163                        | BinaryOperation::Neq
164                        | BinaryOperation::Mul
165                ) && right < left
166                {
167                    // If it's a commutative op, order the operands in a deterministic order.
168                    (right, left)
169                } else {
170                    (left, right)
171                };
172                Expr::Binary { op: binary_expression.op, left, right }
173            }
174            Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
175            Expression::Path(path) => {
176                Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
177            }
178            Expression::Repeat(repeat_expression) => {
179                let value = self.try_atom(&mut repeat_expression.expr)?;
180                let count = self.try_atom(&mut repeat_expression.count)?;
181                Expr::Repeat { value, count }
182            }
183            Expression::Ternary(ternary_expression) => {
184                let condition = self.try_atom(&mut ternary_expression.condition)?;
185                let if_true = self.try_atom(&mut ternary_expression.if_true)?;
186                let if_false = self.try_atom(&mut ternary_expression.if_false)?;
187                Expr::Ternary { condition, if_true, if_false }
188            }
189            Expression::Unary(unary) => {
190                let receiver = self.try_atom(&mut unary.receiver)?;
191                Expr::Unary { op: unary.op, receiver }
192            }
193
194            Expression::Intrinsic(intrinsic) => {
195                for arg in &mut intrinsic.arguments {
196                    if !matches!(arg, Expression::Locator(_)) {
197                        self.try_atom(arg)?;
198                    }
199                }
200                return Some((expression, false));
201            }
202
203            Expression::Call(call) => {
204                // Don't worry about the const expressions.
205                for arg in &mut call.arguments {
206                    self.try_atom(arg)?;
207                }
208                return Some((expression, false));
209            }
210
211            Expression::Cast(cast) => {
212                self.try_atom(&mut cast.expression)?;
213                return Some((expression, false));
214            }
215
216            Expression::MemberAccess(member_access) => {
217                self.try_atom(&mut member_access.inner)?;
218                return Some((expression, false));
219            }
220
221            Expression::Struct(struct_expression) => {
222                for initializer in &mut struct_expression.members {
223                    if let Some(expr) = initializer.expression.as_mut() {
224                        self.try_atom(expr)?;
225                    }
226                }
227                return Some((expression, false));
228            }
229
230            Expression::Tuple(tuple_expression) => {
231                // Tuple expressions only exist in return statements at this point in
232                // compilation, so we need only visit each member.
233                tuple_expression.elements = tuple_expression
234                    .elements
235                    .drain(..)
236                    .map(|expr| self.try_expr(expr, None).map(|x| x.0))
237                    .collect::<Option<Vec<_>>>()?;
238                return Some((expression, false));
239            }
240
241            Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
242
243            Expression::Locator(_) | Expression::Async(_) | Expression::Err(_) | Expression::Unit(_) => {
244                return Some((expression, false));
245            }
246        };
247
248        for map in self.scopes.iter().rev() {
249            if let Some(name) = map.expressions.get(&expr).cloned() {
250                // We already have a symbol whose value is this expression.
251                let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
252                // Get the type of the expression.
253                let type_ = self.state.type_table.get(&expression.id())?.clone();
254                // Assign the type of the new expression.
255                self.state.type_table.insert(identifier.id, type_.clone());
256                if let Some(place) = place {
257                    // We were defining a new variable, whose right hand side is already defined, so map
258                    // this variable to the previous variable.
259                    self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
260                    return Some((identifier.into(), true));
261                }
262                return Some((identifier.into(), false));
263            }
264        }
265
266        if let Some(place) = place {
267            // No variable yet refers to this expression, so map the expression to the variable.
268            self.scopes.last_mut().unwrap().expressions.insert(expr, place);
269        }
270
271        Some((expression, false))
272    }
273}