leo_passes/common_subexpression_elimination/
visitor.rs1use crate::CompilerState;
18
19use leo_ast::{BinaryOperation, Expression, Identifier, LiteralVariant, Node as _, Path, UnaryOperation};
20use leo_span::Symbol;
21
22use std::collections::HashMap;
23
24#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
26pub enum Atom {
27 Path(Vec<Symbol>),
28 Literal(LiteralVariant),
29}
30
31#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug)]
38pub enum Expr {
39 Atom(Atom),
40 Array(Vec<Atom>),
41 ArrayAccess { array: Atom, index: Atom },
42 Binary { op: BinaryOperation, left: Atom, right: Atom },
43 Repeat { value: Atom, count: Atom },
44 Ternary { condition: Atom, if_true: Atom, if_false: Atom },
45 Unary { op: UnaryOperation, receiver: Atom },
46}
47
48impl From<Atom> for Expr {
49 fn from(value: Atom) -> Self {
50 Expr::Atom(value)
51 }
52}
53
54#[derive(Default, Debug)]
55pub struct Scope {
56 pub expressions: HashMap<Expr, Symbol>,
57}
58
59pub struct CommonSubexpressionEliminatingVisitor<'a> {
60 pub state: &'a mut CompilerState,
61
62 pub scopes: Vec<Scope>,
63}
64
65impl CommonSubexpressionEliminatingVisitor<'_> {
66 pub fn in_scope<T>(&mut self, func: impl FnOnce(&mut Self) -> T) -> T {
67 self.scopes.push(Default::default());
68 let result = func(self);
69 self.scopes.pop();
70 result
71 }
72
73 fn try_atom(&self, expression: &mut Expression) -> Option<Atom> {
76 let id = expression.id();
78 let value = match expression {
80 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()),
81 Expression::Path(path) => {
82 let atom_path =
83 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect());
84 let expr = Expr::Atom(atom_path);
85 if let Some(name) = self.scopes.iter().rev().find_map(|scope| scope.expressions.get(&expr)) {
86 let type_ = self.state.type_table.get(&id)?;
88 let p = Path::new(
90 Vec::new(),
91 Identifier::new(*name, self.state.node_builder.next_id()),
92 true,
93 Some(vec![*name]),
94 path.span(),
95 self.state.node_builder.next_id(),
96 );
97 self.state.type_table.insert(p.id(), type_);
99 *path = p;
101 Atom::Path(vec![*name])
102 } else {
103 let Expr::Atom(atom_path) = expr else { unreachable!() };
104 atom_path
105 }
106 }
107
108 Expression::ArrayAccess(_)
109 | Expression::Intrinsic(_)
110 | Expression::Async(_)
111 | Expression::Array(_)
112 | Expression::Binary(_)
113 | Expression::Call(_)
114 | Expression::Cast(_)
115 | Expression::Err(_)
116 | Expression::Locator(_)
117 | Expression::MemberAccess(_)
118 | Expression::Repeat(_)
119 | Expression::Struct(_)
120 | Expression::Ternary(_)
121 | Expression::Tuple(_)
122 | Expression::TupleAccess(_)
123 | Expression::Unary(_)
124 | Expression::Unit(_) => return None,
125 };
126
127 Some(value)
128 }
129
130 pub fn try_expr(&mut self, mut expression: Expression, place: Option<Symbol>) -> Option<(Expression, bool)> {
138 let span = expression.span();
139 let expr: Expr = match &mut expression {
140 Expression::ArrayAccess(array_access) => {
141 let array = self.try_atom(&mut array_access.array)?;
142 let index = self.try_atom(&mut array_access.index)?;
143 Expr::ArrayAccess { array, index }
144 }
145 Expression::Array(array_expression) => {
146 let atoms = array_expression
147 .elements
148 .iter_mut()
149 .map(|elt| self.try_atom(elt))
150 .collect::<Option<Vec<Atom>>>()?;
151 Expr::Array(atoms)
152 }
153 Expression::Binary(binary_expression) => {
154 let left = self.try_atom(&mut binary_expression.left)?;
155 let right = self.try_atom(&mut binary_expression.right)?;
156 let (left, right) = if matches!(
157 binary_expression.op,
158 BinaryOperation::Add
159 | BinaryOperation::AddWrapped
160 | BinaryOperation::BitwiseAnd
161 | BinaryOperation::BitwiseOr
162 | BinaryOperation::Eq
163 | BinaryOperation::Neq
164 | BinaryOperation::Mul
165 ) && right < left
166 {
167 (right, left)
169 } else {
170 (left, right)
171 };
172 Expr::Binary { op: binary_expression.op, left, right }
173 }
174 Expression::Literal(literal) => Atom::Literal(literal.variant.clone()).into(),
175 Expression::Path(path) => {
176 Atom::Path(path.qualifier().iter().map(|id| id.name).chain([path.identifier().name]).collect()).into()
177 }
178 Expression::Repeat(repeat_expression) => {
179 let value = self.try_atom(&mut repeat_expression.expr)?;
180 let count = self.try_atom(&mut repeat_expression.count)?;
181 Expr::Repeat { value, count }
182 }
183 Expression::Ternary(ternary_expression) => {
184 let condition = self.try_atom(&mut ternary_expression.condition)?;
185 let if_true = self.try_atom(&mut ternary_expression.if_true)?;
186 let if_false = self.try_atom(&mut ternary_expression.if_false)?;
187 Expr::Ternary { condition, if_true, if_false }
188 }
189 Expression::Unary(unary) => {
190 let receiver = self.try_atom(&mut unary.receiver)?;
191 Expr::Unary { op: unary.op, receiver }
192 }
193
194 Expression::Intrinsic(intrinsic) => {
195 for arg in &mut intrinsic.arguments {
196 if !matches!(arg, Expression::Locator(_)) {
197 self.try_atom(arg)?;
198 }
199 }
200 return Some((expression, false));
201 }
202
203 Expression::Call(call) => {
204 for arg in &mut call.arguments {
206 self.try_atom(arg)?;
207 }
208 return Some((expression, false));
209 }
210
211 Expression::Cast(cast) => {
212 self.try_atom(&mut cast.expression)?;
213 return Some((expression, false));
214 }
215
216 Expression::MemberAccess(member_access) => {
217 self.try_atom(&mut member_access.inner)?;
218 return Some((expression, false));
219 }
220
221 Expression::Struct(struct_expression) => {
222 for initializer in &mut struct_expression.members {
223 if let Some(expr) = initializer.expression.as_mut() {
224 self.try_atom(expr)?;
225 }
226 }
227 return Some((expression, false));
228 }
229
230 Expression::Tuple(tuple_expression) => {
231 tuple_expression.elements = tuple_expression
234 .elements
235 .drain(..)
236 .map(|expr| self.try_expr(expr, None).map(|x| x.0))
237 .collect::<Option<Vec<_>>>()?;
238 return Some((expression, false));
239 }
240
241 Expression::TupleAccess(_) => panic!("Tuple access expressions should not exist in this pass."),
242
243 Expression::Locator(_) | Expression::Async(_) | Expression::Err(_) | Expression::Unit(_) => {
244 return Some((expression, false));
245 }
246 };
247
248 for map in self.scopes.iter().rev() {
249 if let Some(name) = map.expressions.get(&expr).cloned() {
250 let identifier = Identifier { name, span, id: self.state.node_builder.next_id() };
252 let type_ = self.state.type_table.get(&expression.id())?.clone();
254 self.state.type_table.insert(identifier.id, type_.clone());
256 if let Some(place) = place {
257 self.scopes.last_mut().unwrap().expressions.insert(Atom::Path(vec![place]).into(), name);
260 return Some((identifier.into(), true));
261 }
262 return Some((identifier.into(), false));
263 }
264 }
265
266 if let Some(place) = place {
267 self.scopes.last_mut().unwrap().expressions.insert(expr, place);
269 }
270
271 Some((expression, false))
272 }
273}