1use super::OptionLoweringVisitor;
18
19use leo_ast::*;
20use leo_span::{Span, Symbol};
21
22use indexmap::IndexMap;
23
24impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
25 type AdditionalInput = Option<Type>;
26 type AdditionalOutput = Vec<Statement>;
27
28 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
30 let (length, stmts) = self.reconstruct_expression(*input.length, &None);
31 (
32 Type::Array(ArrayType {
33 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
34 length: Box::new(length),
35 }),
36 stmts,
37 )
38 }
39
40 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
41 let mut statements = Vec::new();
42
43 let const_arguments = input
44 .const_arguments
45 .into_iter()
46 .map(|arg| {
47 let (expr, stmts) = self.reconstruct_expression(arg, &None);
48 statements.extend(stmts);
49 expr
50 })
51 .collect();
52
53 (Type::Composite(CompositeType { const_arguments, ..input }), statements)
54 }
55
56 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
57 let (inner_type, _) = self.reconstruct_type(*input.inner.clone());
58
59 let struct_name = self.insert_optional_wrapper_struct(&inner_type);
61
62 (
63 Type::Composite(CompositeType {
64 path: Path::from(Identifier::new(struct_name, self.state.node_builder.next_id())).into_absolute(),
65 const_arguments: vec![], program: None, }),
68 Default::default(),
69 )
70 }
71
72 fn reconstruct_expression(
74 &mut self,
75 input: Expression,
76 additional: &Option<Type>,
77 ) -> (Expression, Self::AdditionalOutput) {
78 if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input {
80 let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&input.id()) else {
81 panic!("Type checking guarantees that `None` has an Optional type");
82 };
83
84 return (self.wrap_none(&inner), vec![]);
85 }
86
87 let (expr, stmts) = match input {
89 Expression::AssociatedConstant(e) => self.reconstruct_associated_constant(e, additional),
90 Expression::AssociatedFunction(e) => self.reconstruct_associated_function(e, additional),
91 Expression::Async(e) => self.reconstruct_async(e, additional),
92 Expression::Array(e) => self.reconstruct_array(e, additional),
93 Expression::ArrayAccess(e) => self.reconstruct_array_access(*e, additional),
94 Expression::Binary(e) => self.reconstruct_binary(*e, additional),
95 Expression::Call(e) => self.reconstruct_call(*e, additional),
96 Expression::Cast(e) => self.reconstruct_cast(*e, additional),
97 Expression::Struct(e) => self.reconstruct_struct_init(e, additional),
98 Expression::Err(e) => self.reconstruct_err(e, additional),
99 Expression::Path(e) => self.reconstruct_path(e, additional),
100 Expression::Literal(e) => self.reconstruct_literal(e, additional),
101 Expression::Locator(e) => self.reconstruct_locator(e, additional),
102 Expression::MemberAccess(e) => self.reconstruct_member_access(*e, additional),
103 Expression::Repeat(e) => self.reconstruct_repeat(*e, additional),
104 Expression::Ternary(e) => self.reconstruct_ternary(*e, additional),
105 Expression::Tuple(e) => self.reconstruct_tuple(e, additional),
106 Expression::TupleAccess(e) => self.reconstruct_tuple_access(*e, additional),
107 Expression::Unary(e) => self.reconstruct_unary(*e, additional),
108 Expression::Unit(e) => self.reconstruct_unit(e, additional),
109 };
110
111 if let Some(Type::Optional(OptionalType { inner })) = additional {
113 let actual_expr_type =
114 self.state.type_table.get(&expr.id()).expect(
115 "Type table must contain type for this expression ID; IDs are not modified during lowering",
116 );
117
118 if actual_expr_type.can_coerce_to(inner) {
119 return (self.wrap_optional_value(expr, *inner.clone()), stmts);
120 }
121 }
122
123 (expr, stmts)
124 }
125
126 fn reconstruct_array_access(
127 &mut self,
128 mut input: ArrayAccess,
129 _additional: &Self::AdditionalInput,
130 ) -> (Expression, Self::AdditionalOutput) {
131 let (array, mut stmts_array) = self.reconstruct_expression(input.array, &None);
132 let (index, mut stmts_index) = self.reconstruct_expression(input.index, &None);
133
134 input.array = array;
135 input.index = index;
136
137 stmts_array.append(&mut stmts_index);
139
140 (input.into(), stmts_array)
141 }
142
143 fn reconstruct_associated_function(
144 &mut self,
145 mut input: AssociatedFunctionExpression,
146 _additional: &Option<Type>,
147 ) -> (Expression, Self::AdditionalOutput) {
148 match CoreFunction::from_symbols(input.variant.name, input.name.name) {
149 Some(CoreFunction::OptionalUnwrap) => {
150 let [optional_expr] = &input.arguments[..] else {
151 panic!("guaranteed by type checking");
152 };
153
154 let (reconstructed_optional_expr, mut stmts) =
155 self.reconstruct_expression(optional_expr.clone(), &None);
156
157 let val_access = MemberAccess {
159 inner: reconstructed_optional_expr.clone(),
160 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
161 span: Span::default(),
162 id: self.state.node_builder.next_id(),
163 };
164
165 let is_some_access = MemberAccess {
166 inner: reconstructed_optional_expr.clone(),
167 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
168 span: Span::default(),
169 id: self.state.node_builder.next_id(),
170 };
171
172 let assert_stmt = AssertStatement {
174 variant: AssertVariant::Assert(is_some_access.clone().into()),
175 span: Span::default(),
176 id: self.state.node_builder.next_id(),
177 };
178
179 stmts.push(assert_stmt.into());
181
182 (val_access.into(), stmts)
183 }
184 Some(CoreFunction::OptionalUnwrapOr) => {
185 let [optional_expr, default_expr] = &input.arguments[..] else {
186 panic!("unwrap_or must have 2 arguments: optional and default");
187 };
188
189 let (reconstructed_optional_expr, mut stmts1) =
190 self.reconstruct_expression(optional_expr.clone(), &None);
191
192 let Some(Type::Optional(OptionalType { inner: expected_inner_type })) =
194 self.state.type_table.get(&optional_expr.id())
195 else {
196 panic!("guaranteed by type checking")
197 };
198
199 let (reconstructed_fallback_expr, stmts2) =
200 self.reconstruct_expression(default_expr.clone(), &Some(*expected_inner_type.clone()));
201
202 let val_access = MemberAccess {
204 inner: reconstructed_optional_expr.clone(),
205 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
206 span: Span::default(),
207 id: self.state.node_builder.next_id(),
208 };
209
210 let is_some_access = MemberAccess {
211 inner: reconstructed_optional_expr,
212 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
213 span: Span::default(),
214 id: self.state.node_builder.next_id(),
215 };
216
217 let ternary_expr = TernaryExpression {
219 condition: is_some_access.into(),
220 if_true: val_access.into(),
221 if_false: reconstructed_fallback_expr,
222 span: Span::default(),
223 id: self.state.node_builder.next_id(),
224 };
225
226 stmts1.extend(stmts2);
227 (ternary_expr.into(), stmts1)
228 }
229 _ => {
230 let statements: Vec<_> = input
231 .arguments
232 .iter_mut()
233 .flat_map(|arg| {
234 let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &None);
235 *arg = expr;
236 stmts
237 })
238 .collect();
239
240 (input.into(), statements)
241 }
242 }
243 }
244
245 fn reconstruct_member_access(
246 &mut self,
247 mut input: MemberAccess,
248 _additional: &Self::AdditionalInput,
249 ) -> (Expression, Self::AdditionalOutput) {
250 let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &None);
251
252 input.inner = inner;
253
254 (input.into(), stmts_inner)
255 }
256
257 fn reconstruct_repeat(
258 &mut self,
259 mut input: RepeatExpression,
260 additional: &Self::AdditionalInput,
261 ) -> (Expression, Self::AdditionalOutput) {
262 let expected_element_type =
264 additional.clone().or_else(|| self.state.type_table.get(&input.id)).and_then(|mut ty| {
265 if let Type::Optional(inner) = ty {
266 ty = *inner.inner;
267 }
268 match ty {
269 Type::Array(array_ty) => Some(*array_ty.element_type),
270 _ => None,
271 }
272 });
273
274 let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &expected_element_type);
276
277 let (count, mut stmts_count) = self.reconstruct_expression(input.count, &None);
278
279 input.expr = expr;
280 input.count = count;
281
282 stmts_expr.append(&mut stmts_count);
283
284 (input.into(), stmts_expr)
285 }
286
287 fn reconstruct_tuple_access(
288 &mut self,
289 mut input: TupleAccess,
290 _additional: &Self::AdditionalInput,
291 ) -> (Expression, Self::AdditionalOutput) {
292 let (tuple, stmts) = self.reconstruct_expression(input.tuple, &None);
293
294 input.tuple = tuple;
295
296 (input.into(), stmts)
297 }
298
299 fn reconstruct_array(
300 &mut self,
301 mut input: ArrayExpression,
302 additional: &Option<Type>,
303 ) -> (Expression, Self::AdditionalOutput) {
304 let expected_element_type = additional
305 .clone()
306 .or_else(|| self.state.type_table.get(&input.id))
307 .and_then(|mut ty| {
308 if let Type::Optional(inner) = ty {
310 ty = *inner.inner;
311 }
312 match ty {
314 Type::Array(array_ty) => Some(*array_ty.element_type),
315 _ => None,
316 }
317 })
318 .expect("guaranteed by type checking");
319
320 let mut all_stmts = Vec::new();
321 let mut new_elements = Vec::with_capacity(input.elements.len());
322
323 for element in input.elements.into_iter() {
324 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_element_type.clone()));
325 all_stmts.append(&mut stmts);
326 new_elements.push(expr);
327 }
328
329 input.elements = new_elements;
330
331 (input.into(), all_stmts)
332 }
333
334 fn reconstruct_binary(
335 &mut self,
336 mut input: BinaryExpression,
337 _additional: &Self::AdditionalInput,
338 ) -> (Expression, Self::AdditionalOutput) {
339 let (left, mut stmts_left) = self.reconstruct_expression(input.left, &None);
340 let (right, mut stmts_right) = self.reconstruct_expression(input.right, &None);
341
342 input.left = left;
343 input.right = right;
344
345 stmts_left.append(&mut stmts_right);
347
348 (input.into(), stmts_left)
349 }
350
351 fn reconstruct_call(
352 &mut self,
353 mut input: CallExpression,
354 _additional: &Self::AdditionalInput,
355 ) -> (Expression, Self::AdditionalOutput) {
356 let callee_program = input.program.unwrap_or(self.program);
357
358 let func_symbol = self
359 .state
360 .symbol_table
361 .lookup_function(&Location::new(callee_program, input.function.absolute_path()))
362 .expect("The symbol table creator should already have visited all functions.")
363 .clone();
364
365 let mut all_stmts = Vec::new();
366
367 let mut const_arguments = Vec::with_capacity(input.const_arguments.len());
369 for (arg, param) in input.const_arguments.into_iter().zip(func_symbol.function.const_parameters.iter()) {
370 let expected_type = Some(param.type_.clone());
371 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
372 all_stmts.append(&mut stmts);
373 const_arguments.push(expr);
374 }
375
376 let mut arguments = Vec::with_capacity(input.arguments.len());
378 for (arg, param) in input.arguments.into_iter().zip(func_symbol.function.input.iter()) {
379 let expected_type = Some(param.type_.clone());
380 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
381 all_stmts.append(&mut stmts);
382 arguments.push(expr);
383 }
384
385 input.const_arguments = const_arguments;
386 input.arguments = arguments;
387
388 (input.into(), all_stmts)
389 }
390
391 fn reconstruct_cast(
392 &mut self,
393 mut input: CastExpression,
394 _additional: &Self::AdditionalInput,
395 ) -> (Expression, Self::AdditionalOutput) {
396 let (expr, stmts) = self.reconstruct_expression(input.expression, &None);
397
398 input.expression = expr;
399
400 (input.into(), stmts)
401 }
402
403 fn reconstruct_struct_init(
404 &mut self,
405 mut input: StructExpression,
406 additional: &Option<Type>,
407 ) -> (Expression, Self::AdditionalOutput) {
408 let (const_parameters, member_types): (Vec<Type>, IndexMap<Symbol, Type>) = {
409 let mut ty = additional.clone().or_else(|| self.state.type_table.get(&input.id)).expect("type checked");
410
411 if let Type::Optional(inner) = ty {
412 ty = *inner.inner;
413 }
414
415 if let Type::Composite(composite) = ty {
416 let program = composite.program.unwrap_or(self.program);
417 let location = Location::new(program, composite.path.absolute_path());
418 let struct_def = self
419 .state
420 .symbol_table
421 .lookup_record(&location)
422 .or_else(|| self.state.symbol_table.lookup_struct(&composite.path.absolute_path()))
423 .or_else(|| self.new_structs.get(&composite.path.identifier().name))
424 .expect("guaranteed by type checking");
425
426 let const_parameters = struct_def.const_parameters.iter().map(|param| param.type_.clone()).collect();
427 let member_types =
428 struct_def.members.iter().map(|member| (member.identifier.name, member.type_.clone())).collect();
429
430 (const_parameters, member_types)
431 } else {
432 panic!("expected Type::Composite")
433 }
434 };
435
436 let (const_arguments, mut const_arg_stmts): (Vec<_>, Vec<_>) = input
438 .const_arguments
439 .into_iter()
440 .zip(const_parameters.iter())
441 .map(|(arg, ty)| self.reconstruct_expression(arg, &Some(ty.clone())))
442 .unzip();
443
444 let (members, mut member_stmts): (Vec<_>, Vec<_>) = input
446 .members
447 .into_iter()
448 .map(|member| {
449 let expected_type =
450 member_types.get(&member.identifier.name).expect("guaranteed by type checking").clone();
451
452 let expression =
453 member.expression.unwrap_or_else(|| Path::from(member.identifier).into_absolute().into());
454
455 let (new_expr, stmts) = self.reconstruct_expression(expression, &Some(expected_type));
456
457 (
458 StructVariableInitializer {
459 identifier: member.identifier,
460 expression: Some(new_expr),
461 span: member.span,
462 id: member.id,
463 },
464 stmts,
465 )
466 })
467 .unzip();
468
469 input.const_arguments = const_arguments;
470 input.members = members;
471
472 const_arg_stmts.append(&mut member_stmts);
474 let all_stmts = const_arg_stmts.into_iter().flatten().collect();
475
476 (input.into(), all_stmts)
477 }
478
479 fn reconstruct_ternary(
480 &mut self,
481 mut input: TernaryExpression,
482 additional: &Self::AdditionalInput,
483 ) -> (Expression, Self::AdditionalOutput) {
484 let type_ = self.state.type_table.get(&input.id());
485 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
486 let additional = if let Some(expected) = additional { Some(expected.clone()) } else { type_ };
487
488 let (if_true, mut stmts_if_true) = self.reconstruct_expression(input.if_true, &additional);
489 let (if_false, mut stmts_if_false) = self.reconstruct_expression(input.if_false, &additional);
490
491 input.condition = condition;
492 input.if_true = if_true;
493 input.if_false = if_false;
494
495 stmts_condition.append(&mut stmts_if_true);
497 stmts_condition.append(&mut stmts_if_false);
498
499 (input.into(), stmts_condition)
500 }
501
502 fn reconstruct_tuple(
503 &mut self,
504 mut input: TupleExpression,
505 additional: &Option<Type>,
506 ) -> (Expression, Self::AdditionalOutput) {
507 let mut all_stmts = Vec::new();
508 let mut new_elements = Vec::with_capacity(input.elements.len());
509
510 let expected_types = additional
512 .as_ref()
513 .and_then(|ty| {
514 let mut ty = ty.clone();
515
516 if let Type::Optional(inner) = ty {
518 ty = *inner.inner;
519 }
520 if let Type::Tuple(tuple_ty) = ty { Some(tuple_ty.elements.clone()) } else { None }
522 })
523 .expect("guaranteed by type checking");
524
525 for (element, expected_ty) in input.elements.into_iter().zip(expected_types) {
527 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_ty));
528 all_stmts.append(&mut stmts);
529 new_elements.push(expr);
530 }
531
532 input.elements = new_elements;
533
534 (input.into(), all_stmts)
535 }
536
537 fn reconstruct_unary(
538 &mut self,
539 mut input: UnaryExpression,
540 _additional: &Self::AdditionalInput,
541 ) -> (Expression, Self::AdditionalOutput) {
542 let (receiver, stmts) = self.reconstruct_expression(input.receiver, &None);
543
544 input.receiver = receiver;
545
546 (input.into(), stmts)
547 }
548
549 fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
551 let mut all_stmts = Vec::new();
552
553 input.variant = match input.variant {
554 AssertVariant::Assert(expr) => {
555 let (expr, mut stmts) = self.reconstruct_expression(expr, &None);
556 all_stmts.append(&mut stmts);
557 AssertVariant::Assert(expr)
558 }
559 AssertVariant::AssertEq(left, right) => {
560 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
561 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
562 all_stmts.append(&mut stmts_left);
563 all_stmts.append(&mut stmts_right);
564 AssertVariant::AssertEq(left, right)
565 }
566 AssertVariant::AssertNeq(left, right) => {
567 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
568 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
569 all_stmts.append(&mut stmts_left);
570 all_stmts.append(&mut stmts_right);
571 AssertVariant::AssertNeq(left, right)
572 }
573 };
574
575 (input.into(), all_stmts)
576 }
577
578 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
579 let expected_ty = self.state.type_table.get(&input.place.id()).expect("type checked");
580
581 let (new_place, place_stmts) = self.reconstruct_expression(input.place, &None);
582 let (new_value, value_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
583
584 (AssignStatement { place: new_place, value: new_value, ..input }.into(), [place_stmts, value_stmts].concat())
585 }
586
587 fn reconstruct_conditional(&mut self, mut input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
588 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
589 let (then_block, mut stmts_then) = self.reconstruct_block(input.then);
590
591 let otherwise = match input.otherwise {
592 Some(otherwise_stmt) => {
593 let (stmt, mut stmts_otherwise) = self.reconstruct_statement(*otherwise_stmt);
594 stmts_condition.append(&mut stmts_then);
595 stmts_condition.append(&mut stmts_otherwise);
596 Some(Box::new(stmt))
597 }
598 None => {
599 stmts_condition.append(&mut stmts_then);
600 None
601 }
602 };
603
604 input.condition = condition;
605 input.then = then_block;
606 input.otherwise = otherwise;
607
608 (input.into(), stmts_condition)
609 }
610
611 fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
612 let (type_, mut stmts_type) = self.reconstruct_type(input.type_.clone());
613 let (value, mut stmts_value) = self.reconstruct_expression(input.value, &Some(input.type_));
614
615 input.type_ = type_;
616 input.value = value;
617
618 stmts_type.append(&mut stmts_value);
619
620 (input.into(), stmts_type)
621 }
622
623 fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
624 let mut statements = Vec::with_capacity(block.statements.len());
625
626 for statement in block.statements {
627 let (reconstructed_statement, mut additional_stmts) = self.reconstruct_statement(statement);
628 statements.append(&mut additional_stmts);
629 statements.push(reconstructed_statement);
630 }
631
632 block.statements = statements;
633
634 (block, Default::default())
635 }
636
637 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
638 let expected_ty = input
642 .type_
643 .clone()
644 .or_else(|| self.state.type_table.get(&input.value.id()))
645 .expect("guaranteed by type checking");
646
647 let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
648
649 input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
650 input.value = new_value;
651
652 (input.into(), additional_stmts)
653 }
654
655 fn reconstruct_expression_statement(
656 &mut self,
657 mut input: ExpressionStatement,
658 ) -> (Statement, Self::AdditionalOutput) {
659 let (expression, stmts) = self.reconstruct_expression(input.expression, &None);
660
661 input.expression = expression;
662
663 (input.into(), stmts)
664 }
665
666 fn reconstruct_iteration(&mut self, mut input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
667 let mut all_stmts = Vec::new();
668
669 let type_ = match input.type_ {
670 Some(ty) => {
671 let (new_ty, mut stmts_ty) = self.reconstruct_type(ty);
672 all_stmts.append(&mut stmts_ty);
673 Some(new_ty)
674 }
675 None => None,
676 };
677
678 let (start, mut stmts_start) = self.reconstruct_expression(input.start, &None);
679 let (stop, mut stmts_stop) = self.reconstruct_expression(input.stop, &None);
680 let (block, mut stmts_block) = self.reconstruct_block(input.block);
681
682 all_stmts.append(&mut stmts_start);
683 all_stmts.append(&mut stmts_stop);
684 all_stmts.append(&mut stmts_block);
685
686 input.type_ = type_;
687 input.start = start;
688 input.stop = stop;
689 input.block = block;
690
691 (input.into(), all_stmts)
692 }
693
694 fn reconstruct_return(&mut self, mut input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
695 let caller_name = self.function.expect("`self.function` is set every time a function is visited.");
696 let caller_path = self.module.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
697
698 let func_symbol = self
699 .state
700 .symbol_table
701 .lookup_function(&Location::new(self.program, caller_path))
702 .expect("The symbol table creator should already have visited all functions.");
703
704 let return_type = func_symbol.function.output_type.clone();
705
706 let (expression, statements) = self.reconstruct_expression(input.expression, &Some(return_type));
707 input.expression = expression;
708
709 (input.into(), statements)
710 }
711}