leo_passes/common_subexpression_elimination/
visitor.rs1use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27 Path(Vec<Symbol>),
28 Literal(LiteralVariant),
29}
30
31#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39 Atom(Atom),
40 Array(Vec<Atom>),
41 ArrayAccess { array: Atom, index: Atom },
42 Binary { op: BinaryOperation, left: Atom, right: Atom },
43 Repeat { value: Atom, count: Atom },
44 Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45 Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49 fn from(value: Atom) -> Self {
50 Expr::Atom(value)
51 }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56 pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60 pub state: &'a mut CompilerState,
61
62 pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66 pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67 self.scopes.push(Default::default());
68 let result = func(self);
69 self.scopes.pop();
70 result
71 }
72
73 fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76 let value = match expression {
77 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
78 Expression::Path(path) => {
79 let atom_path =
80 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
81 let expr = Expr::Atom(atom_path);
82 if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
83 *path = Path::new(
85 Vec::new(),
86 Identifier::new(*name, self.state.node_builder.next_id()),
87 true,
88 Some(vec![*name]),
89 path.span(),
90 self.state.node_builder.next_id(),
91 );
92 Atom::Path(vec![*name])
93 } else {
94 let Expr::Atom(atom_path) = expr else { unreachable!() };
95 atom_path
96 }
97 }
98
99 Expression::ArrayAccess(_)
100 | Expression::AssociatedConstant(_)
101 | Expression::AssociatedFunction(_)
102 | Expression::Async(_)
103 | Expression::Array(_)
104 | Expression::Binary(_)
105 | Expression::Call(_)
106 | Expression::Cast(_)
107 | Expression::Err(_)
108 | Expression::Locator(_)
109 | Expression::MemberAccess(_)
110 | Expression::Repeat(_)
111 | Expression::Struct(_)
112 | Expression::Ternary(_)
113 | Expression::Tuple(_)
114 | Expression::TupleAccess(_)
115 | Expression::Unary(_)
116 | Expression::Unit(_) => return None,
117 };
118
119 Some(value)
120 }
121
122 pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
130 let span = expression.span();
131 let expr: Expr = match &mut expression {
132 Expression::ArrayAccess(array_access) => {
133 let array = self.try_atom(&mut array_access.array)?;
134 let index = self.try_atom(&mut array_access.index)?;
135 Expr::ArrayAccess { array, index }
136 }
137 Expression::Array(array_expression) => {
138 let atoms = array_expression
139 .elements
140 .iter_mut()
141 .map(|elt| self.try_atom(elt))
142 .collect::<Option<Vec<Atom>>>()?;
143 Expr::Array(atoms)
144 }
145 Expression::Binary(binary_expression) => {
146 let left = self.try_atom(&mut binary_expression.left)?;
147 let right = self.try_atom(&mut binary_expression.right)?;
148 let (left, right) = if matches!(
149 binary_expression.op,
150 BinaryOperation::Add
151 | BinaryOperation::AddWrapped
152 | BinaryOperation::BitwiseAnd
153 | BinaryOperation::BitwiseOr
154 | BinaryOperation::Eq
155 | BinaryOperation::Neq
156 | BinaryOperation::Mul
157 ) && right < left
158 {
159 (right, left)
161 } else {
162 (left, right)
163 };
164 Expr::Binary { op: binary_expression.op, left, right }
165 }
166 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
167 Expression::Path(path) => {
168 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
169 }
170 Expression::Repeat(repeat_expression) => {
171 let value = self.try_atom(&mut repeat_expression.expr)?;
172 let count = self.try_atom(&mut repeat_expression.count)?;
173 Expr::Repeat { value, count }
174 }
175 Expression::Ternary(ternary_expression) => {
176 let condition = self.try_atom(&mut ternary_expression.condition)?;
177 let if_true = self.try_atom(&mut ternary_expression.if_true)?;
178 let if_false = self.try_atom(&mut ternary_expression.if_false)?;
179 Expr::Ternary { condition, if_true, if_false }
180 }
181 Expression::Unary(unary) => {
182 let receiver = self.try_atom(&mut unary.receiver)?;
183 Expr::Unary { op: unary.op, receiver }
184 }
185
186 Expression::AssociatedFunction(associated_function_expression) => {
187 for arg in &mut associated_function_expression.arguments {
188 if !matches!(arg, Expression::Locator(_)) {
189 self.try_atom(arg)?;
190 }
191 }
192 return Some((expression, false));
193 }
194
195 Expression::Call(call) => {
196 for arg in &mut call.arguments {
198 self.try_atom(arg)?;
199 }
200 return Some((expression, false));
201 }
202
203 Expression::Cast(cast) => {
204 self.try_atom(&mut cast.expression)?;
205 return Some((expression, false));
206 }
207
208 Expression::MemberAccess(member_access) => {
209 self.try_atom(&mut member_access.inner)?;
210 return Some((expression, false));
211 }
212
213 Expression::Struct(struct_expression) => {
214 for initializer in &mut struct_expression.members {
215 if let Some(expr) = initializer.expression.as_mut() {
216 self.try_atom(expr)?;
217 }
218 }
219 return Some((expression, false));
220 }
221
222 Expression::Tuple(tuple_expression) => {
223 tuple_expression.elements = tuple_expression
226 .elements
227 .drain(..)
228 .map(|expr| self.try_expr(expr, None).map(|x| x.0))
229 .collect::<Option<Vec<_>>>()?;
230 return Some((expression, false));
231 }
232
233 Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
234
235 Expression::Locator(_)
236 | Expression::Async(_)
237 | Expression::AssociatedConstant(_)
238 | Expression::Err(_)
239 | Expression::Unit(_) => {
240 return Some((expression, false));
241 }
242 };
243
244 for map in self.scopes.iter().rev() {
245 if let Some(name) = map.expressions.get(&expr).cloned() {
246 let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
248 if let Some(place) = place {
249 self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
252 return Some((identifier.into(), true));
253 }
254 return Some((identifier.into(), false));
255 }
256 }
257
258 if let Some(place) = place {
259 self.scopes.last_mut().unwrap().expressions.insert(expr, place);
261 }
262
263 Some((expression, false))
264 }
265}