1use super::OptionLoweringVisitor;
18
19use leo_ast::*;
20use leo_span::{Span, Symbol};
21
22use indexmap::IndexMap;
23
24impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
25 type AdditionalInput = Option<Type>;
26 type AdditionalOutput = Vec<Statement>;
27
28 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
30 let (inner_type, _) = self.reconstruct_type(*input.inner.clone());
31
32 let struct_name = crate::make_optional_struct_symbol(&inner_type);
34
35 self.new_structs.entry(struct_name).or_insert_with(|| Composite {
37 identifier: Identifier::new(struct_name, self.state.node_builder.next_id()),
38 const_parameters: vec![], members: vec![
40 Member {
41 mode: Mode::None,
42 identifier: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
43 type_: Type::Boolean,
44 span: Span::default(),
45 id: self.state.node_builder.next_id(),
46 },
47 Member {
48 mode: Mode::None,
49 identifier: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
50 type_: inner_type,
51 span: Span::default(),
52 id: self.state.node_builder.next_id(),
53 },
54 ],
55 external: None,
56 is_record: false,
57 span: Span::default(),
58 id: self.state.node_builder.next_id(),
59 });
60
61 (
62 Type::Composite(CompositeType {
63 path: Path::from(Identifier::new(struct_name, self.state.node_builder.next_id())).into_absolute(),
64 const_arguments: vec![], program: None, }),
67 Default::default(),
68 )
69 }
70
71 fn reconstruct_expression(
73 &mut self,
74 input: Expression,
75 additional: &Option<Type>,
76 ) -> (Expression, Self::AdditionalOutput) {
77 if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input {
79 let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&input.id()) else {
80 panic!("Type checking guarantees that `None` has an Optional type");
81 };
82
83 return (self.wrap_none(&inner), vec![]);
84 }
85
86 let (expr, stmts) = match input {
88 Expression::AssociatedConstant(e) => self.reconstruct_associated_constant(e, additional),
89 Expression::AssociatedFunction(e) => self.reconstruct_associated_function(e, additional),
90 Expression::Async(e) => self.reconstruct_async(e, additional),
91 Expression::Array(e) => self.reconstruct_array(e, additional),
92 Expression::ArrayAccess(e) => self.reconstruct_array_access(*e, additional),
93 Expression::Binary(e) => self.reconstruct_binary(*e, additional),
94 Expression::Call(e) => self.reconstruct_call(*e, additional),
95 Expression::Cast(e) => self.reconstruct_cast(*e, additional),
96 Expression::Struct(e) => self.reconstruct_struct_init(e, additional),
97 Expression::Err(e) => self.reconstruct_err(e, additional),
98 Expression::Path(e) => self.reconstruct_path(e, additional),
99 Expression::Literal(e) => self.reconstruct_literal(e, additional),
100 Expression::Locator(e) => self.reconstruct_locator(e, additional),
101 Expression::MemberAccess(e) => self.reconstruct_member_access(*e, additional),
102 Expression::Repeat(e) => self.reconstruct_repeat(*e, additional),
103 Expression::Ternary(e) => self.reconstruct_ternary(*e, additional),
104 Expression::Tuple(e) => self.reconstruct_tuple(e, additional),
105 Expression::TupleAccess(e) => self.reconstruct_tuple_access(*e, additional),
106 Expression::Unary(e) => self.reconstruct_unary(*e, additional),
107 Expression::Unit(e) => self.reconstruct_unit(e, additional),
108 };
109
110 if let Some(Type::Optional(OptionalType { inner })) = additional {
112 let actual_expr_type =
113 self.state.type_table.get(&expr.id()).expect(
114 "Type table must contain type for this expression ID; IDs are not modified during lowering",
115 );
116
117 if actual_expr_type.can_coerce_to(inner) {
118 return (self.wrap_optional_value(expr, *inner.clone()), stmts);
119 }
120 }
121
122 (expr, stmts)
123 }
124
125 fn reconstruct_array_access(
126 &mut self,
127 mut input: ArrayAccess,
128 _additional: &Self::AdditionalInput,
129 ) -> (Expression, Self::AdditionalOutput) {
130 let (array, mut stmts_array) = self.reconstruct_expression(input.array, &None);
131 let (index, mut stmts_index) = self.reconstruct_expression(input.index, &None);
132
133 input.array = array;
134 input.index = index;
135
136 stmts_array.append(&mut stmts_index);
138
139 (input.into(), stmts_array)
140 }
141
142 fn reconstruct_associated_function(
143 &mut self,
144 mut input: AssociatedFunctionExpression,
145 _additional: &Option<Type>,
146 ) -> (Expression, Self::AdditionalOutput) {
147 match CoreFunction::from_symbols(input.variant.name, input.name.name) {
148 Some(CoreFunction::OptionalUnwrap) => {
149 let [optional_expr] = &input.arguments[..] else {
150 panic!("guaranteed by type checking");
151 };
152
153 let (reconstructed_optional_expr, mut stmts) =
154 self.reconstruct_expression(optional_expr.clone(), &None);
155
156 let val_access = MemberAccess {
158 inner: reconstructed_optional_expr.clone(),
159 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
160 span: Span::default(),
161 id: self.state.node_builder.next_id(),
162 };
163
164 let is_some_access = MemberAccess {
165 inner: reconstructed_optional_expr.clone(),
166 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
167 span: Span::default(),
168 id: self.state.node_builder.next_id(),
169 };
170
171 let assert_stmt = AssertStatement {
173 variant: AssertVariant::Assert(is_some_access.clone().into()),
174 span: Span::default(),
175 id: self.state.node_builder.next_id(),
176 };
177
178 stmts.push(assert_stmt.into());
180
181 (val_access.into(), stmts)
182 }
183 Some(CoreFunction::OptionalUnwrapOr) => {
184 let [optional_expr, default_expr] = &input.arguments[..] else {
185 panic!("unwrap_or must have 2 arguments: optional and default");
186 };
187
188 let (reconstructed_optional_expr, mut stmts1) =
189 self.reconstruct_expression(optional_expr.clone(), &None);
190
191 let Some(Type::Optional(OptionalType { inner: expected_inner_type })) =
193 self.state.type_table.get(&optional_expr.id())
194 else {
195 panic!("guaranteed by type checking")
196 };
197
198 let (reconstructed_fallback_expr, stmts2) =
199 self.reconstruct_expression(default_expr.clone(), &Some(*expected_inner_type.clone()));
200
201 let val_access = MemberAccess {
203 inner: reconstructed_optional_expr.clone(),
204 name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
205 span: Span::default(),
206 id: self.state.node_builder.next_id(),
207 };
208
209 let is_some_access = MemberAccess {
210 inner: reconstructed_optional_expr,
211 name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
212 span: Span::default(),
213 id: self.state.node_builder.next_id(),
214 };
215
216 let ternary_expr = TernaryExpression {
218 condition: is_some_access.into(),
219 if_true: val_access.into(),
220 if_false: reconstructed_fallback_expr,
221 span: Span::default(),
222 id: self.state.node_builder.next_id(),
223 };
224
225 stmts1.extend(stmts2);
226 (ternary_expr.into(), stmts1)
227 }
228 _ => {
229 let statements: Vec<_> = input
230 .arguments
231 .iter_mut()
232 .flat_map(|arg| {
233 let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &None);
234 *arg = expr;
235 stmts
236 })
237 .collect();
238
239 (input.into(), statements)
240 }
241 }
242 }
243
244 fn reconstruct_member_access(
245 &mut self,
246 mut input: MemberAccess,
247 _additional: &Self::AdditionalInput,
248 ) -> (Expression, Self::AdditionalOutput) {
249 let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &None);
250
251 input.inner = inner;
252
253 (input.into(), stmts_inner)
254 }
255
256 fn reconstruct_repeat(
257 &mut self,
258 mut input: RepeatExpression,
259 additional: &Self::AdditionalInput,
260 ) -> (Expression, Self::AdditionalOutput) {
261 let expected_element_type =
263 additional.clone().or_else(|| self.state.type_table.get(&input.id)).and_then(|mut ty| {
264 if let Type::Optional(inner) = ty {
265 ty = *inner.inner;
266 }
267 match ty {
268 Type::Array(array_ty) => Some(*array_ty.element_type),
269 _ => None,
270 }
271 });
272
273 let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &expected_element_type);
275
276 let (count, mut stmts_count) = self.reconstruct_expression(input.count, &None);
277
278 input.expr = expr;
279 input.count = count;
280
281 stmts_expr.append(&mut stmts_count);
282
283 (input.into(), stmts_expr)
284 }
285
286 fn reconstruct_tuple_access(
287 &mut self,
288 mut input: TupleAccess,
289 _additional: &Self::AdditionalInput,
290 ) -> (Expression, Self::AdditionalOutput) {
291 let (tuple, stmts) = self.reconstruct_expression(input.tuple, &None);
292
293 input.tuple = tuple;
294
295 (input.into(), stmts)
296 }
297
298 fn reconstruct_array(
299 &mut self,
300 mut input: ArrayExpression,
301 additional: &Option<Type>,
302 ) -> (Expression, Self::AdditionalOutput) {
303 let expected_element_type = additional
304 .clone()
305 .or_else(|| self.state.type_table.get(&input.id))
306 .and_then(|mut ty| {
307 if let Type::Optional(inner) = ty {
309 ty = *inner.inner;
310 }
311 match ty {
313 Type::Array(array_ty) => Some(*array_ty.element_type),
314 _ => None,
315 }
316 })
317 .expect("guaranteed by type checking");
318
319 let mut all_stmts = Vec::new();
320 let mut new_elements = Vec::with_capacity(input.elements.len());
321
322 for element in input.elements.into_iter() {
323 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_element_type.clone()));
324 all_stmts.append(&mut stmts);
325 new_elements.push(expr);
326 }
327
328 input.elements = new_elements;
329
330 (input.into(), all_stmts)
331 }
332
333 fn reconstruct_binary(
334 &mut self,
335 mut input: BinaryExpression,
336 _additional: &Self::AdditionalInput,
337 ) -> (Expression, Self::AdditionalOutput) {
338 let (left, mut stmts_left) = self.reconstruct_expression(input.left, &None);
339 let (right, mut stmts_right) = self.reconstruct_expression(input.right, &None);
340
341 input.left = left;
342 input.right = right;
343
344 stmts_left.append(&mut stmts_right);
346
347 (input.into(), stmts_left)
348 }
349
350 fn reconstruct_call(
351 &mut self,
352 mut input: CallExpression,
353 _additional: &Self::AdditionalInput,
354 ) -> (Expression, Self::AdditionalOutput) {
355 let callee_program = input.program.unwrap_or(self.program);
356
357 let func_symbol = self
358 .state
359 .symbol_table
360 .lookup_function(&Location::new(callee_program, input.function.absolute_path()))
361 .expect("The symbol table creator should already have visited all functions.")
362 .clone();
363
364 let mut all_stmts = Vec::new();
365
366 let mut const_arguments = Vec::with_capacity(input.const_arguments.len());
368 for (arg, param) in input.const_arguments.into_iter().zip(func_symbol.function.const_parameters.iter()) {
369 let expected_type = Some(param.type_.clone());
370 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
371 all_stmts.append(&mut stmts);
372 const_arguments.push(expr);
373 }
374
375 let mut arguments = Vec::with_capacity(input.arguments.len());
377 for (arg, param) in input.arguments.into_iter().zip(func_symbol.function.input.iter()) {
378 let expected_type = Some(param.type_.clone());
379 let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
380 all_stmts.append(&mut stmts);
381 arguments.push(expr);
382 }
383
384 input.const_arguments = const_arguments;
385 input.arguments = arguments;
386
387 (input.into(), all_stmts)
388 }
389
390 fn reconstruct_cast(
391 &mut self,
392 mut input: CastExpression,
393 _additional: &Self::AdditionalInput,
394 ) -> (Expression, Self::AdditionalOutput) {
395 let (expr, stmts) = self.reconstruct_expression(input.expression, &None);
396
397 input.expression = expr;
398
399 (input.into(), stmts)
400 }
401
402 fn reconstruct_struct_init(
403 &mut self,
404 mut input: StructExpression,
405 additional: &Option<Type>,
406 ) -> (Expression, Self::AdditionalOutput) {
407 let (const_parameters, member_types): (Vec<Type>, IndexMap<Symbol, Type>) = {
408 let mut ty = additional.clone().or_else(|| self.state.type_table.get(&input.id)).expect("type checked");
409
410 if let Type::Optional(inner) = ty {
411 ty = *inner.inner;
412 }
413
414 if let Type::Composite(composite) = ty {
415 let program = composite.program.unwrap_or(self.program);
416 let location = Location::new(program, composite.path.absolute_path());
417 let struct_def = self
418 .state
419 .symbol_table
420 .lookup_record(&location)
421 .or_else(|| self.state.symbol_table.lookup_struct(&composite.path.absolute_path()))
422 .or_else(|| self.new_structs.get(&composite.path.identifier().name))
423 .expect("guaranteed by type checking");
424
425 let const_parameters = struct_def.const_parameters.iter().map(|param| param.type_.clone()).collect();
426 let member_types =
427 struct_def.members.iter().map(|member| (member.identifier.name, member.type_.clone())).collect();
428
429 (const_parameters, member_types)
430 } else {
431 panic!("expected Type::Composite")
432 }
433 };
434
435 let (const_arguments, mut const_arg_stmts): (Vec<_>, Vec<_>) = input
437 .const_arguments
438 .into_iter()
439 .zip(const_parameters.iter())
440 .map(|(arg, ty)| self.reconstruct_expression(arg, &Some(ty.clone())))
441 .unzip();
442
443 let (members, mut member_stmts): (Vec<_>, Vec<_>) = input
445 .members
446 .into_iter()
447 .map(|member| {
448 let expected_type =
449 member_types.get(&member.identifier.name).expect("guaranteed by type checking").clone();
450
451 let expression =
452 member.expression.unwrap_or_else(|| Path::from(member.identifier).into_absolute().into());
453
454 let (new_expr, stmts) = self.reconstruct_expression(expression, &Some(expected_type));
455
456 (
457 StructVariableInitializer {
458 identifier: member.identifier,
459 expression: Some(new_expr),
460 span: member.span,
461 id: member.id,
462 },
463 stmts,
464 )
465 })
466 .unzip();
467
468 input.const_arguments = const_arguments;
469 input.members = members;
470
471 const_arg_stmts.append(&mut member_stmts);
473 let all_stmts = const_arg_stmts.into_iter().flatten().collect();
474
475 (input.into(), all_stmts)
476 }
477
478 fn reconstruct_ternary(
479 &mut self,
480 mut input: TernaryExpression,
481 additional: &Self::AdditionalInput,
482 ) -> (Expression, Self::AdditionalOutput) {
483 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
484 let (if_true, mut stmts_if_true) = self.reconstruct_expression(input.if_true, additional);
485 let (if_false, mut stmts_if_false) = self.reconstruct_expression(input.if_false, additional);
486
487 input.condition = condition;
488 input.if_true = if_true;
489 input.if_false = if_false;
490
491 stmts_condition.append(&mut stmts_if_true);
493 stmts_condition.append(&mut stmts_if_false);
494
495 (input.into(), stmts_condition)
496 }
497
498 fn reconstruct_tuple(
499 &mut self,
500 mut input: TupleExpression,
501 additional: &Option<Type>,
502 ) -> (Expression, Self::AdditionalOutput) {
503 let mut all_stmts = Vec::new();
504 let mut new_elements = Vec::with_capacity(input.elements.len());
505
506 let expected_types = additional
508 .as_ref()
509 .and_then(|ty| {
510 let mut ty = ty.clone();
511
512 if let Type::Optional(inner) = ty {
514 ty = *inner.inner;
515 }
516 if let Type::Tuple(tuple_ty) = ty { Some(tuple_ty.elements.clone()) } else { None }
518 })
519 .expect("guaranteed by type checking");
520
521 for (element, expected_ty) in input.elements.into_iter().zip(expected_types) {
523 let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_ty));
524 all_stmts.append(&mut stmts);
525 new_elements.push(expr);
526 }
527
528 input.elements = new_elements;
529
530 (input.into(), all_stmts)
531 }
532
533 fn reconstruct_unary(
534 &mut self,
535 mut input: UnaryExpression,
536 _additional: &Self::AdditionalInput,
537 ) -> (Expression, Self::AdditionalOutput) {
538 let (receiver, stmts) = self.reconstruct_expression(input.receiver, &None);
539
540 input.receiver = receiver;
541
542 (input.into(), stmts)
543 }
544
545 fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
547 let mut all_stmts = Vec::new();
548
549 input.variant = match input.variant {
550 AssertVariant::Assert(expr) => {
551 let (expr, mut stmts) = self.reconstruct_expression(expr, &None);
552 all_stmts.append(&mut stmts);
553 AssertVariant::Assert(expr)
554 }
555 AssertVariant::AssertEq(left, right) => {
556 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
557 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
558 all_stmts.append(&mut stmts_left);
559 all_stmts.append(&mut stmts_right);
560 AssertVariant::AssertEq(left, right)
561 }
562 AssertVariant::AssertNeq(left, right) => {
563 let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
564 let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
565 all_stmts.append(&mut stmts_left);
566 all_stmts.append(&mut stmts_right);
567 AssertVariant::AssertNeq(left, right)
568 }
569 };
570
571 (input.into(), all_stmts)
572 }
573
574 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
575 let expected_ty = self.state.type_table.get(&input.place.id()).expect("type checked");
576
577 let (new_place, place_stmts) = self.reconstruct_expression(input.place, &None);
578 let (new_value, value_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
579
580 (AssignStatement { place: new_place, value: new_value, ..input }.into(), [place_stmts, value_stmts].concat())
581 }
582
583 fn reconstruct_conditional(&mut self, mut input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
584 let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
585 let (then_block, mut stmts_then) = self.reconstruct_block(input.then);
586
587 let otherwise = match input.otherwise {
588 Some(otherwise_stmt) => {
589 let (stmt, mut stmts_otherwise) = self.reconstruct_statement(*otherwise_stmt);
590 stmts_condition.append(&mut stmts_then);
591 stmts_condition.append(&mut stmts_otherwise);
592 Some(Box::new(stmt))
593 }
594 None => {
595 stmts_condition.append(&mut stmts_then);
596 None
597 }
598 };
599
600 input.condition = condition;
601 input.then = then_block;
602 input.otherwise = otherwise;
603
604 (input.into(), stmts_condition)
605 }
606
607 fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
608 let (type_, mut stmts_type) = self.reconstruct_type(input.type_.clone());
609 let (value, mut stmts_value) = self.reconstruct_expression(input.value, &Some(input.type_));
610
611 input.type_ = type_;
612 input.value = value;
613
614 stmts_type.append(&mut stmts_value);
615
616 (input.into(), stmts_type)
617 }
618
619 fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
620 let mut statements = Vec::with_capacity(block.statements.len());
621
622 for statement in block.statements {
623 let (reconstructed_statement, mut additional_stmts) = self.reconstruct_statement(statement);
624 statements.append(&mut additional_stmts);
625 statements.push(reconstructed_statement);
626 }
627
628 block.statements = statements;
629
630 (block, Default::default())
631 }
632
633 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
634 let expected_ty = input
638 .type_
639 .clone()
640 .or_else(|| self.state.type_table.get(&input.value.id()))
641 .expect("guaranteed by type checking");
642
643 let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
644
645 input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
646 input.value = new_value;
647
648 (input.into(), additional_stmts)
649 }
650
651 fn reconstruct_expression_statement(
652 &mut self,
653 mut input: ExpressionStatement,
654 ) -> (Statement, Self::AdditionalOutput) {
655 let (expression, stmts) = self.reconstruct_expression(input.expression, &None);
656
657 input.expression = expression;
658
659 (input.into(), stmts)
660 }
661
662 fn reconstruct_iteration(&mut self, mut input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
663 let mut all_stmts = Vec::new();
664
665 let type_ = match input.type_ {
666 Some(ty) => {
667 let (new_ty, mut stmts_ty) = self.reconstruct_type(ty);
668 all_stmts.append(&mut stmts_ty);
669 Some(new_ty)
670 }
671 None => None,
672 };
673
674 let (start, mut stmts_start) = self.reconstruct_expression(input.start, &None);
675 let (stop, mut stmts_stop) = self.reconstruct_expression(input.stop, &None);
676 let (block, mut stmts_block) = self.reconstruct_block(input.block);
677
678 all_stmts.append(&mut stmts_start);
679 all_stmts.append(&mut stmts_stop);
680 all_stmts.append(&mut stmts_block);
681
682 input.type_ = type_;
683 input.start = start;
684 input.stop = stop;
685 input.block = block;
686
687 (input.into(), all_stmts)
688 }
689
690 fn reconstruct_return(&mut self, mut input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
691 let caller_name = self.function.expect("`self.function` is set every time a function is visited.");
692 let caller_path = self.module.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
693
694 let func_symbol = self
695 .state
696 .symbol_table
697 .lookup_function(&Location::new(self.program, caller_path))
698 .expect("The symbol table creator should already have visited all functions.");
699
700 let return_type = func_symbol.function.output_type.clone();
701
702 let (expression, statements) = self.reconstruct_expression(input.expression, &Some(return_type));
703 input.expression = expression;
704
705 (input.into(), statements)
706 }
707}