1use super::WriteTransformingVisitor;
18use leo_ast::*;
19use leo_span::Symbol;
20
21impl WriteTransformingVisitor<'_> {
22 pub fn get_array_member(&self, array_name: Symbol, index: &Expression) -> Option<Identifier> {
23 let members = self.array_members.get(&array_name)?;
24 let Expression::Literal(lit) = index else {
25 panic!("Const propagation should have ensured this is a literal.");
26 };
27 let index = lit
28 .as_u32()
29 .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
30 as usize;
31 Some(members[index])
32 }
33
34 pub fn get_struct_member(&self, struct_name: Symbol, field_name: Symbol) -> Option<Identifier> {
35 let members = self.struct_members.get(&struct_name)?;
36 members.get(&field_name).cloned()
37 }
38}
39
40impl WriteTransformingVisitor<'_> {
41 fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Vec<Statement>) {
42 let ty = self.state.type_table.get(&input.id()).unwrap();
43 let mut statements = Vec::new();
44 if let Some(array_members) = self.array_members.get(&input.name) {
45 let id = self.state.node_builder.next_id();
47 self.state.type_table.insert(id, ty.clone());
48 let expr = ArrayExpression {
49 elements: array_members
50 .clone()
52 .iter()
53 .map(|identifier| {
54 let (expr, statements2) = self.reconstruct_identifier(*identifier);
55 statements.extend(statements2);
56 expr
57 })
58 .collect(),
59 span: Default::default(),
60 id,
61 };
62 let statement = AssignStatement {
63 place: Path::from(input).into_absolute().into(),
64 value: expr.into(),
65 span: Default::default(),
66 id: self.state.node_builder.next_id(),
67 };
68 statements.push(statement.into());
69 (Path::from(input).into_absolute().into(), statements)
70 } else if let Some(struct_members) = self.struct_members.get(&input.name) {
71 let id = self.state.node_builder.next_id();
73 self.state.type_table.insert(id, ty.clone());
74 let Type::Composite(comp_type) = ty else {
75 panic!("The type of a struct init should be a composite.");
76 };
77 let expr = StructExpression {
78 const_arguments: Vec::new(), members: struct_members
80 .clone()
82 .iter()
83 .map(|(field_name, ident)| {
84 let (expr, statements2) = self.reconstruct_identifier(*ident);
85 statements.extend(statements2);
86 StructVariableInitializer {
87 identifier: Identifier::new(*field_name, self.state.node_builder.next_id()),
88 expression: Some(expr),
89 span: Default::default(),
90 id: self.state.node_builder.next_id(),
91 }
92 })
93 .collect(),
94 path: comp_type.path,
95 span: Default::default(),
96 id,
97 };
98 let statement = AssignStatement {
99 place: Path::from(input).into_absolute().into(),
100 value: expr.into(),
101 span: Default::default(),
102 id: self.state.node_builder.next_id(),
103 };
104 statements.push(statement.into());
105 (Path::from(input).into_absolute().into(), statements)
106 } else {
107 (Path::from(input).into_absolute().into(), Default::default())
109 }
110 }
111}
112
113impl AstReconstructor for WriteTransformingVisitor<'_> {
114 type AdditionalInput = ();
115 type AdditionalOutput = Vec<Statement>;
116
117 fn reconstruct_path(&mut self, input: Path, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
119 if input.qualifier().is_empty() {
120 self.reconstruct_identifier(Identifier { name: input.identifier().name, span: input.span, id: input.id })
121 } else {
122 (input.into(), Default::default())
123 }
124 }
125
126 fn reconstruct_array_access(
127 &mut self,
128 input: ArrayAccess,
129 _addiional: &(),
130 ) -> (Expression, Self::AdditionalOutput) {
131 let Expression::Path(ref array_name) = input.array else {
132 panic!("SSA ensures that this is a Path.");
133 };
134 if let Some(member) = self.get_array_member(array_name.identifier().name, &input.index) {
135 self.reconstruct_identifier(member)
136 } else {
137 (input.into(), Default::default())
138 }
139 }
140
141 fn reconstruct_member_access(
142 &mut self,
143 input: MemberAccess,
144 _addiional: &(),
145 ) -> (Expression, Self::AdditionalOutput) {
146 let Expression::Path(ref struct_name) = input.inner else {
147 panic!("SSA ensures that this is a Path.");
148 };
149 if let Some(member) = self.get_struct_member(struct_name.identifier().name, input.name.name) {
150 self.reconstruct_identifier(member)
151 } else {
152 (input.into(), Default::default())
153 }
154 }
155
156 fn reconstruct_associated_function(
160 &mut self,
161 mut input: AssociatedFunctionExpression,
162 _addiional: &(),
163 ) -> (Expression, Self::AdditionalOutput) {
164 let mut statements = Vec::new();
165 for arg in input.arguments.iter_mut() {
166 let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
167 statements.extend(statements2);
168 *arg = expr;
169 }
170 (input.into(), statements)
171 }
172
173 fn reconstruct_tuple_access(
174 &mut self,
175 _input: TupleAccess,
176 _addiional: &(),
177 ) -> (Expression, Self::AdditionalOutput) {
178 panic!("`TupleAccess` should not be in the AST at this point.");
179 }
180
181 fn reconstruct_array(
182 &mut self,
183 mut input: ArrayExpression,
184 _addiional: &(),
185 ) -> (Expression, Self::AdditionalOutput) {
186 let mut statements = Vec::new();
187 for element in input.elements.iter_mut() {
188 let (expr, statements2) = self.reconstruct_expression(std::mem::take(element), &());
189 statements.extend(statements2);
190 *element = expr;
191 }
192 (input.into(), statements)
193 }
194
195 fn reconstruct_binary(&mut self, input: BinaryExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
196 let (left, mut statements) = self.reconstruct_expression(input.left, &());
197 let (right, statements2) = self.reconstruct_expression(input.right, &());
198 statements.extend(statements2);
199 (BinaryExpression { left, right, ..input }.into(), statements)
200 }
201
202 fn reconstruct_call(&mut self, mut input: CallExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
203 let mut statements = Vec::new();
204 for arg in input.arguments.iter_mut() {
205 let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg), &());
206 statements.extend(statements2);
207 *arg = expr;
208 }
209 (input.into(), statements)
210 }
211
212 fn reconstruct_cast(&mut self, input: CastExpression, _addiional: &()) -> (Expression, Self::AdditionalOutput) {
213 let (expression, statements) = self.reconstruct_expression(input.expression, &());
214 (CastExpression { expression, ..input }.into(), statements)
215 }
216
217 fn reconstruct_struct_init(
218 &mut self,
219 mut input: StructExpression,
220 _addiional: &(),
221 ) -> (Expression, Self::AdditionalOutput) {
222 let mut statements = Vec::new();
223 for member in input.members.iter_mut() {
224 assert!(member.expression.is_some());
225 let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap(), &());
226 statements.extend(statements2);
227 member.expression = Some(expr);
228 }
229
230 (input.into(), statements)
231 }
232
233 fn reconstruct_err(
234 &mut self,
235 _input: leo_ast::ErrExpression,
236 _addiional: &(),
237 ) -> (Expression, Self::AdditionalOutput) {
238 std::panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
239 }
240
241 fn reconstruct_literal(
242 &mut self,
243 input: leo_ast::Literal,
244 _addiional: &(),
245 ) -> (Expression, Self::AdditionalOutput) {
246 (input.into(), Default::default())
247 }
248
249 fn reconstruct_ternary(
250 &mut self,
251 input: TernaryExpression,
252 _addiional: &(),
253 ) -> (Expression, Self::AdditionalOutput) {
254 let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
255 let (if_true, statements2) = self.reconstruct_expression(input.if_true, &());
256 let (if_false, statements3) = self.reconstruct_expression(input.if_false, &());
257 statements.extend(statements2);
258 statements.extend(statements3);
259 (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
260 }
261
262 fn reconstruct_tuple(
263 &mut self,
264 input: leo_ast::TupleExpression,
265 _addiional: &(),
266 ) -> (Expression, Self::AdditionalOutput) {
267 let mut statements = Vec::new();
269 let elements = input
270 .elements
271 .into_iter()
272 .map(|element| {
273 let (expr, statements2) = self.reconstruct_expression(element, &());
274 statements.extend(statements2);
275 expr
276 })
277 .collect();
278 (TupleExpression { elements, ..input }.into(), statements)
279 }
280
281 fn reconstruct_unary(
282 &mut self,
283 input: leo_ast::UnaryExpression,
284 _addiional: &(),
285 ) -> (Expression, Self::AdditionalOutput) {
286 let (receiver, statements) = self.reconstruct_expression(input.receiver, &());
287 (UnaryExpression { receiver, ..input }.into(), statements)
288 }
289
290 fn reconstruct_unit(
291 &mut self,
292 input: leo_ast::UnitExpression,
293 _addiional: &(),
294 ) -> (Expression, Self::AdditionalOutput) {
295 (input.into(), Default::default())
296 }
297
298 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
302 let (value, mut statements) = self.reconstruct_expression(input.value, &());
303 let place = self.reconstruct_assign_place(input.place);
304 self.reconstruct_assign_recurse(place, value, &mut statements);
305 (Statement::dummy(), statements)
306 }
307
308 fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
309 let mut statements = Vec::new();
310 let stmt = AssertStatement {
311 variant: match input.variant {
312 AssertVariant::Assert(expr) => {
313 let (expr, statements2) = self.reconstruct_expression(expr, &());
314 statements.extend(statements2);
315 AssertVariant::Assert(expr)
316 }
317 AssertVariant::AssertEq(left, right) => {
318 let (left, statements2) = self.reconstruct_expression(left, &());
319 statements.extend(statements2);
320 let (right, statements3) = self.reconstruct_expression(right, &());
321 statements.extend(statements3);
322 AssertVariant::AssertEq(left, right)
323 }
324 AssertVariant::AssertNeq(left, right) => {
325 let (left, statements2) = self.reconstruct_expression(left, &());
326 statements.extend(statements2);
327 let (right, statements3) = self.reconstruct_expression(right, &());
328 statements.extend(statements3);
329 AssertVariant::AssertNeq(left, right)
330 }
331 },
332 ..input
333 }
334 .into();
335 (stmt, Default::default())
336 }
337
338 fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
339 let mut statements = Vec::with_capacity(block.statements.len());
340
341 for statement in block.statements {
343 let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
344 statements.extend(additional_statements);
345 if !reconstructed_statement.is_empty() {
346 statements.push(reconstructed_statement);
347 }
348 }
349
350 (Block { statements, ..block }, Default::default())
351 }
352
353 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
354 let (value, mut statements) = self.reconstruct_expression(input.value, &());
355 input.value = value;
356 match input.place.clone() {
357 DefinitionPlace::Single(identifier) => {
358 statements.push(input.into());
359 self.define_variable_members(identifier, &mut statements);
360 }
361 DefinitionPlace::Multiple(identifiers) => {
362 statements.push(input.into());
363 for &identifier in identifiers.iter() {
364 self.define_variable_members(identifier, &mut statements);
365 }
366 }
367 }
368 (Statement::dummy(), statements)
369 }
370
371 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
372 let (expression, statements) = self.reconstruct_expression(input.expression, &());
373 (ExpressionStatement { expression, ..input }.into(), statements)
374 }
375
376 fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
377 panic!("`IterationStatement`s should not be in the AST at this point.");
378 }
379
380 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
381 let (expression, statements) = self.reconstruct_expression(input.expression, &());
382 (ReturnStatement { expression, ..input }.into(), statements)
383 }
384
385 fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
386 let (condition, mut statements) = self.reconstruct_expression(input.condition, &());
387 let (then, statements2) = self.reconstruct_block(input.then);
388 statements.extend(statements2);
389 let otherwise = input.otherwise.map(|oth| {
390 let (expr, statements3) = self.reconstruct_statement(*oth);
391 statements.extend(statements3);
392 Box::new(expr)
393 });
394 (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
395 }
396}