1use crate::CompilerState;
18
19use leo_ast::{
20 ArrayAccess,
21 ArrayExpression,
22 ArrayType,
23 AstReconstructor,
24 BinaryExpression,
25 BinaryOperation,
26 Block,
27 Composite,
28 CompositeType,
29 Expression,
30 Identifier,
31 IntegerType,
32 Literal,
33 Member,
34 MemberAccess,
35 Node,
36 NonNegativeNumber,
37 ReturnStatement,
38 Statement,
39 StructExpression,
40 StructVariableInitializer,
41 TernaryExpression,
42 TupleAccess,
43 TupleExpression,
44 TupleType,
45 Type,
46 UnitExpression,
47};
48use leo_span::Symbol;
49
50#[derive(Clone, Copy)]
53pub enum Guard {
54 Unconstructed(Identifier),
57
58 Constructed(Identifier),
64}
65
66#[derive(Clone, Copy)]
67pub enum ReturnGuard {
68 None,
70
71 Unconstructed(Identifier),
74
75 Constructed {
77 plain: Identifier,
79
80 any_return: Identifier,
83 },
84}
85
86impl Guard {
87 fn identifier(self) -> Identifier {
88 match self {
89 Guard::Constructed(id) | Guard::Unconstructed(id) => id,
90 }
91 }
92}
93
94pub struct FlatteningVisitor<'a> {
95 pub state: &'a mut CompilerState,
96
97 pub condition_stack: Vec<Guard>,
99
100 pub returns: Vec<(ReturnGuard, ReturnStatement)>,
105
106 pub program: Symbol,
108
109 pub is_async: bool,
111}
112
113impl FlatteningVisitor<'_> {
114 pub fn construct_early_return_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
118 if self.returns.is_empty() {
119 return None;
120 }
121
122 if self.returns.iter().any(|g| matches!(g.0, ReturnGuard::None)) {
123 let place = Identifier {
125 name: self.state.assigner.unique_symbol("true", "$"),
126 span: Default::default(),
127 id: self.state.node_builder.next_id(),
128 };
129 let statement = self.simple_definition(
130 place,
131 Literal::boolean(true, Default::default(), self.state.node_builder.next_id()).into(),
132 );
133 return Some((place, vec![statement]));
134 }
135
136 let start_i = (0..self.returns.len())
139 .rev()
140 .take_while(|&i| matches!(self.returns[i].0, ReturnGuard::Unconstructed(_)))
141 .last()
142 .unwrap_or(self.returns.len());
143
144 let mut statements = Vec::with_capacity(self.returns.len() - start_i);
145
146 for i in start_i..self.returns.len() {
147 let ReturnGuard::Unconstructed(identifier) = self.returns[i].0 else {
148 unreachable!("We assured above that all guards after the index are Unconstructed.");
149 };
150 if i == 0 {
151 self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: identifier };
152 continue;
153 }
154
155 let ReturnGuard::Constructed { any_return: previous_identifier, .. } = self.returns[i - 1].0 else {
156 unreachable!("We're always at an index where previous guards were Constructed.");
157 };
158
159 let binary = BinaryExpression {
161 op: BinaryOperation::Or,
162 left: previous_identifier.into(),
163 right: identifier.into(),
164 span: Default::default(),
165 id: self.state.node_builder.next_id(),
166 };
167 self.state.type_table.insert(binary.id(), Type::Boolean);
168
169 let place = Identifier {
171 name: self.state.assigner.unique_symbol("guard", "$"),
172 span: Default::default(),
173 id: self.state.node_builder.next_id(),
174 };
175 statements.push(self.simple_definition(place, binary.into()));
176
177 self.returns[i].0 = ReturnGuard::Constructed { plain: identifier, any_return: place };
179 }
180
181 let ReturnGuard::Constructed { any_return, .. } = self.returns.last().unwrap().0 else {
182 unreachable!("Above we made all guards Constructed.");
183 };
184
185 Some((any_return, statements))
186 }
187
188 pub fn construct_guard(&mut self) -> Option<(Identifier, Vec<Statement>)> {
193 if self.condition_stack.is_empty() {
194 return None;
195 }
196
197 let start_i = (0..self.condition_stack.len())
202 .rev()
203 .take_while(|&i| matches!(self.condition_stack[i], Guard::Unconstructed(_)))
204 .last()
205 .unwrap_or(self.condition_stack.len());
206
207 let mut statements = Vec::with_capacity(self.condition_stack.len() - start_i);
208
209 for i in start_i..self.condition_stack.len() {
210 let identifier = self.condition_stack[i].identifier();
211 if i == 0 {
212 self.condition_stack[0] = Guard::Constructed(identifier);
213 continue;
214 }
215
216 let previous = self.condition_stack[i - 1].identifier();
217
218 let binary = BinaryExpression {
220 op: BinaryOperation::And,
221 left: previous.into(),
222 right: identifier.into(),
223 span: Default::default(),
224 id: self.state.node_builder.next_id(),
225 };
226 self.state.type_table.insert(binary.id(), Type::Boolean);
227
228 let place = Identifier {
230 name: self.state.assigner.unique_symbol("guard", "$"),
231 span: Default::default(),
232 id: self.state.node_builder.next_id(),
233 };
234 statements.push(self.simple_definition(place, binary.into()));
235
236 self.condition_stack[i] = Guard::Constructed(place);
238 }
239
240 Some((self.condition_stack.last().unwrap().identifier(), statements))
241 }
242
243 pub fn fold_guards(
246 &mut self,
247 prefix: &str,
248 mut guards: Vec<(Option<Expression>, Expression)>,
249 ) -> (Expression, Vec<Statement>) {
250 let (_, last_expression) = guards.pop().unwrap();
252
253 match last_expression {
254 Expression::Unit(_) => (last_expression, Vec::new()),
256 _ => {
258 let mut statements = Vec::with_capacity(guards.len());
260
261 let mut construct_ternary_assignment =
263 |guard: Expression, if_true: Expression, if_false: Expression| {
264 let place = Identifier {
265 name: self.state.assigner.unique_symbol(prefix, "$"),
266 span: Default::default(),
267 id: self.state.node_builder.next_id(),
268 };
269 let Some(type_) = self.state.type_table.get(&if_true.id()) else {
270 panic!("Type checking guarantees that all expressions have a type.");
271 };
272 let ternary = TernaryExpression {
273 condition: guard,
274 if_true,
275 if_false,
276 span: Default::default(),
277 id: self.state.node_builder.next_id(),
278 };
279 self.state.type_table.insert(ternary.id(), type_);
280 let (value, stmts) = self.reconstruct_ternary(ternary);
281 statements.extend(stmts);
282
283 if let Expression::Tuple(..) = &value {
284 value
287 } else {
288 statements.push(self.simple_definition(place, value));
290 place.into()
291 }
292 };
293
294 let expression = guards.into_iter().rev().fold(last_expression, |acc, (guard, expr)| match guard {
295 None => unreachable!("All expressions except for the last one must have a guard."),
296 Some(guard) => construct_ternary_assignment(guard, expr, acc),
298 });
299
300 (expression, statements)
301 }
302 }
303 }
304
305 pub fn unique_simple_definition(&mut self, expr: Expression) -> (Identifier, Statement) {
307 let name = self.state.assigner.unique_symbol("$var", "$");
309 let place = Identifier { name, span: Default::default(), id: self.state.node_builder.next_id() };
311 let statement = self.simple_definition(place, expr);
313
314 (place, statement)
315 }
316
317 pub fn simple_definition(&mut self, lhs: Identifier, rhs: Expression) -> Statement {
319 let type_ = match self.state.type_table.get(&rhs.id()) {
321 Some(type_) => type_,
322 None => unreachable!("Type checking guarantees that all expressions have a type."),
323 };
324 self.state.type_table.insert(lhs.id(), type_);
325 self.state.assigner.simple_definition(lhs, rhs, self.state.node_builder.next_id())
327 }
328
329 pub fn fold_returns(&mut self, block: &mut Block, returns: Vec<(Option<Expression>, ReturnStatement)>) {
331 if !returns.is_empty() {
333 let mut return_expressions = Vec::with_capacity(returns.len());
334
335 for (guard, return_statement) in returns {
337 return_expressions.push((guard.clone(), return_statement.expression));
338 }
339
340 let (expression, stmts) = self.fold_guards("$ret", return_expressions);
342
343 block.statements.extend(stmts);
345
346 block.statements.push(
348 ReturnStatement { expression, span: Default::default(), id: self.state.node_builder.next_id() }.into(),
349 );
350 }
351 else {
353 block.statements.push(
354 ReturnStatement {
355 expression: UnitExpression { span: Default::default(), id: self.state.node_builder.next_id() }
356 .into(),
357 span: Default::default(),
358 id: self.state.node_builder.next_id(),
359 }
360 .into(),
361 );
362 }
363 }
364
365 fn make_array_access_definition(
367 &mut self,
368 i: usize,
369 identifier: Identifier,
370 array_type: &ArrayType,
371 ) -> (Identifier, Statement) {
372 let index =
373 Literal::integer(IntegerType::U32, i.to_string(), Default::default(), self.state.node_builder.next_id());
374 self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
375 let access: Expression = ArrayAccess {
376 array: identifier.into(),
377 index: index.into(),
378 span: Default::default(),
379 id: self.state.node_builder.next_id(),
380 }
381 .into();
382 self.state.type_table.insert(access.id(), array_type.element_type().clone());
383 self.unique_simple_definition(access)
384 }
385
386 pub fn ternary_array(
387 &mut self,
388 array: &ArrayType,
389 condition: &Expression,
390 first: &Identifier,
391 second: &Identifier,
392 ) -> (Expression, Vec<Statement>) {
393 let mut statements = Vec::new();
395 let elements = (0..array.length.as_u32().expect("length should be known at this point") as usize)
397 .map(|i| {
398 let (first, stmt) = self.make_array_access_definition(i, *first, array);
400 statements.push(stmt);
401 let (second, stmt) = self.make_array_access_definition(i, *second, array);
403 statements.push(stmt);
404
405 let ternary = TernaryExpression {
407 condition: condition.clone(),
408 if_true: first.into(),
410 if_false: second.into(),
412 span: Default::default(),
413 id: self.state.node_builder.next_id(),
414 };
415 self.state.type_table.insert(ternary.id(), array.element_type().clone());
416
417 let (expression, stmts) = self.reconstruct_ternary(ternary);
418
419 statements.extend(stmts);
421
422 expression
423 })
424 .collect();
425
426 let (expr, stmts) = self.reconstruct_array(ArrayExpression {
428 elements,
429 span: Default::default(),
430 id: {
431 let id = self.state.node_builder.next_id();
433 self.state.type_table.insert(id, Type::Array(array.clone()));
435 id
436 },
437 });
438
439 statements.extend(stmts);
441
442 let (identifier, statement) = self.unique_simple_definition(expr);
444
445 statements.push(statement);
446
447 (identifier.into(), statements)
448 }
449
450 fn make_struct_access_definition(
452 &mut self,
453 inner: Identifier,
454 name: Identifier,
455 type_: Type,
456 ) -> (Identifier, Statement) {
457 let expr: Expression =
458 MemberAccess { inner: inner.into(), name, span: Default::default(), id: self.state.node_builder.next_id() }
459 .into();
460 self.state.type_table.insert(expr.id(), type_);
461 self.unique_simple_definition(expr)
462 }
463
464 pub fn ternary_struct(
465 &mut self,
466 struct_: &Composite,
467 condition: &Expression,
468 first: &Identifier,
469 second: &Identifier,
470 ) -> (Expression, Vec<Statement>) {
471 let mut statements = Vec::new();
473 let members = struct_
475 .members
476 .iter()
477 .map(|Member { identifier, type_, .. }| {
478 let (first, stmt) = self.make_struct_access_definition(*first, *identifier, type_.clone());
479 statements.push(stmt);
480 let (second, stmt) = self.make_struct_access_definition(*second, *identifier, type_.clone());
481 statements.push(stmt);
482 let ternary = TernaryExpression {
484 condition: condition.clone(),
485 if_true: first.into(),
486 if_false: second.into(),
487 span: Default::default(),
488 id: self.state.node_builder.next_id(),
489 };
490 self.state.type_table.insert(ternary.id(), type_.clone());
491 let (expression, stmts) = self.reconstruct_ternary(ternary);
492
493 statements.extend(stmts);
495
496 StructVariableInitializer {
497 identifier: *identifier,
498 expression: Some(expression),
499 span: Default::default(),
500 id: self.state.node_builder.next_id(),
501 }
502 })
503 .collect();
504
505 let (expr, stmts) = self.reconstruct_struct_init(StructExpression {
506 name: struct_.identifier,
507 const_arguments: Vec::new(), members,
509 span: Default::default(),
510 id: {
511 let id = self.state.node_builder.next_id();
513 self.state.type_table.insert(
515 id,
516 Type::Composite(CompositeType {
517 id: struct_.identifier,
518 const_arguments: Vec::new(), program: struct_.external,
520 }),
521 );
522 id
523 },
524 });
525
526 statements.extend(stmts);
528
529 let (identifier, statement) = self.unique_simple_definition(expr);
531
532 statements.push(statement);
533
534 (identifier.into(), statements)
535 }
536
537 pub fn ternary_tuple(
538 &mut self,
539 tuple_type: &TupleType,
540 condition: &Expression,
541 first: &Expression,
542 second: &Expression,
543 ) -> (Expression, Vec<Statement>) {
544 let make_access = |base_expression: &Expression, i: usize, ty: Type, slf: &mut Self| -> Expression {
545 match base_expression {
546 expr @ Expression::Identifier(..) => {
547 let id = slf.state.node_builder.next_id();
549 slf.state.type_table.insert(id, ty);
551 TupleAccess { tuple: expr.clone(), index: NonNegativeNumber::from(i), span: Default::default(), id }
552 .into()
553 }
554
555 Expression::Tuple(tuple_expr) => tuple_expr.elements[i].clone(),
556
557 _ => panic!("SSA should have prevented this"),
558 }
559 };
560
561 let mut statements = Vec::new();
563 let elements = tuple_type
565 .elements()
566 .iter()
567 .enumerate()
568 .map(|(i, type_)| {
569 let access1 = make_access(first, i, type_.clone(), self);
571 let (first, stmt) = self.unique_simple_definition(access1);
572 statements.push(stmt);
573 let access2 = make_access(second, i, type_.clone(), self);
574 let (second, stmt) = self.unique_simple_definition(access2);
576 statements.push(stmt);
577
578 let ternary = TernaryExpression {
580 condition: condition.clone(),
581 if_true: first.into(),
582 if_false: second.into(),
583 span: Default::default(),
584 id: self.state.node_builder.next_id(),
585 };
586 self.state.type_table.insert(ternary.id(), type_.clone());
587 let (expression, stmts) = self.reconstruct_ternary(ternary);
588
589 statements.extend(stmts);
591
592 expression
593 })
594 .collect();
595
596 let tuple = TupleExpression {
598 elements,
599 span: Default::default(),
600 id: {
601 let id = self.state.node_builder.next_id();
603 self.state.type_table.insert(id, Type::Tuple(tuple_type.clone()));
605 id
606 },
607 };
608 let (expr, stmts) = self.reconstruct_tuple(tuple);
609
610 statements.extend(stmts);
612
613 if let Expression::Identifier(..) = first {
614 let (identifier, statement) = self.unique_simple_definition(expr);
616
617 statements.push(statement);
618
619 (identifier.into(), statements)
620 } else {
621 (expr, statements)
623 }
624 }
625}