1use super::{FlatteningVisitor, Guard, ReturnGuard};
18
19use leo_ast::*;
20
21use itertools::Itertools;
22
23impl AstReconstructor for FlatteningVisitor<'_> {
24 type AdditionalOutput = Vec<Statement>;
25
26 fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
29 let mut statements = Vec::new();
30 let mut members = Vec::with_capacity(input.members.len());
31
32 for member in input.members.into_iter() {
34 let (expr, stmts) = self.reconstruct_expression(member.expression.unwrap());
36 statements.extend(stmts);
38 members.push(StructVariableInitializer {
40 identifier: member.identifier,
41 expression: Some(expr),
42 span: member.span,
43 id: member.id,
44 });
45 }
46
47 (StructExpression { members, ..input }.into(), statements)
48 }
49
50 fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
66 let if_true_type = self
67 .state
68 .type_table
69 .get(&input.if_true.id())
70 .expect("Type checking guarantees that all expressions are typed.");
71 let if_false_type = self
72 .state
73 .type_table
74 .get(&input.if_false.id())
75 .expect("Type checking guarantees that all expressions are typed.");
76
77 assert!(if_true_type.eq_flat_relaxed(&if_false_type));
79
80 fn as_identifier(ident_expr: Expression) -> Identifier {
81 let Expression::Identifier(identifier) = ident_expr else {
82 panic!("SSA form should have guaranteed this is an identifier: {}.", ident_expr);
83 };
84 identifier
85 }
86
87 match &if_true_type {
88 Type::Array(if_true_type) => self.ternary_array(
89 if_true_type,
90 &input.condition,
91 &as_identifier(input.if_true),
92 &as_identifier(input.if_false),
93 ),
94 Type::Composite(if_true_type) => {
95 let program = if_true_type.program.unwrap_or(self.program);
97 let if_true_type = self
98 .state
99 .symbol_table
100 .lookup_struct(if_true_type.id.name)
101 .or_else(|| self.state.symbol_table.lookup_record(Location::new(program, if_true_type.id.name)))
102 .expect("This definition should exist")
103 .clone();
104
105 self.ternary_struct(
106 &if_true_type,
107 &input.condition,
108 &as_identifier(input.if_true),
109 &as_identifier(input.if_false),
110 )
111 }
112 Type::Tuple(if_true_type) => {
113 self.ternary_tuple(if_true_type, &input.condition, &input.if_true, &input.if_false)
114 }
115 _ => {
116 assert!(matches!(&input.if_true, Expression::Identifier(..)));
120 assert!(matches!(&input.if_false, Expression::Identifier(..)));
121
122 (input.into(), Default::default())
123 }
124 }
125 }
126
127 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
146 let mut statements = Vec::new();
147
148 if self.is_async {
150 return (input.into(), statements);
151 }
152
153 let assert = AssertStatement {
155 span: input.span,
156 id: input.id,
157 variant: match input.variant {
158 AssertVariant::Assert(expression) => {
159 let (expression, additional_statements) = self.reconstruct_expression(expression);
160 statements.extend(additional_statements);
161 AssertVariant::Assert(expression)
162 }
163 AssertVariant::AssertEq(left, right) => {
164 let (left, additional_statements) = self.reconstruct_expression(left);
165 statements.extend(additional_statements);
166 let (right, additional_statements) = self.reconstruct_expression(right);
167 statements.extend(additional_statements);
168 AssertVariant::AssertEq(left, right)
169 }
170 AssertVariant::AssertNeq(left, right) => {
171 let (left, additional_statements) = self.reconstruct_expression(left);
172 statements.extend(additional_statements);
173 let (right, additional_statements) = self.reconstruct_expression(right);
174 statements.extend(additional_statements);
175 AssertVariant::AssertNeq(left, right)
176 }
177 },
178 };
179
180 let mut guards: Vec<Expression> = Vec::new();
181
182 if let Some((guard, guard_statements)) = self.construct_guard() {
183 statements.extend(guard_statements);
184
185 let not_guard = UnaryExpression {
188 op: UnaryOperation::Not,
189 receiver: guard.into(),
190 span: Default::default(),
191 id: {
192 let id = self.state.node_builder.next_id();
194 self.state.type_table.insert(id, Type::Boolean);
196 id
197 },
198 }
199 .into();
200 let (identifier, statement) = self.unique_simple_definition(not_guard);
201 statements.push(statement);
202 guards.push(identifier.into());
203 }
204
205 if let Some((guard, guard_statements)) = self.construct_early_return_guard() {
207 guards.push(guard.into());
208 statements.extend(guard_statements);
209 }
210
211 if guards.is_empty() {
212 return (assert.into(), statements);
213 }
214
215 let is_eq = matches!(assert.variant, AssertVariant::AssertEq(..));
216
217 let mut expression = match assert.variant {
220 AssertVariant::Assert(expression) => expression,
222
223 AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
225 let binary = BinaryExpression {
226 left,
227 right,
228 op: if is_eq { BinaryOperation::Eq } else { BinaryOperation::Neq },
229 span: Default::default(),
230 id: self.state.node_builder.next_id(),
231 };
232 self.state.type_table.insert(binary.id, Type::Boolean);
233 let (identifier, statement) = self.unique_simple_definition(binary.into());
234 statements.push(statement);
235 identifier.into()
236 }
237 };
238
239 for guard in guards.into_iter() {
243 let binary = BinaryExpression {
244 left: expression,
245 right: guard,
246 op: BinaryOperation::Or,
247 span: Default::default(),
248 id: self.state.node_builder.next_id(),
249 };
250 self.state.type_table.insert(binary.id(), Type::Boolean);
251 let (identifier, statement) = self.unique_simple_definition(binary.into());
252 statements.push(statement);
253 expression = identifier.into();
254 }
255
256 let assert_statement = AssertStatement { variant: AssertVariant::Assert(expression), ..input }.into();
257
258 (assert_statement, statements)
259 }
260
261 fn reconstruct_assign(&mut self, _assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
262 panic!("`AssignStatement`s should not be in the AST at this phase of compilation");
263 }
264
265 fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
269 let mut statements = Vec::with_capacity(block.statements.len());
270
271 for statement in block.statements {
273 let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
274 statements.extend(additional_statements);
275 statements.push(reconstructed_statement);
276 }
277
278 (Block { span: block.span, statements, id: self.state.node_builder.next_id() }, Default::default())
279 }
280
281 fn reconstruct_conditional(&mut self, conditional: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
283 let mut statements = Vec::with_capacity(conditional.then.statements.len());
284
285 if self.is_async {
287 let then_block = self.reconstruct_block(conditional.then).0;
288 let otherwise_block = match conditional.otherwise {
289 Some(statement) => match *statement {
290 Statement::Block(block) => self.reconstruct_block(block).0,
291 _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
292 },
293 None => {
294 Block { span: Default::default(), statements: Vec::new(), id: self.state.node_builder.next_id() }
295 }
296 };
297
298 return (
299 ConditionalStatement {
300 then: then_block,
301 otherwise: Some(Box::new(otherwise_block.into())),
302 ..conditional
303 }
304 .into(),
305 statements,
306 );
307 }
308
309 let place = Identifier {
311 name: self.state.assigner.unique_symbol("condition", "$"),
312 span: Default::default(),
313 id: {
314 let id = self.state.node_builder.next_id();
315 self.state.type_table.insert(id, Type::Boolean);
316 id
317 },
318 };
319
320 statements.push(self.simple_definition(place, conditional.condition.clone()));
321
322 self.condition_stack.push(Guard::Unconstructed(place));
324
325 statements.extend(self.reconstruct_block(conditional.then).0.statements);
327
328 self.condition_stack.pop();
330
331 if let Some(statement) = conditional.otherwise {
333 let not_condition = UnaryExpression {
335 op: UnaryOperation::Not,
336 receiver: conditional.condition.clone(),
337 span: conditional.condition.span(),
338 id: conditional.condition.id(),
339 }
340 .into();
341 let not_place = Identifier {
342 name: self.state.assigner.unique_symbol("condition", "$"),
343 span: Default::default(),
344 id: {
345 let id = self.state.node_builder.next_id();
346 self.state.type_table.insert(id, Type::Boolean);
347 id
348 },
349 };
350 statements.push(self.simple_definition(not_place, not_condition));
351 self.condition_stack.push(Guard::Unconstructed(not_place));
352
353 match *statement {
355 Statement::Block(block) => statements.extend(self.reconstruct_block(block).0.statements),
356 _ => panic!("SSA guarantees that the `otherwise` is always a `Block`"),
357 }
358
359 self.condition_stack.pop();
361 };
362
363 (Statement::dummy(), statements)
364 }
365
366 fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
371 let (value, statements) = self.reconstruct_expression(definition.value);
373 match (definition.place, &value) {
374 (DefinitionPlace::Single(identifier), _) => (self.simple_definition(identifier, value), statements),
375 (DefinitionPlace::Multiple(identifiers), expression) => {
376 let output_type = match &self.state.type_table.get(&expression.id()) {
377 Some(Type::Tuple(tuple_type)) => tuple_type.clone(),
378 _ => panic!("Type checking guarantees that the output type is a tuple."),
379 };
380
381 for (identifier, type_) in identifiers.iter().zip_eq(output_type.elements().iter()) {
382 self.state.type_table.insert(identifier.id, type_.clone());
384 }
385
386 (
387 DefinitionStatement {
388 place: DefinitionPlace::Multiple(identifiers),
389 type_: None,
390 value,
391 span: Default::default(),
392 id: self.state.node_builder.next_id(),
393 }
394 .into(),
395 statements,
396 )
397 }
398 }
399 }
400
401 fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
402 panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
403 }
404
405 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
408 use Expression::*;
409
410 if self.is_async {
412 return (input.into(), Default::default());
413 }
414 let (guard_identifier, statements) = self.construct_guard().unzip();
416
417 let return_guard = guard_identifier.map_or(ReturnGuard::None, ReturnGuard::Unconstructed);
418
419 let is_tuple_ids = matches!(&input.expression, Tuple(tuple_expr) if tuple_expr .elements.iter() .all(|expr| matches!(expr, Identifier(_))));
420 if !matches!(&input.expression, Unit(_) | Identifier(_) | AssociatedConstant(_)) && !is_tuple_ids {
421 panic!("SSA guarantees that the expression is always an identifier, unit expression, or tuple literal.")
422 }
423
424 self.returns.push((return_guard, input));
425
426 (Statement::dummy(), statements.unwrap_or_default())
427 }
428}