leo_passes/common_subexpression_elimination/
visitor.rs1use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27 Path(Vec<Symbol>),
28 Literal(LiteralVariant),
29}
30
31#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39 Atom(Atom),
40 Array(Vec<Atom>),
41 ArrayAccess { array: Atom, index: Atom },
42 Binary { op: BinaryOperation, left: Atom, right: Atom },
43 Repeat { value: Atom, count: Atom },
44 Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45 Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49 fn from(value: Atom) -> Self {
50 Expr::Atom(value)
51 }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56 pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60 pub state: &'a mut CompilerState,
61
62 pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66 pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67 self.scopes.push(Default::default());
68 let result = func(self);
69 self.scopes.pop();
70 result
71 }
72
73 fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76 let id = expression.id();
78 let value = match expression {
80 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
81 Expression::Path(path) => {
82 let atom_path =
83 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
84 let expr = Expr::Atom(atom_path);
85 if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
86 let type_ = self.state.type_table.get(&id)?;
88 let p = Path::new(
90 Vec::new(),
91 Identifier::new(*name, self.state.node_builder.next_id()),
92 true,
93 Some(vec![*name]),
94 path.span(),
95 self.state.node_builder.next_id(),
96 );
97 self.state.type_table.insert(p.id(), type_);
99 *path = p;
101 Atom::Path(vec![*name])
102 } else {
103 let Expr::Atom(atom_path) = expr else { unreachable!() };
104 atom_path
105 }
106 }
107
108 Expression::ArrayAccess(_)
109 | Expression::AssociatedConstant(_)
110 | Expression::AssociatedFunction(_)
111 | Expression::Async(_)
112 | Expression::Array(_)
113 | Expression::Binary(_)
114 | Expression::Call(_)
115 | Expression::Cast(_)
116 | Expression::Err(_)
117 | Expression::Locator(_)
118 | Expression::MemberAccess(_)
119 | Expression::Repeat(_)
120 | Expression::Struct(_)
121 | Expression::Ternary(_)
122 | Expression::Tuple(_)
123 | Expression::TupleAccess(_)
124 | Expression::Unary(_)
125 | Expression::Unit(_) => return None,
126 };
127
128 Some(value)
129 }
130
131 pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
139 let span = expression.span();
140 let expr: Expr = match &mut expression {
141 Expression::ArrayAccess(array_access) => {
142 let array = self.try_atom(&mut array_access.array)?;
143 let index = self.try_atom(&mut array_access.index)?;
144 Expr::ArrayAccess { array, index }
145 }
146 Expression::Array(array_expression) => {
147 let atoms = array_expression
148 .elements
149 .iter_mut()
150 .map(|elt| self.try_atom(elt))
151 .collect::<Option<Vec<Atom>>>()?;
152 Expr::Array(atoms)
153 }
154 Expression::Binary(binary_expression) => {
155 let left = self.try_atom(&mut binary_expression.left)?;
156 let right = self.try_atom(&mut binary_expression.right)?;
157 let (left, right) = if matches!(
158 binary_expression.op,
159 BinaryOperation::Add
160 | BinaryOperation::AddWrapped
161 | BinaryOperation::BitwiseAnd
162 | BinaryOperation::BitwiseOr
163 | BinaryOperation::Eq
164 | BinaryOperation::Neq
165 | BinaryOperation::Mul
166 ) && right < left
167 {
168 (right, left)
170 } else {
171 (left, right)
172 };
173 Expr::Binary { op: binary_expression.op, left, right }
174 }
175 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
176 Expression::Path(path) => {
177 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
178 }
179 Expression::Repeat(repeat_expression) => {
180 let value = self.try_atom(&mut repeat_expression.expr)?;
181 let count = self.try_atom(&mut repeat_expression.count)?;
182 Expr::Repeat { value, count }
183 }
184 Expression::Ternary(ternary_expression) => {
185 let condition = self.try_atom(&mut ternary_expression.condition)?;
186 let if_true = self.try_atom(&mut ternary_expression.if_true)?;
187 let if_false = self.try_atom(&mut ternary_expression.if_false)?;
188 Expr::Ternary { condition, if_true, if_false }
189 }
190 Expression::Unary(unary) => {
191 let receiver = self.try_atom(&mut unary.receiver)?;
192 Expr::Unary { op: unary.op, receiver }
193 }
194
195 Expression::AssociatedFunction(associated_function_expression) => {
196 for arg in &mut associated_function_expression.arguments {
197 if !matches!(arg, Expression::Locator(_)) {
198 self.try_atom(arg)?;
199 }
200 }
201 return Some((expression, false));
202 }
203
204 Expression::Call(call) => {
205 for arg in &mut call.arguments {
207 self.try_atom(arg)?;
208 }
209 return Some((expression, false));
210 }
211
212 Expression::Cast(cast) => {
213 self.try_atom(&mut cast.expression)?;
214 return Some((expression, false));
215 }
216
217 Expression::MemberAccess(member_access) => {
218 self.try_atom(&mut member_access.inner)?;
219 return Some((expression, false));
220 }
221
222 Expression::Struct(struct_expression) => {
223 for initializer in &mut struct_expression.members {
224 if let Some(expr) = initializer.expression.as_mut() {
225 self.try_atom(expr)?;
226 }
227 }
228 return Some((expression, false));
229 }
230
231 Expression::Tuple(tuple_expression) => {
232 tuple_expression.elements = tuple_expression
235 .elements
236 .drain(..)
237 .map(|expr| self.try_expr(expr, None).map(|x| x.0))
238 .collect::<Option<Vec<_>>>()?;
239 return Some((expression, false));
240 }
241
242 Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
243
244 Expression::Locator(_)
245 | Expression::Async(_)
246 | Expression::AssociatedConstant(_)
247 | Expression::Err(_)
248 | Expression::Unit(_) => {
249 return Some((expression, false));
250 }
251 };
252
253 for map in self.scopes.iter().rev() {
254 if let Some(name) = map.expressions.get(&expr).cloned() {
255 let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
257 let type_ = self.state.type_table.get(&expression.id())?.clone();
259 self.state.type_table.insert(identifier.id, type_.clone());
261 if let Some(place) = place {
262 self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
265 return Some((identifier.into(), true));
266 }
267 return Some((identifier.into(), false));
268 }
269 }
270
271 if let Some(place) = place {
272 self.scopes.last_mut().unwrap().expressions.insert(expr, place);
274 }
275
276 Some((expression, false))
277 }
278}