1use super::OptionLoweringVisitor;
18
19use leo_ast::*;
20use leo_span::{Span, Symbol};
21
22use indexmap::IndexMap;
23
24impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
25 type AdditionalInput = Option<Type>;
26 type AdditionalOutput = Vec<Statement>;
27
28 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
30 let (length, stmts) = self.reconstruct_expression(*input.length, &None);
31 (
32 Type::Array(ArrayType {
33 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
34 length: Box::new(length),
35 }),
36 stmts,
37 )
38 }
39
40 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
41 let mut statements = Vec::new();
42
43 let const_arguments = input
44 .const_arguments
45 .into_iter()
46 .map(|arg| {
47 let (expr, stmts) = self.reconstruct_expression(arg, &None);
48 statements.extend(stmts);
49 expr
50 })
51 .collect();
52
53 (Type::Composite(CompositeType { const_arguments, ..input }), statements)
54 }
55
56 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
57 let (inner_type, _) = self.reconstruct_type(*input.inner.clone());
58
59 let struct_name = self.insert_optional_wrapper_struct(&inner_type);
61
62 (
63 Type::Composite(CompositeType {
64 path: Path::from(Identifier::new(struct_name, self.state.node_builder.next_id())).into_absolute(),
65 const_arguments: vec![], program: None, }),
68 Default::default(),
69 )
70 }
71
72 fn reconstruct_expression(
74 &mut self,
75 input: Expression,
76 additional: &Option<Type>,
77 ) -> (Expression, Self::AdditionalOutput) {
78 if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input {
80 let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&input.id()) else {
81 panic!("Type checking guarantees that `None` has an Optional type");
82 };
83
84 return (self.wrap_none(&inner), vec![]);
85 }
86
87 let (expr, stmts) = match input {
89 Expression::Intrinsic(e) => self.reconstruct_intrinsic(*e, additional),
90 Expression::Async(e) => self.reconstruct_async(e, additional),
91 Expression::Array(e) => self.reconstruct_array(e, additional),
92 Expression::ArrayAccess(e) => self.reconstruct_array_access(*e, additional),
93 Expression::Binary(e) => self.reconstruct_binary(*e, additional),
94 Expression::Call(e) => self.reconstruct_call(*e, additional),
95 Expression::Cast(e) => self.reconstruct_cast(*e, additional),
96 Expression::Struct(e) => self.reconstruct_struct_init(e, additional),
97 Expression::Err(e) => self.reconstruct_err(e, additional),
98 Expression::Path(e) => self.reconstruct_path(e, additional),
99 Expression::Literal(e) => self.reconstruct_literal(e, additional),
100 Expression::Locator(e) => self.reconstruct_locator(e, additional),
101 Expression::MemberAccess(e) => self.reconstruct_member_access(*e, additional),
102 Expression::Repeat(e) => self.reconstruct_repeat(*e, additional),
103 Expression::Ternary(e) => self.reconstruct_ternary(*e, additional),
104 Expression::Tuple(e) => self.reconstruct_tuple(e, additional),
105 Expression::TupleAccess(e) => self.reconstruct_tuple_access(*e, additional),
106 Expression::Unary(e) => self.reconstruct_unary(*e, additional),
107 Expression::Unit(e) => self.reconstruct_unit(e, additional),
108 };
109
110 if let Some(Type::Optional(OptionalType { inner })) = additional {
112 let actual_expr_type =
113 self.state.type_table.get(&expr.id()).expect(
114 "Type table must contain type for this expression ID; IDs are not modified during lowering",
115 );
116
117 if actual_expr_type.can_coerce_to(inner) {
118 return (self.wrap_optional_value(expr, *inner.clone()), stmts);
119 }
120 }
121
122 (expr, stmts)
123 }
124
125 fn reconstruct_array_access(
126 &mut self,
127 mut input: ArrayAccess,
128 _additional: &Self::AdditionalInput,
129 ) -> (Expression, Self::AdditionalOutput) {
130 let (array, mut stmts_array) = self.reconstruct_expression(input.array, &None);
131 let (index, mut stmts_index) = self.reconstruct_expression(input.index, &None);
132
133 input.array = array;
134 input.index = index;
135
136 stmts_array.append(&mut stmts_index);
138
139 (input.into(), stmts_array)
140 }
141
142 fn reconstruct_intrinsic(
143 &mut self,
144 mut input: IntrinsicExpression,
145 _additional: &Option<Type>,
146 ) -> (Expression, Self::AdditionalOutput) {
147 match Intrinsic::from_symbol(input.name, &input.type_parameters) {
148 Some(Intrinsic::OptionalUnwrap) => {
149 let [optional_expr] = &input.arguments[..] else {
150 panic!("guaranteed by type checking");
151 };
152
153 let (reconstructed_optional_expr, mut stmts) =
154 self.reconstruct_expression(optional_expr.clone(), &None);
155
156 let val_access = MemberAccess {
158 inner: reconstructed_optional_expr.clone(),
159 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
160 span: Span::default(),
161 id: self.state.node_builder.next_id(),
162 };
163 let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&optional_expr.id())
164 else {
165 panic!("guaranteed by type checking");
166 };
167 self.state.type_table.insert(val_access.id(), *inner);
168
169 let is_some_access = MemberAccess {
170 inner: reconstructed_optional_expr.clone(),
171 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
172 span: Span::default(),
173 id: self.state.node_builder.next_id(),
174 };
175 self.state.type_table.insert(is_some_access.id(), Type::Boolean);
176
177 let assert_stmt = AssertStatement {
179 variant: AssertVariant::Assert(is_some_access.clone().into()),
180 span: Span::default(),
181 id: self.state.node_builder.next_id(),
182 };
183
184 stmts.push(assert_stmt.into());
186
187 (val_access.into(), stmts)
188 }
189 Some(Intrinsic::OptionalUnwrapOr) => {
190 let [optional_expr, default_expr] = &input.arguments[..] else {
191 panic!("unwrap_or must have 2 arguments: optional and default");
192 };
193
194 let (reconstructed_optional_expr, mut stmts1) =
195 self.reconstruct_expression(optional_expr.clone(), &None);
196
197 let Some(Type::Optional(OptionalType { inner: expected_inner_type })) =
199 self.state.type_table.get(&optional_expr.id())
200 else {
201 panic!("guaranteed by type checking")
202 };
203
204 let (reconstructed_fallback_expr, stmts2) =
205 self.reconstruct_expression(default_expr.clone(), &Some(*expected_inner_type.clone()));
206
207 let val_access = MemberAccess {
209 inner: reconstructed_optional_expr.clone(),
210 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
211 span: Span::default(),
212 id: self.state.node_builder.next_id(),
213 };
214 self.state.type_table.insert(val_access.id(), *expected_inner_type.clone());
215
216 let is_some_access = MemberAccess {
217 inner: reconstructed_optional_expr,
218 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
219 span: Span::default(),
220 id: self.state.node_builder.next_id(),
221 };
222 self.state.type_table.insert(is_some_access.id(), Type::Boolean);
223
224 let ternary_expr = TernaryExpression {
226 condition: is_some_access.into(),
227 if_true: val_access.into(),
228 if_false: reconstructed_fallback_expr,
229 span: Span::default(),
230 id: self.state.node_builder.next_id(),
231 };
232 self.state.type_table.insert(ternary_expr.id(), *expected_inner_type);
233
234 stmts1.extend(stmts2);
235 (ternary_expr.into(), stmts1)
236 }
237 _ => {
238 let statements: Vec<_> = input
239 .arguments
240 .iter_mut()
241 .flat_map(|arg| {
242 let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &None);
243 *arg = expr;
244 stmts
245 })
246 .collect();
247
248 (input.into(), statements)
249 }
250 }
251 }
252
253 fn reconstruct_member_access(
254 &mut self,
255 mut input: MemberAccess,
256 _additional: &Self::AdditionalInput,
257 ) -> (Expression, Self::AdditionalOutput) {
258 let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &None);
259
260 input.inner = inner;
261
262 (input.into(), stmts_inner)
263 }
264
265 fn reconstruct_repeat(
266 &mut self,
267 mut input: RepeatExpression,
268 additional: &Self::AdditionalInput,
269 ) -> (Expression, Self::AdditionalOutput) {
270 let expected_element_type =
272 additional.clone().or_else(|| self.state.type_table.get(&input.id)).and_then(|mut ty| {
273 if let Type::Optional(inner) = ty {
274 ty = *inner.inner;
275 }
276 match ty {
277 Type::Array(array_ty) => Some(*array_ty.element_type),
278 _ => None,
279 }
280 });
281
282 let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &expected_element_type);
284
285 let (count, mut stmts_count) = self.reconstruct_expression(input.count, &None);
286
287 input.expr = expr;
288 input.count = count;
289
290 stmts_expr.append(&mut stmts_count);
291
292 (input.into(), stmts_expr)
293 }
294
295 fn reconstruct_tuple_access(
296 &mut self,
297 mut input: TupleAccess,
298 _additional: &Self::AdditionalInput,
299 ) -> (Expression, Self::AdditionalOutput) {
300 let (tuple, stmts) = self.reconstruct_expression(input.tuple, &None);
301
302 input.tuple = tuple;
303
304 (input.into(), stmts)
305 }
306
307 fn reconstruct_array(
308 &mut self,
309 mut input: ArrayExpression,
310 additional: &Option<Type>,
311 ) -> (Expression, Self::AdditionalOutput) {
312 let expected_element_type = additional
313 .clone()
314 .or_else(|| self.state.type_table.get(&input.id))
315 .and_then(|mut ty| {
316 if let Type::Optional(inner) = ty {
318 ty = *inner.inner;
319 }
320 match ty {
322 Type::Array(array_ty) => Some(*array_ty.element_type),
323 _ => None,
324 }
325 })
326 .expect("guaranteed by type checking");
327
328 let mut all_stmts = Vec::new();
329 let mut new_elements = Vec::with_capacity(input.elements.len());
330
331 for element in input.elements.into_iter() {
332 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_element_type.clone()));
333 all_stmts.append(&mut stmts);
334 new_elements.push(expr);
335 }
336
337 input.elements = new_elements;
338
339 (input.into(), all_stmts)
340 }
341
342 fn reconstruct_binary(
343 &mut self,
344 mut input: BinaryExpression,
345 _additional: &Self::AdditionalInput,
346 ) -> (Expression, Self::AdditionalOutput) {
347 let (left, mut stmts_left) = self.reconstruct_expression(input.left, &None);
348 let (right, mut stmts_right) = self.reconstruct_expression(input.right, &None);
349
350 input.left = left;
351 input.right = right;
352
353 stmts_left.append(&mut stmts_right);
355
356 (input.into(), stmts_left)
357 }
358
359 fn reconstruct_call(
360 &mut self,
361 mut input: CallExpression,
362 _additional: &Self::AdditionalInput,
363 ) -> (Expression, Self::AdditionalOutput) {
364 let callee_program = input.program.unwrap_or(self.program);
365
366 let func_symbol = self
367 .state
368 .symbol_table
369 .lookup_function(&Location::new(callee_program, input.function.absolute_path()))
370 .expect("The symbol table creator should already have visited all functions.")
371 .clone();
372
373 let mut all_stmts = Vec::new();
374
375 let mut const_arguments = Vec::with_capacity(input.const_arguments.len());
377 for (arg, param) in input.const_arguments.into_iter().zip(func_symbol.function.const_parameters.iter()) {
378 let expected_type = Some(param.type_.clone());
379 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
380 all_stmts.append(&mut stmts);
381 const_arguments.push(expr);
382 }
383
384 let mut arguments = Vec::with_capacity(input.arguments.len());
386 for (arg, param) in input.arguments.into_iter().zip(func_symbol.function.input.iter()) {
387 let expected_type = Some(param.type_.clone());
388 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
389 all_stmts.append(&mut stmts);
390 arguments.push(expr);
391 }
392
393 input.const_arguments = const_arguments;
394 input.arguments = arguments;
395
396 (input.into(), all_stmts)
397 }
398
399 fn reconstruct_cast(
400 &mut self,
401 mut input: CastExpression,
402 _additional: &Self::AdditionalInput,
403 ) -> (Expression, Self::AdditionalOutput) {
404 let (expr, stmts) = self.reconstruct_expression(input.expression, &None);
405
406 input.expression = expr;
407
408 (input.into(), stmts)
409 }
410
411 fn reconstruct_struct_init(
412 &mut self,
413 mut input: StructExpression,
414 additional: &Option<Type>,
415 ) -> (Expression, Self::AdditionalOutput) {
416 let (const_parameters, member_types): (Vec<Type>, IndexMap<Symbol, Type>) = {
417 let mut ty = additional.clone().or_else(|| self.state.type_table.get(&input.id)).expect("type checked");
418
419 if let Type::Optional(inner) = ty {
420 ty = *inner.inner;
421 }
422
423 if let Type::Composite(composite) = ty {
424 let program = composite.program.unwrap_or(self.program);
425 let location = Location::new(program, composite.path.absolute_path());
426 let struct_def = self
427 .state
428 .symbol_table
429 .lookup_record(&location)
430 .or_else(|| self.state.symbol_table.lookup_struct(&composite.path.absolute_path()))
431 .or_else(|| self.new_structs.get(&composite.path.identifier().name))
432 .expect("guaranteed by type checking");
433
434 let const_parameters = struct_def.const_parameters.iter().map(|param| param.type_.clone()).collect();
435 let member_types =
436 struct_def.members.iter().map(|member| (member.identifier.name, member.type_.clone())).collect();
437
438 (const_parameters, member_types)
439 } else {
440 panic!("expected Type::Composite")
441 }
442 };
443
444 let (const_arguments, mut const_arg_stmts): (Vec<_>, Vec<_>) = input
446 .const_arguments
447 .into_iter()
448 .zip(const_parameters.iter())
449 .map(|(arg, ty)| self.reconstruct_expression(arg, &Some(ty.clone())))
450 .unzip();
451
452 let (members, mut member_stmts): (Vec<_>, Vec<_>) = input
454 .members
455 .into_iter()
456 .map(|member| {
457 let expected_type =
458 member_types.get(&member.identifier.name).expect("guaranteed by type checking").clone();
459
460 let expression =
461 member.expression.unwrap_or_else(|| Path::from(member.identifier).into_absolute().into());
462
463 let (new_expr, stmts) = self.reconstruct_expression(expression, &Some(expected_type));
464
465 (
466 StructVariableInitializer {
467 identifier: member.identifier,
468 expression: Some(new_expr),
469 span: member.span,
470 id: member.id,
471 },
472 stmts,
473 )
474 })
475 .unzip();
476
477 input.const_arguments = const_arguments;
478 input.members = members;
479
480 const_arg_stmts.append(&mut member_stmts);
482 let all_stmts = const_arg_stmts.into_iter().flatten().collect();
483
484 (input.into(), all_stmts)
485 }
486
487 fn reconstruct_ternary(
488 &mut self,
489 mut input: TernaryExpression,
490 additional: &Self::AdditionalInput,
491 ) -> (Expression, Self::AdditionalOutput) {
492 let type_ = self.state.type_table.get(&input.id());
493 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
494 let additional = if let Some(expected) = additional { Some(expected.clone()) } else { type_ };
495
496 let (if_true, mut stmts_if_true) = self.reconstruct_expression(input.if_true, &additional);
497 let (if_false, mut stmts_if_false) = self.reconstruct_expression(input.if_false, &additional);
498
499 input.condition = condition;
500 input.if_true = if_true;
501 input.if_false = if_false;
502
503 stmts_condition.append(&mut stmts_if_true);
505 stmts_condition.append(&mut stmts_if_false);
506
507 (input.into(), stmts_condition)
508 }
509
510 fn reconstruct_tuple(
511 &mut self,
512 mut input: TupleExpression,
513 additional: &Option<Type>,
514 ) -> (Expression, Self::AdditionalOutput) {
515 let expected_types = additional
517 .clone()
518 .or_else(|| self.state.type_table.get(&input.id))
519 .and_then(|mut ty| {
520 if let Type::Optional(inner) = ty {
522 ty = *inner.inner;
523 }
524
525 match ty {
527 Type::Tuple(tuple_ty) => Some(tuple_ty.elements.clone()),
528 _ => None,
529 }
530 })
531 .expect("guaranteed by type checking");
532
533 let mut all_stmts = Vec::new();
534 let mut new_elements = Vec::with_capacity(input.elements.len());
535
536 for (element, expected_ty) in input.elements.into_iter().zip(expected_types) {
538 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_ty));
539 all_stmts.append(&mut stmts);
540 new_elements.push(expr);
541 }
542
543 input.elements = new_elements;
544
545 (input.into(), all_stmts)
546 }
547
548 fn reconstruct_unary(
549 &mut self,
550 mut input: UnaryExpression,
551 _additional: &Self::AdditionalInput,
552 ) -> (Expression, Self::AdditionalOutput) {
553 let (receiver, stmts) = self.reconstruct_expression(input.receiver, &None);
554
555 input.receiver = receiver;
556
557 (input.into(), stmts)
558 }
559
560 fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
562 let mut all_stmts = Vec::new();
563
564 input.variant = match input.variant {
565 AssertVariant::Assert(expr) => {
566 let (expr, mut stmts) = self.reconstruct_expression(expr, &None);
567 all_stmts.append(&mut stmts);
568 AssertVariant::Assert(expr)
569 }
570 AssertVariant::AssertEq(left, right) => {
571 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
572 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
573 all_stmts.append(&mut stmts_left);
574 all_stmts.append(&mut stmts_right);
575 AssertVariant::AssertEq(left, right)
576 }
577 AssertVariant::AssertNeq(left, right) => {
578 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
579 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
580 all_stmts.append(&mut stmts_left);
581 all_stmts.append(&mut stmts_right);
582 AssertVariant::AssertNeq(left, right)
583 }
584 };
585
586 (input.into(), all_stmts)
587 }
588
589 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
590 let expected_ty = self.state.type_table.get(&input.place.id()).expect("type checked");
591
592 let (new_place, place_stmts) = self.reconstruct_expression(input.place, &None);
593 let (new_value, value_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
594
595 (AssignStatement { place: new_place, value: new_value, ..input }.into(), [place_stmts, value_stmts].concat())
596 }
597
598 fn reconstruct_conditional(&mut self, mut input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
599 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
600 let (then_block, mut stmts_then) = self.reconstruct_block(input.then);
601
602 let otherwise = match input.otherwise {
603 Some(otherwise_stmt) => {
604 let (stmt, mut stmts_otherwise) = self.reconstruct_statement(*otherwise_stmt);
605 stmts_condition.append(&mut stmts_then);
606 stmts_condition.append(&mut stmts_otherwise);
607 Some(Box::new(stmt))
608 }
609 None => {
610 stmts_condition.append(&mut stmts_then);
611 None
612 }
613 };
614
615 input.condition = condition;
616 input.then = then_block;
617 input.otherwise = otherwise;
618
619 (input.into(), stmts_condition)
620 }
621
622 fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
623 let (type_, mut stmts_type) = self.reconstruct_type(input.type_.clone());
624 let (value, mut stmts_value) = self.reconstruct_expression(input.value, &Some(input.type_));
625
626 input.type_ = type_;
627 input.value = value;
628
629 stmts_type.append(&mut stmts_value);
630
631 (input.into(), stmts_type)
632 }
633
634 fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
635 let mut statements = Vec::with_capacity(block.statements.len());
636
637 for statement in block.statements {
638 let (reconstructed_statement, mut additional_stmts) = self.reconstruct_statement(statement);
639 statements.append(&mut additional_stmts);
640 statements.push(reconstructed_statement);
641 }
642
643 block.statements = statements;
644
645 (block, Default::default())
646 }
647
648 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
649 let expected_ty = input
653 .type_
654 .clone()
655 .or_else(|| self.state.type_table.get(&input.value.id()))
656 .expect("guaranteed by type checking");
657
658 let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
659
660 input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
661 input.value = new_value;
662
663 (input.into(), additional_stmts)
664 }
665
666 fn reconstruct_expression_statement(
667 &mut self,
668 mut input: ExpressionStatement,
669 ) -> (Statement, Self::AdditionalOutput) {
670 let (expression, stmts) = self.reconstruct_expression(input.expression, &None);
671
672 input.expression = expression;
673
674 (input.into(), stmts)
675 }
676
677 fn reconstruct_iteration(&mut self, mut input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
678 let mut all_stmts = Vec::new();
679
680 let type_ = match input.type_ {
681 Some(ty) => {
682 let (new_ty, mut stmts_ty) = self.reconstruct_type(ty);
683 all_stmts.append(&mut stmts_ty);
684 Some(new_ty)
685 }
686 None => None,
687 };
688
689 let (start, mut stmts_start) = self.reconstruct_expression(input.start, &None);
690 let (stop, mut stmts_stop) = self.reconstruct_expression(input.stop, &None);
691 let (block, mut stmts_block) = self.reconstruct_block(input.block);
692
693 all_stmts.append(&mut stmts_start);
694 all_stmts.append(&mut stmts_stop);
695 all_stmts.append(&mut stmts_block);
696
697 input.type_ = type_;
698 input.start = start;
699 input.stop = stop;
700 input.block = block;
701
702 (input.into(), all_stmts)
703 }
704
705 fn reconstruct_return(&mut self, mut input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
706 let caller_name = self.function.expect("`self.function` is set every time a function is visited.");
707 let caller_path = self.module.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
708
709 let func_symbol = self
710 .state
711 .symbol_table
712 .lookup_function(&Location::new(self.program, caller_path))
713 .expect("The symbol table creator should already have visited all functions.");
714
715 let return_type = func_symbol.function.output_type.clone();
716
717 let (expression, statements) = self.reconstruct_expression(input.expression, &Some(return_type));
718 input.expression = expression;
719
720 (input.into(), statements)
721 }
722}