1use crate::CompilerState;
18
19use leo_ast::{
20 ArrayAccess,
21 ArrayExpression,
22 ArrayType,
23 AstReconstructor,
24 BinaryExpression,
25 BinaryOperation,
26 Block,
27 Composite,
28 CompositeType,
29 Expression,
30 Identifier,
31 IntegerType,
32 Literal,
33 Member,
34 MemberAccess,
35 Node,
36 NonNegativeNumber,
37 Path,
38 ReturnStatement,
39 Statement,
40 StructExpression,
41 StructVariableInitializer,
42 TernaryExpression,
43 TupleAccess,
44 TupleExpression,
45 TupleType,
46 Type,
47 UnitExpression,
48};
49use leo_span::Symbol;
50
51#[derive(Clone, Copy)]
54pub enum Guard {
55 Unconstructed(Identifier),
58
59 Constructed(Identifier),
65}
66
67#[derive(Clone, Copy)]
68pub enum ReturnGuard {
69 None,
71
72 Unconstructed(Identifier),
75
76 Constructed {
78 plain: Identifier,
80
81 any_return: Identifier,
84 },
85}
86
87impl Guard {
88 fn identifier(self) -> Identifier {
89 match self {
90 Guard::Constructed(id) | Guard::Unconstructed(id) => id,
91 }
92 }
93}
94
95pub struct FlatteningVisitor<'a> {
96 pub state: &'a mut CompilerState,
97
98 pub condition_stack: Vec<Guard>,
100
101 pub returns: Vec<(ReturnGuard, ReturnStatement)>,
106
107 pub program: Symbol,
109
110 pub is_async: bool,
112}
113
114impl FlatteningVisitor<'_> {
115 pub fn construct_early_return_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
119 if self.returns.is_empty() {
120 return None;
121 }
122
123 if self.returns.iter().any(|g| matches!(g.0, ReturnGuard::None)) {
124 let place = Identifier {
126 name: self.state.assigner.unique_symbol("true", "$"),
127 span: Default::default(),
128 id: self.state.node_builder.next_id(),
129 };
130 let statement = self.simple_definition(
131 place,
132 Literal::boolean(true, Default::default(), self.state.node_builder.next_id()).into(),
133 );
134 return Some((place, vec![statement]));
135 }
136
137 let start_i = (0..self.returns.len())
140 .rev()
141 .take_while(|&i| matches!(self.returns[i].0, ReturnGuard::Unconstructed(_)))
142 .last()
143 .unwrap_or(self.returns.len());
144
145 let mut statements = Vec::with_capacity(self.returns.len() - start_i);
146
147 for i in start_i..self.returns.len() {
148 let ReturnGuard::Unconstructed(identifier) = self.returns[i].0 else {
149 unreachable!("We assured above that all guards after the index are Unconstructed.");
150 };
151 if i == 0 {
152 self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: identifier };
153 continue;
154 }
155
156 let ReturnGuard::Constructed { any_return: previous_identifier, .. } = self.returns[i - 1].0 else {
157 unreachable!("We're always at an index where previous guards were Constructed.");
158 };
159
160 let binary = BinaryExpression {
162 op: BinaryOperation::Or,
163 left: Path::from(previous_identifier).into_absolute().into(),
164 right: Path::from(identifier).into_absolute().into(),
165 span: Default::default(),
166 id: self.state.node_builder.next_id(),
167 };
168 self.state.type_table.insert(binary.id(), Type::Boolean);
169
170 let place = Identifier {
172 name: self.state.assigner.unique_symbol("guard", "$"),
173 span: Default::default(),
174 id: self.state.node_builder.next_id(),
175 };
176 statements.push(self.simple_definition(place, binary.into()));
177
178 self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: place };
180 }
181
182 let ReturnGuard::Constructed { any_return, .. } = self.returns.last().unwrap().0 else {
183 unreachable!("Above we made all guards Constructed.");
184 };
185
186 Some((any_return, statements))
187 }
188
189 pub fn construct_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
194 if self.condition_stack.is_empty() {
195 return None;
196 }
197
198 let start_i = (0..self.condition_stack.len())
203 .rev()
204 .take_while(|&i| matches!(self.condition_stack[i], Guard::Unconstructed(_)))
205 .last()
206 .unwrap_or(self.condition_stack.len());
207
208 let mut statements = Vec::with_capacity(self.condition_stack.len() - start_i);
209
210 for i in start_i..self.condition_stack.len() {
211 let identifier = self.condition_stack[i].identifier();
212 if i == 0 {
213 self.condition_stack[0] = Guard::Constructed(identifier);
214 continue;
215 }
216
217 let previous = self.condition_stack[i - 1].identifier();
218
219 let binary = BinaryExpression {
221 op: BinaryOperation::And,
222 left: Path::from(previous).into_absolute().into(),
223 right: Path::from(identifier).into_absolute().into(),
224 span: Default::default(),
225 id: self.state.node_builder.next_id(),
226 };
227 self.state.type_table.insert(binary.id(), Type::Boolean);
228
229 let place = Identifier {
231 name: self.state.assigner.unique_symbol("guard", "$"),
232 span: Default::default(),
233 id: self.state.node_builder.next_id(),
234 };
235 statements.push(self.simple_definition(place, binary.into()));
236
237 self.condition_stack[i] = Guard::Constructed(place);
239 }
240
241 Some((self.condition_stack.last().unwrap().identifier(), statements))
242 }
243
244 pub fn fold_guards(
247 &mut self,
248 prefix: &str,
249 mut guards: Vec<(Option<Expression>, Expression)>,
250 ) -> (Expression, Vec<Statement>) {
251 let (_, last_expression) = guards.pop().unwrap();
253
254 match last_expression {
255 Expression::Unit(_) => (last_expression, Vec::new()),
257 _ => {
259 let mut statements = Vec::with_capacity(guards.len());
261
262 let mut construct_ternary_assignment =
264 |guard: Expression, if_true: Expression, if_false: Expression| {
265 let place = Identifier {
266 name: self.state.assigner.unique_symbol(prefix, "$"),
267 span: Default::default(),
268 id: self.state.node_builder.next_id(),
269 };
270 let Some(type_) = self.state.type_table.get(&if_true.id()) else {
271 panic!("Type checking guarantees that all expressions have a type.");
272 };
273 let ternary = TernaryExpression {
274 condition: guard,
275 if_true,
276 if_false,
277 span: Default::default(),
278 id: self.state.node_builder.next_id(),
279 };
280 self.state.type_table.insert(ternary.id(), type_);
281 let (value, stmts) = self.reconstruct_ternary(ternary, &());
282 statements.extend(stmts);
283
284 if let Expression::Tuple(..) = &value {
285 value
288 } else {
289 statements.push(self.simple_definition(place, value));
291 Path::from(place).into_absolute().into()
292 }
293 };
294
295 let expression = guards.into_iter().rev().fold(last_expression, |acc, (guard, expr)| match guard {
296 None => unreachable!("All expressions except for the last one must have a guard."),
297 Some(guard) => construct_ternary_assignment(guard, expr, acc),
299 });
300
301 (expression, statements)
302 }
303 }
304 }
305
306 pub fn unique_simple_definition(&mut self, expr: Expression) -> (Identifier, Statement) {
308 let name = self.state.assigner.unique_symbol("$var", "$");
310 let place = Identifier { name, span: Default::default(), id: self.state.node_builder.next_id() };
312 let statement = self.simple_definition(place, expr);
314
315 (place, statement)
316 }
317
318 pub fn simple_definition(&mut self, lhs: Identifier, rhs: Expression) -> Statement {
320 let type_ = match self.state.type_table.get(&rhs.id()) {
322 Some(type_) => type_,
323 None => unreachable!("Type checking guarantees that all expressions have a type."),
324 };
325 self.state.type_table.insert(lhs.id(), type_);
326 self.state.assigner.simple_definition(lhs, rhs, self.state.node_builder.next_id())
328 }
329
330 pub fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
332 if !returns.is_empty() {
334 let mut return_expressions = Vec::with_capacity(returns.len());
335
336 for (guard, return_statement) in returns {
338 return_expressions.push((guard.clone(), return_statement.expression));
339 }
340
341 let (expression, stmts) = self.fold_guards("$ret", return_expressions);
343
344 block.statements.extend(stmts);
346
347 block.statements.push(
349 ReturnStatement { expression, span: Default::default(), id: self.state.node_builder.next_id() }.into(),
350 );
351 }
352 else {
354 block.statements.push(
355 ReturnStatement {
356 expression: UnitExpression { span: Default::default(), id: self.state.node_builder.next_id() }
357 .into(),
358 span: Default::default(),
359 id: self.state.node_builder.next_id(),
360 }
361 .into(),
362 );
363 }
364 }
365
366 fn make_array_access_definition(
368 &mut self,
369 i: usize,
370 identifier: Identifier,
371 array_type: &ArrayType,
372 ) -> (Identifier, Statement) {
373 let index =
374 Literal::integer(IntegerType::U32, i.to_string(), Default::default(), self.state.node_builder.next_id());
375 self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
376 let access: Expression = ArrayAccess {
377 array: Path::from(identifier).into_absolute().into(),
378 index: index.into(),
379 span: Default::default(),
380 id: self.state.node_builder.next_id(),
381 }
382 .into();
383 self.state.type_table.insert(access.id(), array_type.element_type().clone());
384 self.unique_simple_definition(access)
385 }
386
387 pub fn ternary_array(
388 &mut self,
389 array: &ArrayType,
390 condition: &Expression,
391 first: &Identifier,
392 second: &Identifier,
393 ) -> (Expression, Vec<Statement>) {
394 let mut statements = Vec::new();
396 let elements = (0..array.length.as_u32().expect("length should be known at this point") as usize)
398 .map(|i| {
399 let (first, stmt) = self.make_array_access_definition(i, *first, array);
401 statements.push(stmt);
402 let (second, stmt) = self.make_array_access_definition(i, *second, array);
404 statements.push(stmt);
405
406 let ternary = TernaryExpression {
408 condition: condition.clone(),
409 if_true: Path::from(first).into_absolute().into(),
411 if_false: Path::from(second).into_absolute().into(),
413 span: Default::default(),
414 id: self.state.node_builder.next_id(),
415 };
416 self.state.type_table.insert(ternary.id(), array.element_type().clone());
417
418 let (expression, stmts) = self.reconstruct_ternary(ternary, &());
419
420 statements.extend(stmts);
422
423 expression
424 })
425 .collect();
426
427 let (expr, stmts) = self.reconstruct_array(
429 ArrayExpression {
430 elements,
431 span: Default::default(),
432 id: {
433 let id = self.state.node_builder.next_id();
435 self.state.type_table.insert(id, Type::Array(array.clone()));
437 id
438 },
439 },
440 &(),
441 );
442
443 statements.extend(stmts);
445
446 let (identifier, statement) = self.unique_simple_definition(expr);
448
449 statements.push(statement);
450
451 (Path::from(identifier).into_absolute().into(), statements)
452 }
453
454 fn make_struct_access_definition(
456 &mut self,
457 inner: Identifier,
458 name: Identifier,
459 type_: Type,
460 ) -> (Identifier, Statement) {
461 let expr: Expression = MemberAccess {
462 inner: Path::from(inner).into_absolute().into(),
463 name,
464 span: Default::default(),
465 id: self.state.node_builder.next_id(),
466 }
467 .into();
468 self.state.type_table.insert(expr.id(), type_);
469 self.unique_simple_definition(expr)
470 }
471
472 pub fn ternary_struct(
473 &mut self,
474 struct_path: &Path,
475 struct_: &Composite,
476 condition: &Expression,
477 first: &Identifier,
478 second: &Identifier,
479 ) -> (Expression, Vec<Statement>) {
480 let mut statements = Vec::new();
482 let members = struct_
484 .members
485 .iter()
486 .map(|Member { identifier, type_, .. }| {
487 let (first, stmt) = self.make_struct_access_definition(*first, *identifier, type_.clone());
488 statements.push(stmt);
489 let (second, stmt) = self.make_struct_access_definition(*second, *identifier, type_.clone());
490 statements.push(stmt);
491 let ternary = TernaryExpression {
493 condition: condition.clone(),
494 if_true: Path::from(first).into_absolute().into(),
495 if_false: Path::from(second).into_absolute().into(),
496 span: Default::default(),
497 id: self.state.node_builder.next_id(),
498 };
499 self.state.type_table.insert(ternary.id(), type_.clone());
500 let (expression, stmts) = self.reconstruct_ternary(ternary, &());
501
502 statements.extend(stmts);
504
505 StructVariableInitializer {
506 identifier: *identifier,
507 expression: Some(expression),
508 span: Default::default(),
509 id: self.state.node_builder.next_id(),
510 }
511 })
512 .collect();
513
514 let (expr, stmts) = self.reconstruct_struct_init(
515 StructExpression {
516 path: struct_path.clone(),
517 const_arguments: Vec::new(), members,
519 span: Default::default(),
520 id: {
521 let id = self.state.node_builder.next_id();
523 self.state.type_table.insert(
525 id,
526 Type::Composite(CompositeType {
527 path: struct_path.clone(),
528 const_arguments: Vec::new(), program: struct_.external,
530 }),
531 );
532 id
533 },
534 },
535 &(),
536 );
537
538 statements.extend(stmts);
540
541 let (identifier, statement) = self.unique_simple_definition(expr);
543
544 statements.push(statement);
545
546 (Path::from(identifier).into_absolute().into(), statements)
547 }
548
549 pub fn ternary_tuple(
550 &mut self,
551 tuple_type: &TupleType,
552 condition: &Expression,
553 first: &Expression,
554 second: &Expression,
555 ) -> (Expression, Vec<Statement>) {
556 let make_access = |base_expression: &Expression, i: usize, ty: Type, slf: &mut Self| -> Expression {
557 match base_expression {
558 expr @ Expression::Path(..) => {
559 let id = slf.state.node_builder.next_id();
561 slf.state.type_table.insert(id, ty);
563 TupleAccess { tuple: expr.clone(), index: NonNegativeNumber::from(i), span: Default::default(), id }
564 .into()
565 }
566
567 Expression::Tuple(tuple_expr) => tuple_expr.elements[i].clone(),
568
569 _ => panic!("SSA should have prevented this"),
570 }
571 };
572
573 let mut statements = Vec::new();
575 let elements = tuple_type
577 .elements()
578 .iter()
579 .enumerate()
580 .map(|(i, type_)| {
581 let access1 = make_access(first, i, type_.clone(), self);
583 let (first, stmt) = self.unique_simple_definition(access1);
584 statements.push(stmt);
585 let access2 = make_access(second, i, type_.clone(), self);
586 let (second, stmt) = self.unique_simple_definition(access2);
588 statements.push(stmt);
589
590 let ternary = TernaryExpression {
592 condition: condition.clone(),
593 if_true: Path::from(first).into_absolute().into(),
594 if_false: Path::from(second).into_absolute().into(),
595 span: Default::default(),
596 id: self.state.node_builder.next_id(),
597 };
598 self.state.type_table.insert(ternary.id(), type_.clone());
599 let (expression, stmts) = self.reconstruct_ternary(ternary, &());
600
601 statements.extend(stmts);
603
604 expression
605 })
606 .collect();
607
608 let tuple = TupleExpression {
610 elements,
611 span: Default::default(),
612 id: {
613 let id = self.state.node_builder.next_id();
615 self.state.type_table.insert(id, Type::Tuple(tuple_type.clone()));
617 id
618 },
619 };
620 let (expr, stmts) = self.reconstruct_tuple(tuple, &());
621
622 statements.extend(stmts);
624
625 if let Expression::Path(..) = first {
626 let (identifier, statement) = self.unique_simple_definition(expr);
628
629 statements.push(statement);
630
631 (Path::from(identifier).into_absolute().into(), statements)
632 } else {
633 (expr, statements)
635 }
636 }
637}