leo_passes/option_lowering/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::OptionLoweringVisitor;
18
19use leo_ast::*;
20use leo_span::{Span, Symbol};
21
22use indexmap::IndexMap;
23
24impl leo_ast::AstReconstructor for OptionLoweringVisitor<'_> {
25    type AdditionalInput = Option<Type>;
26    type AdditionalOutput = Vec<Statement>;
27
28    /* Types */
29    fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
30        let (inner_type, _) = self.reconstruct_type(*input.inner.clone());
31
32        // Generate a unique name "u32?", "bool?", etc.
33        let struct_name = crate::make_optional_struct_symbol(&inner_type);
34
35        // Register the struct if it hasn't been already
36        self.new_structs.entry(struct_name).or_insert_with(|| Composite {
37            identifier: Identifier::new(struct_name, self.state.node_builder.next_id()),
38            const_parameters: vec![], // this is not a generic struct
39            members: vec![
40                Member {
41                    mode: Mode::None,
42                    identifier: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
43                    type_: Type::Boolean,
44                    span: Span::default(),
45                    id: self.state.node_builder.next_id(),
46                },
47                Member {
48                    mode: Mode::None,
49                    identifier: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
50                    type_: inner_type,
51                    span: Span::default(),
52                    id: self.state.node_builder.next_id(),
53                },
54            ],
55            external: None,
56            is_record: false,
57            span: Span::default(),
58            id: self.state.node_builder.next_id(),
59        });
60
61        (
62            Type::Composite(CompositeType {
63                path: Path::from(Identifier::new(struct_name, self.state.node_builder.next_id())).into_absolute(),
64                const_arguments: vec![], // this is not a generic struct
65                program: None,           // current program
66            }),
67            Default::default(),
68        )
69    }
70
71    /* Expressions */
72    fn reconstruct_expression(
73        &mut self,
74        input: Expression,
75        additional: &Option<Type>,
76    ) -> (Expression, Self::AdditionalOutput) {
77        // Handle `None` literal separately
78        if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input {
79            let Some(Type::Optional(OptionalType { inner })) = self.state.type_table.get(&input.id()) else {
80                panic!("Type checking guarantees that `None` has an Optional type");
81            };
82
83            return (self.wrap_none(&inner), vec![]);
84        }
85
86        // Reconstruct the expression based on its variant
87        let (expr, stmts) = match input {
88            Expression::AssociatedConstant(e) => self.reconstruct_associated_constant(e, additional),
89            Expression::AssociatedFunction(e) => self.reconstruct_associated_function(e, additional),
90            Expression::Async(e) => self.reconstruct_async(e, additional),
91            Expression::Array(e) => self.reconstruct_array(e, additional),
92            Expression::ArrayAccess(e) => self.reconstruct_array_access(*e, additional),
93            Expression::Binary(e) => self.reconstruct_binary(*e, additional),
94            Expression::Call(e) => self.reconstruct_call(*e, additional),
95            Expression::Cast(e) => self.reconstruct_cast(*e, additional),
96            Expression::Struct(e) => self.reconstruct_struct_init(e, additional),
97            Expression::Err(e) => self.reconstruct_err(e, additional),
98            Expression::Path(e) => self.reconstruct_path(e, additional),
99            Expression::Literal(e) => self.reconstruct_literal(e, additional),
100            Expression::Locator(e) => self.reconstruct_locator(e, additional),
101            Expression::MemberAccess(e) => self.reconstruct_member_access(*e, additional),
102            Expression::Repeat(e) => self.reconstruct_repeat(*e, additional),
103            Expression::Ternary(e) => self.reconstruct_ternary(*e, additional),
104            Expression::Tuple(e) => self.reconstruct_tuple(e, additional),
105            Expression::TupleAccess(e) => self.reconstruct_tuple_access(*e, additional),
106            Expression::Unary(e) => self.reconstruct_unary(*e, additional),
107            Expression::Unit(e) => self.reconstruct_unit(e, additional),
108        };
109
110        // Optionally wrap in an optional if expected type is `Optional<T>`
111        if let Some(Type::Optional(OptionalType { inner })) = additional {
112            let actual_expr_type =
113                self.state.type_table.get(&expr.id()).expect(
114                    "Type table must contain type for this expression ID; IDs are not modified during lowering",
115                );
116
117            if actual_expr_type.can_coerce_to(inner) {
118                return (self.wrap_optional_value(expr, *inner.clone()), stmts);
119            }
120        }
121
122        (expr, stmts)
123    }
124
125    fn reconstruct_array_access(
126        &mut self,
127        mut input: ArrayAccess,
128        _additional: &Self::AdditionalInput,
129    ) -> (Expression, Self::AdditionalOutput) {
130        let (array, mut stmts_array) = self.reconstruct_expression(input.array, &None);
131        let (index, mut stmts_index) = self.reconstruct_expression(input.index, &None);
132
133        input.array = array;
134        input.index = index;
135
136        // Merge side effects
137        stmts_array.append(&mut stmts_index);
138
139        (input.into(), stmts_array)
140    }
141
142    fn reconstruct_associated_function(
143        &mut self,
144        mut input: AssociatedFunctionExpression,
145        _additional: &Option<Type>,
146    ) -> (Expression, Self::AdditionalOutput) {
147        match CoreFunction::from_symbols(input.variant.name, input.name.name) {
148            Some(CoreFunction::OptionalUnwrap) => {
149                let [optional_expr] = &input.arguments[..] else {
150                    panic!("guaranteed by type checking");
151                };
152
153                let (reconstructed_optional_expr, mut stmts) =
154                    self.reconstruct_expression(optional_expr.clone(), &None);
155
156                // Access `.val` and `.is_some` from reconstructed expression
157                let val_access = MemberAccess {
158                    inner: reconstructed_optional_expr.clone(),
159                    name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
160                    span: Span::default(),
161                    id: self.state.node_builder.next_id(),
162                };
163
164                let is_some_access = MemberAccess {
165                    inner: reconstructed_optional_expr.clone(),
166                    name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
167                    span: Span::default(),
168                    id: self.state.node_builder.next_id(),
169                };
170
171                // Create assertion: ensure `is_some` is `true`.
172                let assert_stmt = AssertStatement {
173                    variant: AssertVariant::Assert(is_some_access.clone().into()),
174                    span: Span::default(),
175                    id: self.state.node_builder.next_id(),
176                };
177
178                // Combine all statements.
179                stmts.push(assert_stmt.into());
180
181                (val_access.into(), stmts)
182            }
183            Some(CoreFunction::OptionalUnwrapOr) => {
184                let [optional_expr, default_expr] = &input.arguments[..] else {
185                    panic!("unwrap_or must have 2 arguments: optional and default");
186                };
187
188                let (reconstructed_optional_expr, mut stmts1) =
189                    self.reconstruct_expression(optional_expr.clone(), &None);
190
191                // Extract the inner type from the expected type of the optional argument.
192                let Some(Type::Optional(OptionalType { inner: expected_inner_type })) =
193                    self.state.type_table.get(&optional_expr.id())
194                else {
195                    panic!("guaranteed by type checking")
196                };
197
198                let (reconstructed_fallback_expr, stmts2) =
199                    self.reconstruct_expression(default_expr.clone(), &Some(*expected_inner_type.clone()));
200
201                // Access `.val` and `.is_some` from reconstructed expression
202                let val_access = MemberAccess {
203                    inner: reconstructed_optional_expr.clone(),
204                    name: Identifier::new(Symbol::intern("val"), self.state.node_builder.next_id()),
205                    span: Span::default(),
206                    id: self.state.node_builder.next_id(),
207                };
208
209                let is_some_access = MemberAccess {
210                    inner: reconstructed_optional_expr,
211                    name: Identifier::new(Symbol::intern("is_some"), self.state.node_builder.next_id()),
212                    span: Span::default(),
213                    id: self.state.node_builder.next_id(),
214                };
215
216                // s.is_some ? s.val : fallback
217                let ternary_expr = TernaryExpression {
218                    condition: is_some_access.into(),
219                    if_true: val_access.into(),
220                    if_false: reconstructed_fallback_expr,
221                    span: Span::default(),
222                    id: self.state.node_builder.next_id(),
223                };
224
225                stmts1.extend(stmts2);
226                (ternary_expr.into(), stmts1)
227            }
228            _ => {
229                let statements: Vec<_> = input
230                    .arguments
231                    .iter_mut()
232                    .flat_map(|arg| {
233                        let (expr, stmts) = self.reconstruct_expression(std::mem::take(arg), &None);
234                        *arg = expr;
235                        stmts
236                    })
237                    .collect();
238
239                (input.into(), statements)
240            }
241        }
242    }
243
244    fn reconstruct_member_access(
245        &mut self,
246        mut input: MemberAccess,
247        _additional: &Self::AdditionalInput,
248    ) -> (Expression, Self::AdditionalOutput) {
249        let (inner, stmts_inner) = self.reconstruct_expression(input.inner, &None);
250
251        input.inner = inner;
252
253        (input.into(), stmts_inner)
254    }
255
256    fn reconstruct_repeat(
257        &mut self,
258        mut input: RepeatExpression,
259        additional: &Self::AdditionalInput,
260    ) -> (Expression, Self::AdditionalOutput) {
261        // Derive expected element type from the type of the whole expression
262        let expected_element_type =
263            additional.clone().or_else(|| self.state.type_table.get(&input.id)).and_then(|mut ty| {
264                if let Type::Optional(inner) = ty {
265                    ty = *inner.inner;
266                }
267                match ty {
268                    Type::Array(array_ty) => Some(*array_ty.element_type),
269                    _ => None,
270                }
271            });
272
273        // Use expected type (if available) for `expr`
274        let (expr, mut stmts_expr) = self.reconstruct_expression(input.expr, &expected_element_type);
275
276        let (count, mut stmts_count) = self.reconstruct_expression(input.count, &None);
277
278        input.expr = expr;
279        input.count = count;
280
281        stmts_expr.append(&mut stmts_count);
282
283        (input.into(), stmts_expr)
284    }
285
286    fn reconstruct_tuple_access(
287        &mut self,
288        mut input: TupleAccess,
289        _additional: &Self::AdditionalInput,
290    ) -> (Expression, Self::AdditionalOutput) {
291        let (tuple, stmts) = self.reconstruct_expression(input.tuple, &None);
292
293        input.tuple = tuple;
294
295        (input.into(), stmts)
296    }
297
298    fn reconstruct_array(
299        &mut self,
300        mut input: ArrayExpression,
301        additional: &Option<Type>,
302    ) -> (Expression, Self::AdditionalOutput) {
303        let expected_element_type = additional
304            .clone()
305            .or_else(|| self.state.type_table.get(&input.id))
306            .and_then(|mut ty| {
307                // Unwrap Optional if any
308                if let Type::Optional(inner) = ty {
309                    ty = *inner.inner;
310                }
311                // Expect Array type
312                match ty {
313                    Type::Array(array_ty) => Some(*array_ty.element_type),
314                    _ => None,
315                }
316            })
317            .expect("guaranteed by type checking");
318
319        let mut all_stmts = Vec::new();
320        let mut new_elements = Vec::with_capacity(input.elements.len());
321
322        for element in input.elements.into_iter() {
323            let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_element_type.clone()));
324            all_stmts.append(&mut stmts);
325            new_elements.push(expr);
326        }
327
328        input.elements = new_elements;
329
330        (input.into(), all_stmts)
331    }
332
333    fn reconstruct_binary(
334        &mut self,
335        mut input: BinaryExpression,
336        _additional: &Self::AdditionalInput,
337    ) -> (Expression, Self::AdditionalOutput) {
338        let (left, mut stmts_left) = self.reconstruct_expression(input.left, &None);
339        let (right, mut stmts_right) = self.reconstruct_expression(input.right, &None);
340
341        input.left = left;
342        input.right = right;
343
344        // Merge side effects
345        stmts_left.append(&mut stmts_right);
346
347        (input.into(), stmts_left)
348    }
349
350    fn reconstruct_call(
351        &mut self,
352        mut input: CallExpression,
353        _additional: &Self::AdditionalInput,
354    ) -> (Expression, Self::AdditionalOutput) {
355        let callee_program = input.program.unwrap_or(self.program);
356
357        let func_symbol = self
358            .state
359            .symbol_table
360            .lookup_function(&Location::new(callee_program, input.function.absolute_path()))
361            .expect("The symbol table creator should already have visited all functions.")
362            .clone();
363
364        let mut all_stmts = Vec::new();
365
366        // Reconstruct const arguments with expected types
367        let mut const_arguments = Vec::with_capacity(input.const_arguments.len());
368        for (arg, param) in input.const_arguments.into_iter().zip(func_symbol.function.const_parameters.iter()) {
369            let expected_type = Some(param.type_.clone());
370            let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
371            all_stmts.append(&mut stmts);
372            const_arguments.push(expr);
373        }
374
375        // Reconstruct normal arguments with expected types
376        let mut arguments = Vec::with_capacity(input.arguments.len());
377        for (arg, param) in input.arguments.into_iter().zip(func_symbol.function.input.iter()) {
378            let expected_type = Some(param.type_.clone());
379            let (expr, mut stmts) = self.reconstruct_expression(arg, &expected_type);
380            all_stmts.append(&mut stmts);
381            arguments.push(expr);
382        }
383
384        input.const_arguments = const_arguments;
385        input.arguments = arguments;
386
387        (input.into(), all_stmts)
388    }
389
390    fn reconstruct_cast(
391        &mut self,
392        mut input: CastExpression,
393        _additional: &Self::AdditionalInput,
394    ) -> (Expression, Self::AdditionalOutput) {
395        let (expr, stmts) = self.reconstruct_expression(input.expression, &None);
396
397        input.expression = expr;
398
399        (input.into(), stmts)
400    }
401
402    fn reconstruct_struct_init(
403        &mut self,
404        mut input: StructExpression,
405        additional: &Option<Type>,
406    ) -> (Expression, Self::AdditionalOutput) {
407        let (const_parameters, member_types): (Vec<Type>, IndexMap<Symbol, Type>) = {
408            let mut ty = additional.clone().or_else(|| self.state.type_table.get(&input.id)).expect("type checked");
409
410            if let Type::Optional(inner) = ty {
411                ty = *inner.inner;
412            }
413
414            if let Type::Composite(composite) = ty {
415                let program = composite.program.unwrap_or(self.program);
416                let location = Location::new(program, composite.path.absolute_path());
417                let struct_def = self
418                    .state
419                    .symbol_table
420                    .lookup_record(&location)
421                    .or_else(|| self.state.symbol_table.lookup_struct(&composite.path.absolute_path()))
422                    .or_else(|| self.new_structs.get(&composite.path.identifier().name))
423                    .expect("guaranteed by type checking");
424
425                let const_parameters = struct_def.const_parameters.iter().map(|param| param.type_.clone()).collect();
426                let member_types =
427                    struct_def.members.iter().map(|member| (member.identifier.name, member.type_.clone())).collect();
428
429                (const_parameters, member_types)
430            } else {
431                panic!("expected Type::Composite")
432            }
433        };
434
435        // Reconstruct const arguments with expected types
436        let (const_arguments, mut const_arg_stmts): (Vec<_>, Vec<_>) = input
437            .const_arguments
438            .into_iter()
439            .zip(const_parameters.iter())
440            .map(|(arg, ty)| self.reconstruct_expression(arg, &Some(ty.clone())))
441            .unzip();
442
443        // Reconstruct members
444        let (members, mut member_stmts): (Vec<_>, Vec<_>) = input
445            .members
446            .into_iter()
447            .map(|member| {
448                let expected_type =
449                    member_types.get(&member.identifier.name).expect("guaranteed by type checking").clone();
450
451                let expression =
452                    member.expression.unwrap_or_else(|| Path::from(member.identifier).into_absolute().into());
453
454                let (new_expr, stmts) = self.reconstruct_expression(expression, &Some(expected_type));
455
456                (
457                    StructVariableInitializer {
458                        identifier: member.identifier,
459                        expression: Some(new_expr),
460                        span: member.span,
461                        id: member.id,
462                    },
463                    stmts,
464                )
465            })
466            .unzip();
467
468        input.const_arguments = const_arguments;
469        input.members = members;
470
471        // Merge all side effect statements
472        const_arg_stmts.append(&mut member_stmts);
473        let all_stmts = const_arg_stmts.into_iter().flatten().collect();
474
475        (input.into(), all_stmts)
476    }
477
478    fn reconstruct_ternary(
479        &mut self,
480        mut input: TernaryExpression,
481        additional: &Self::AdditionalInput,
482    ) -> (Expression, Self::AdditionalOutput) {
483        let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
484        let (if_true, mut stmts_if_true) = self.reconstruct_expression(input.if_true, additional);
485        let (if_false, mut stmts_if_false) = self.reconstruct_expression(input.if_false, additional);
486
487        input.condition = condition;
488        input.if_true = if_true;
489        input.if_false = if_false;
490
491        // Merge all side effects
492        stmts_condition.append(&mut stmts_if_true);
493        stmts_condition.append(&mut stmts_if_false);
494
495        (input.into(), stmts_condition)
496    }
497
498    fn reconstruct_tuple(
499        &mut self,
500        mut input: TupleExpression,
501        additional: &Option<Type>,
502    ) -> (Expression, Self::AdditionalOutput) {
503        let mut all_stmts = Vec::new();
504        let mut new_elements = Vec::with_capacity(input.elements.len());
505
506        // Extract tuple element types if additional type info is Some(Type::Tuple).
507        let expected_types = additional
508            .as_ref()
509            .and_then(|ty| {
510                let mut ty = ty.clone();
511
512                // Unwrap Optional if any.
513                if let Type::Optional(inner) = ty {
514                    ty = *inner.inner;
515                }
516                // Expect Tuple type.
517                if let Type::Tuple(tuple_ty) = ty { Some(tuple_ty.elements.clone()) } else { None }
518            })
519            .expect("guaranteed by type checking");
520
521        // Zip elements with expected types and reconstruct with expected type.
522        for (element, expected_ty) in input.elements.into_iter().zip(expected_types) {
523            let (expr, mut stmts) = self.reconstruct_expression(element, &Some(expected_ty));
524            all_stmts.append(&mut stmts);
525            new_elements.push(expr);
526        }
527
528        input.elements = new_elements;
529
530        (input.into(), all_stmts)
531    }
532
533    fn reconstruct_unary(
534        &mut self,
535        mut input: UnaryExpression,
536        _additional: &Self::AdditionalInput,
537    ) -> (Expression, Self::AdditionalOutput) {
538        let (receiver, stmts) = self.reconstruct_expression(input.receiver, &None);
539
540        input.receiver = receiver;
541
542        (input.into(), stmts)
543    }
544
545    /* Statements */
546    fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
547        let mut all_stmts = Vec::new();
548
549        input.variant = match input.variant {
550            AssertVariant::Assert(expr) => {
551                let (expr, mut stmts) = self.reconstruct_expression(expr, &None);
552                all_stmts.append(&mut stmts);
553                AssertVariant::Assert(expr)
554            }
555            AssertVariant::AssertEq(left, right) => {
556                let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
557                let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
558                all_stmts.append(&mut stmts_left);
559                all_stmts.append(&mut stmts_right);
560                AssertVariant::AssertEq(left, right)
561            }
562            AssertVariant::AssertNeq(left, right) => {
563                let (left, mut stmts_left) = self.reconstruct_expression(left, &None);
564                let (right, mut stmts_right) = self.reconstruct_expression(right, &None);
565                all_stmts.append(&mut stmts_left);
566                all_stmts.append(&mut stmts_right);
567                AssertVariant::AssertNeq(left, right)
568            }
569        };
570
571        (input.into(), all_stmts)
572    }
573
574    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
575        let expected_ty = self.state.type_table.get(&input.place.id()).expect("type checked");
576
577        let (new_place, place_stmts) = self.reconstruct_expression(input.place, &None);
578        let (new_value, value_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
579
580        (AssignStatement { place: new_place, value: new_value, ..input }.into(), [place_stmts, value_stmts].concat())
581    }
582
583    fn reconstruct_conditional(&mut self, mut input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
584        let (condition, mut stmts_condition) = self.reconstruct_expression(input.condition, &None);
585        let (then_block, mut stmts_then) = self.reconstruct_block(input.then);
586
587        let otherwise = match input.otherwise {
588            Some(otherwise_stmt) => {
589                let (stmt, mut stmts_otherwise) = self.reconstruct_statement(*otherwise_stmt);
590                stmts_condition.append(&mut stmts_then);
591                stmts_condition.append(&mut stmts_otherwise);
592                Some(Box::new(stmt))
593            }
594            None => {
595                stmts_condition.append(&mut stmts_then);
596                None
597            }
598        };
599
600        input.condition = condition;
601        input.then = then_block;
602        input.otherwise = otherwise;
603
604        (input.into(), stmts_condition)
605    }
606
607    fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
608        let (type_, mut stmts_type) = self.reconstruct_type(input.type_.clone());
609        let (value, mut stmts_value) = self.reconstruct_expression(input.value, &Some(input.type_));
610
611        input.type_ = type_;
612        input.value = value;
613
614        stmts_type.append(&mut stmts_value);
615
616        (input.into(), stmts_type)
617    }
618
619    fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
620        let mut statements = Vec::with_capacity(block.statements.len());
621
622        for statement in block.statements {
623            let (reconstructed_statement, mut additional_stmts) = self.reconstruct_statement(statement);
624            statements.append(&mut additional_stmts);
625            statements.push(reconstructed_statement);
626        }
627
628        block.statements = statements;
629
630        (block, Default::default())
631    }
632
633    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
634        // Use the explicitly provided type if available, otherwise fall back to the type table
635        // Note that we have to consult the type annotation first to handle cases like `let x: u32? = 1`
636        // where the type annotation is a `u32?` while the RHS is a `u32`.
637        let expected_ty = input
638            .type_
639            .clone()
640            .or_else(|| self.state.type_table.get(&input.value.id()))
641            .expect("guaranteed by type checking");
642
643        let (new_value, additional_stmts) = self.reconstruct_expression(input.value, &Some(expected_ty));
644
645        input.type_ = input.type_.map(|ty| self.reconstruct_type(ty).0);
646        input.value = new_value;
647
648        (input.into(), additional_stmts)
649    }
650
651    fn reconstruct_expression_statement(
652        &mut self,
653        mut input: ExpressionStatement,
654    ) -> (Statement, Self::AdditionalOutput) {
655        let (expression, stmts) = self.reconstruct_expression(input.expression, &None);
656
657        input.expression = expression;
658
659        (input.into(), stmts)
660    }
661
662    fn reconstruct_iteration(&mut self, mut input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
663        let mut all_stmts = Vec::new();
664
665        let type_ = match input.type_ {
666            Some(ty) => {
667                let (new_ty, mut stmts_ty) = self.reconstruct_type(ty);
668                all_stmts.append(&mut stmts_ty);
669                Some(new_ty)
670            }
671            None => None,
672        };
673
674        let (start, mut stmts_start) = self.reconstruct_expression(input.start, &None);
675        let (stop, mut stmts_stop) = self.reconstruct_expression(input.stop, &None);
676        let (block, mut stmts_block) = self.reconstruct_block(input.block);
677
678        all_stmts.append(&mut stmts_start);
679        all_stmts.append(&mut stmts_stop);
680        all_stmts.append(&mut stmts_block);
681
682        input.type_ = type_;
683        input.start = start;
684        input.stop = stop;
685        input.block = block;
686
687        (input.into(), all_stmts)
688    }
689
690    fn reconstruct_return(&mut self, mut input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
691        let caller_name = self.function.expect("`self.function` is set every time a function is visited.");
692        let caller_path = self.module.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
693
694        let func_symbol = self
695            .state
696            .symbol_table
697            .lookup_function(&Location::new(self.program, caller_path))
698            .expect("The symbol table creator should already have visited all functions.");
699
700        let return_type = func_symbol.function.output_type.clone();
701
702        let (expression, statements) = self.reconstruct_expression(input.expression, &Some(return_type));
703        input.expression = expression;
704
705        (input.into(), statements)
706    }
707}