leo_passes/write_transforming/
expression.rs1use super::WriteTransformingVisitor;
18
19use leo_ast::{
20 ArrayAccess,
21 ArrayExpression,
22 AssignStatement,
23 AssociatedFunctionExpression,
24 BinaryExpression,
25 CallExpression,
26 CastExpression,
27 Expression,
28 ExpressionReconstructor,
29 Identifier,
30 MemberAccess,
31 Node,
32 Statement,
33 StructExpression,
34 StructVariableInitializer,
35 TernaryExpression,
36 TupleAccess,
37 TupleExpression,
38 Type,
39 UnaryExpression,
40};
41use leo_span::Symbol;
42
43impl WriteTransformingVisitor<'_> {
44 pub fn get_array_member(&self, array_name: Symbol, index: &Expression) -> Option<Identifier> {
45 let members = self.array_members.get(&array_name)?;
46 let Expression::Literal(lit) = index else {
47 panic!("Const propagation should have ensured this is a literal.");
48 };
49 let index = lit
50 .as_u32()
51 .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
52 as usize;
53 Some(members[index])
54 }
55
56 pub fn get_struct_member(&self, struct_name: Symbol, field_name: Symbol) -> Option<Identifier> {
57 let members = self.struct_members.get(&struct_name)?;
58 members.get(&field_name).cloned()
59 }
60}
61
62impl ExpressionReconstructor for WriteTransformingVisitor<'_> {
63 type AdditionalOutput = Vec<Statement>;
64
65 fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
66 let ty = self.state.type_table.get(&input.id()).unwrap();
67 let mut statements = Vec::new();
68 if let Some(array_members) = self.array_members.get(&input.name) {
69 let id = self.state.node_builder.next_id();
71 self.state.type_table.insert(id, ty.clone());
72 let expr = ArrayExpression {
73 elements: array_members
74 .clone()
76 .iter()
77 .map(|identifier| {
78 let (expr, statements2) = self.reconstruct_identifier(*identifier);
79 statements.extend(statements2);
80 expr
81 })
82 .collect(),
83 span: Default::default(),
84 id,
85 };
86 let statement = AssignStatement {
87 place: input.into(),
88 value: expr.into(),
89 span: Default::default(),
90 id: self.state.node_builder.next_id(),
91 };
92 statements.push(statement.into());
93 (input.into(), statements)
94 } else if let Some(struct_members) = self.struct_members.get(&input.name) {
95 let id = self.state.node_builder.next_id();
97 self.state.type_table.insert(id, ty.clone());
98 let Type::Composite(comp_type) = ty else {
99 panic!("The type of a struct init should be a composite.");
100 };
101 let expr = StructExpression {
102 members: struct_members
103 .clone()
105 .iter()
106 .map(|(field_name, ident)| {
107 let (expr, statements2) = self.reconstruct_identifier(*ident);
108 statements.extend(statements2);
109 StructVariableInitializer {
110 identifier: Identifier::new(*field_name, self.state.node_builder.next_id()),
111 expression: Some(expr),
112 span: Default::default(),
113 id: self.state.node_builder.next_id(),
114 }
115 })
116 .collect(),
117 name: comp_type.id,
118 span: Default::default(),
119 id,
120 };
121 let statement = AssignStatement {
122 place: input.into(),
123 value: expr.into(),
124 span: Default::default(),
125 id: self.state.node_builder.next_id(),
126 };
127 statements.push(statement.into());
128 (input.into(), statements)
129 } else {
130 (input.into(), Default::default())
132 }
133 }
134
135 fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
136 let Expression::Identifier(array_name) = input.array else {
137 panic!("SSA ensures that this is an Identifier.");
138 };
139 if let Some(member) = self.get_array_member(array_name.name, &input.index) {
140 self.reconstruct_identifier(member)
141 } else {
142 (input.into(), Default::default())
143 }
144 }
145
146 fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
147 let Expression::Identifier(array_name) = input.inner else {
148 panic!("SSA ensures that this is an Identifier.");
149 };
150 if let Some(member) = self.get_struct_member(array_name.name, input.name.name) {
151 self.reconstruct_identifier(member)
152 } else {
153 (input.into(), Default::default())
154 }
155 }
156
157 fn reconstruct_associated_function(
161 &mut self,
162 mut input: AssociatedFunctionExpression,
163 ) -> (Expression, Self::AdditionalOutput) {
164 let mut statements = Vec::new();
165 for arg in input.arguments.iter_mut() {
166 let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg));
167 statements.extend(statements2);
168 *arg = expr;
169 }
170 (input.into(), statements)
171 }
172
173 fn reconstruct_tuple_access(&mut self, _input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
174 panic!("`TupleAccess` should not be in the AST at this point.");
175 }
176
177 fn reconstruct_array(&mut self, mut input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
178 let mut statements = Vec::new();
179 for element in input.elements.iter_mut() {
180 let (expr, statements2) = self.reconstruct_expression(std::mem::take(element));
181 statements.extend(statements2);
182 *element = expr;
183 }
184 (input.into(), statements)
185 }
186
187 fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
188 let (left, mut statements) = self.reconstruct_expression(input.left);
189 let (right, statements2) = self.reconstruct_expression(input.right);
190 statements.extend(statements2);
191 (BinaryExpression { left, right, ..input }.into(), statements)
192 }
193
194 fn reconstruct_call(&mut self, mut input: CallExpression) -> (Expression, Self::AdditionalOutput) {
195 let mut statements = Vec::new();
196 for arg in input.arguments.iter_mut() {
197 let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg));
198 statements.extend(statements2);
199 *arg = expr;
200 }
201 (input.into(), statements)
202 }
203
204 fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
205 let (expression, statements) = self.reconstruct_expression(input.expression);
206 (CastExpression { expression, ..input }.into(), statements)
207 }
208
209 fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
210 let mut statements = Vec::new();
211 for member in input.members.iter_mut() {
212 assert!(member.expression.is_some());
213 let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap());
214 statements.extend(statements2);
215 member.expression = Some(expr);
216 }
217
218 (input.into(), statements)
219 }
220
221 fn reconstruct_err(&mut self, _input: leo_ast::ErrExpression) -> (Expression, Self::AdditionalOutput) {
222 std::panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
223 }
224
225 fn reconstruct_literal(&mut self, input: leo_ast::Literal) -> (Expression, Self::AdditionalOutput) {
226 (input.into(), Default::default())
227 }
228
229 fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
230 let (condition, mut statements) = self.reconstruct_expression(input.condition);
231 let (if_true, statements2) = self.reconstruct_expression(input.if_true);
232 let (if_false, statements3) = self.reconstruct_expression(input.if_false);
233 statements.extend(statements2);
234 statements.extend(statements3);
235 (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
236 }
237
238 fn reconstruct_tuple(&mut self, input: leo_ast::TupleExpression) -> (Expression, Self::AdditionalOutput) {
239 let mut statements = Vec::new();
241 let elements = input
242 .elements
243 .into_iter()
244 .map(|element| {
245 let (expr, statements2) = self.reconstruct_expression(element);
246 statements.extend(statements2);
247 expr
248 })
249 .collect();
250 (TupleExpression { elements, ..input }.into(), statements)
251 }
252
253 fn reconstruct_unary(&mut self, input: leo_ast::UnaryExpression) -> (Expression, Self::AdditionalOutput) {
254 let (receiver, statements) = self.reconstruct_expression(input.receiver);
255 (UnaryExpression { receiver, ..input }.into(), statements)
256 }
257
258 fn reconstruct_unit(&mut self, input: leo_ast::UnitExpression) -> (Expression, Self::AdditionalOutput) {
259 (input.into(), Default::default())
260 }
261}