1use super::{FlatteningVisitor, Guard, ReturnGuard};
18
19use leo_ast::*;
20
21use itertools::Itertools;
22
23impl AstReconstructor for FlatteningVisitor<'_> {
24 type AdditionalInput = ();
25 type AdditionalOutput = Vec<Statement>;
26
27 fn reconstruct_struct_init(
30 &mut self,
31 input: StructExpression,
32 _additional: &(),
33 ) -> (Expression, Self::AdditionalOutput) {
34 let mut statements = Vec::new();
35 let mut members = Vec::with_capacity(input.members.len());
36
37 for member in input.members.into_iter() {
39 let (expr, stmts) = self.reconstruct_expression(member.expression.unwrap(), &());
41 statements.extend(stmts);
43 members.push(StructVariableInitializer {
45 identifier: member.identifier,
46 expression: Some(expr),
47 span: member.span,
48 id: member.id,
49 });
50 }
51
52 (StructExpression { members, ..input }.into(), statements)
53 }
54
55 fn reconstruct_ternary(
71 &mut self,
72 input: TernaryExpression,
73 _additional: &(),
74 ) -> (Expression, Self::AdditionalOutput) {
75 let if_true_type = self
76 .state
77 .type_table
78 .get(&input.if_true.id())
79 .expect("Type checking guarantees that all expressions are typed.");
80 let if_false_type = self
81 .state
82 .type_table
83 .get(&input.if_false.id())
84 .expect("Type checking guarantees that all expressions are typed.");
85
86 assert!(if_true_type.eq_flat_relaxed(&if_false_type));
88
89 fn as_identifier(path_expr: Expression) -> Identifier {
90 let Expression::Path(path) = path_expr else {
91 panic!("SSA form should have guaranteed this is a path: {path_expr}.");
92 };
93 Identifier { name: path.identifier().name, span: path.span, id: path.id }
94 }
95
96 match &if_true_type {
97 Type::Array(if_true_type) => self.ternary_array(
98 if_true_type,
99 &input.condition,
100 &as_identifier(input.if_true),
101 &as_identifier(input.if_false),
102 ),
103 Type::Composite(if_true_type) => {
104 let program = if_true_type.program.unwrap_or(self.program);
106 let composite_path = if_true_type.path.clone();
107 let if_true_type = self
108 .state
109 .symbol_table
110 .lookup_struct(&composite_path.absolute_path())
111 .or_else(|| {
112 self.state.symbol_table.lookup_record(&Location::new(program, composite_path.absolute_path()))
113 })
114 .expect("This definition should exist")
115 .clone();
116
117 self.ternary_struct(
118 &composite_path,
119 &if_true_type,
120 &input.condition,
121 &as_identifier(input.if_true),
122 &as_identifier(input.if_false),
123 )
124 }
125 Type::Tuple(if_true_type) => {
126 self.ternary_tuple(if_true_type, &input.condition, &input.if_true, &input.if_false)
127 }
128 _ => {
129 assert!(matches!(&input.if_true, Expression::Path(..)));
133 assert!(matches!(&input.if_false, Expression::Path(..)));
134
135 (input.into(), Default::default())
136 }
137 }
138 }
139
140 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
159 let mut statements = Vec::new();
160
161 if self.is_async {
163 return (input.into(), statements);
164 }
165
166 let assert = AssertStatement {
168 span: input.span,
169 id: input.id,
170 variant: match input.variant {
171 AssertVariant::Assert(expression) => {
172 let (expression, additional_statements) = self.reconstruct_expression(expression, &());
173 statements.extend(additional_statements);
174 AssertVariant::Assert(expression)
175 }
176 AssertVariant::AssertEq(left, right) => {
177 let (left, additional_statements) = self.reconstruct_expression(left, &());
178 statements.extend(additional_statements);
179 let (right, additional_statements) = self.reconstruct_expression(right, &());
180 statements.extend(additional_statements);
181 AssertVariant::AssertEq(left, right)
182 }
183 AssertVariant::AssertNeq(left, right) => {
184 let (left, additional_statements) = self.reconstruct_expression(left, &());
185 statements.extend(additional_statements);
186 let (right, additional_statements) = self.reconstruct_expression(right, &());
187 statements.extend(additional_statements);
188 AssertVariant::AssertNeq(left, right)
189 }
190 },
191 };
192
193 let mut guards: Vec<Expression> = Vec::new();
194
195 if let Some((guard, guard_statements)) = self.construct_guard() {
196 statements.extend(guard_statements);
197
198 let not_guard = UnaryExpression {
201 op: UnaryOperation::Not,
202 receiver: Path::from(guard).into(),
203 span: Default::default(),
204 id: {
205 let id = self.state.node_builder.next_id();
207 self.state.type_table.insert(id, Type::Boolean);
209 id
210 },
211 }
212 .into();
213 let (identifier, statement) = self.unique_simple_definition(not_guard);
214 statements.push(statement);
215 guards.push(Path::from(identifier).into());
216 }
217
218 if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
220 guards.push(Path::from(guard).into());
221 statements.extend(guard_statements);
222 }
223
224 if guards.is_empty() {
225 return (assert.into(), statements);
226 }
227
228 let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));
229
230 let mut expression = match assert.variant {
233 AssertVariant::Assert(expression) => expression,
235
236 AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
238 let binary = BinaryExpression {
239 left,
240 right,
241 op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
242 span: Default::default(),
243 id: self.state.node_builder.next_id(),
244 };
245 self.state.type_table.insert(binary.id, Type::Boolean);
246 let (identifier, statement) = self.unique_simple_definition(binary.into());
247 statements.push(statement);
248 Path::from(identifier).into()
249 }
250 };
251
252 for guard in guards.into_iter() {
256 let binary = BinaryExpression {
257 left: expression,
258 right: guard,
259 op: BinaryOperation::Or,
260 span: Default::default(),
261 id: self.state.node_builder.next_id(),
262 };
263 self.state.type_table.insert(binary.id(), Type::Boolean);
264 let (identifier, statement) = self.unique_simple_definition(binary.into());
265 statements.push(statement);
266 expression = Path::from(identifier).into();
267 }
268
269 let assert_statement = AssertStatement { variant: AssertVariant::Assert(expression), ..input }.into();
270
271 (assert_statement, statements)
272 }
273
274 fn reconstruct_assign(&mut self, _assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
275 panic!("`AssignStatement`s should not be in the AST at this phase of compilation");
276 }
277
278 fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
282 let mut statements = Vec::with_capacity(block.statements.len());
283
284 for statement in block.statements {
286 let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
287 statements.extend(additional_statements);
288 statements.push(reconstructed_statement);
289 }
290
291 (Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
292 }
293
294 fn reconstruct_conditional(&mut self, conditional: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
296 let mut statements = Vec::with_capacity(conditional.then.statements.len());
297
298 if self.is_async {
300 let then_block = self.reconstruct_block(conditional.then).0;
301 let otherwise_block = match conditional.otherwise {
302 Some(statement) => match *statement {
303 Statement::Block(block) => self.reconstruct_block(block).0,
304 _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
305 },
306 None => {
307 Block { span: Default::default(), statements: Vec::new(), id: self.state.node_builder.next_id() }
308 }
309 };
310
311 return (
312 ConditionalStatement {
313 then: then_block,
314 otherwise: Some(Box::new(otherwise_block.into())),
315 ..conditional
316 }
317 .into(),
318 statements,
319 );
320 }
321
322 let place = Identifier {
324 name: self.state.assigner.unique_symbol("condition", "$"),
325 span: Default::default(),
326 id: {
327 let id = self.state.node_builder.next_id();
328 self.state.type_table.insert(id, Type::Boolean);
329 id
330 },
331 };
332
333 statements.push(self.simple_definition(place, conditional.condition.clone()));
334
335 self.condition_stack.push(Guard::Unconstructed(place));
337
338 statements.extend(self.reconstruct_block(conditional.then).0.statements);
340
341 self.condition_stack.pop();
343
344 if let Some(statement) = conditional.otherwise {
346 let not_condition = UnaryExpression {
348 op: UnaryOperation::Not,
349 receiver: conditional.condition.clone(),
350 span: conditional.condition.span(),
351 id: conditional.condition.id(),
352 }
353 .into();
354 let not_place = Identifier {
355 name: self.state.assigner.unique_symbol("condition", "$"),
356 span: Default::default(),
357 id: {
358 let id = self.state.node_builder.next_id();
359 self.state.type_table.insert(id, Type::Boolean);
360 id
361 },
362 };
363 statements.push(self.simple_definition(not_place, not_condition));
364 self.condition_stack.push(Guard::Unconstructed(not_place));
365
366 match *statement {
368 Statement::Block(block) => statements.extend(self.reconstruct_block(block).0.statements),
369 _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
370 }
371
372 self.condition_stack.pop();
374 };
375
376 (Statement::dummy(), statements)
377 }
378
379 fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
384 let (value, statements) = self.reconstruct_expression(definition.value, &());
386 match (definition.place, &value) {
387 (DefinitionPlace::Single(identifier), _) => (self.simple_definition(identifier, value), statements),
388 (DefinitionPlace::Multiple(identifiers), expression) => {
389 let output_type = match &self.state.type_table.get(&expression.id()) {
390 Some(Type::Tuple(tuple_type)) => tuple_type.clone(),
391 _ => panic!("Type checking guarantees that the output type is a tuple."),
392 };
393
394 for (identifier, type_) in identifiers.iter().zip_eq(output_type.elements().iter()) {
395 self.state.type_table.insert(identifier.id, type_.clone());
397 }
398
399 (
400 DefinitionStatement {
401 place: DefinitionPlace::Multiple(identifiers),
402 type_: None,
403 value,
404 span: Default::default(),
405 id: self.state.node_builder.next_id(),
406 }
407 .into(),
408 statements,
409 )
410 }
411 }
412 }
413
414 fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
415 panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
416 }
417
418 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
421 use Expression::*;
422
423 if self.is_async {
425 return (input.into(), Default::default());
426 }
427 let (guard_identifier, statements) = self.construct_guard().unzip();
429
430 let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);
431
432 let is_tuple_ids = matches!(&input.expression, Tuple(tuple_expr) if tuple_expr .elements.iter() .all(|expr| matches!(expr, Expression::Path(_))));
433 if !matches!(&input.expression, Unit(_) | Expression::Path(_) | AssociatedConstant(_)) && !is_tuple_ids {
434 panic!("SSA guarantees that the expression is always a Path, unit expression, or tuple literal.")
435 }
436
437 self.returns.push((return_guard, input));
438
439 (Statement::dummy(), statements.unwrap_or_default())
440 }
441}