leo_passes/flattening/
program.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::{FlatteningVisitor, ReturnGuard};
18
19use leo_ast::{
20    AstReconstructor,
21    Constructor,
22    Expression,
23    Function,
24    ProgramReconstructor,
25    ProgramScope,
26    ReturnStatement,
27    Statement,
28};
29
30impl ProgramReconstructor for FlatteningVisitor<'_> {
31    /// Flattens a program scope.
32    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
33        self.program = input.program_id.name.name;
34        ProgramScope {
35            program_id: input.program_id,
36            consts: input
37                .consts
38                .into_iter()
39                .map(|(i, c)| match self.reconstruct_const(c) {
40                    (Statement::Const(declaration), _) => (i, declaration),
41                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
42                })
43                .collect(),
44            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
45            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
46            storage_variables: input
47                .storage_variables
48                .into_iter()
49                .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
50                .collect(),
51            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
52            constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
53            span: input.span,
54        }
55    }
56
57    /// Flattens a function's body
58    fn reconstruct_function(&mut self, function: Function) -> Function {
59        // Set when the function is an async function.
60        self.is_async = function.variant.is_async_function();
61
62        // Flatten the function body.
63        let mut block = self.reconstruct_block(function.block).0;
64
65        // Fold the return statements into the block.
66        let returns = std::mem::take(&mut self.returns);
67        let expression_returns: Vec<(Option<Expression>, ReturnStatement)> = returns
68            .into_iter()
69            .map(|(guard, statement)| match guard {
70                ReturnGuard::None => (None, statement),
71                ReturnGuard::Unconstructed(plain) | ReturnGuard::Constructed { plain, .. } => {
72                    (Some(leo_ast::Path::from(plain).into_absolute().into()), statement)
73                }
74            })
75            .collect();
76
77        self.fold_returns(&mut block, expression_returns);
78
79        Function {
80            annotations: function.annotations,
81            variant: function.variant,
82            identifier: function.identifier,
83            const_parameters: function.const_parameters,
84            input: function.input,
85            output: function.output,
86            output_type: function.output_type,
87            block,
88            span: function.span,
89            id: function.id,
90        }
91    }
92
93    /// Flattens a constructor's body.
94    fn reconstruct_constructor(&mut self, constructor: Constructor) -> Constructor {
95        // A constructor is always async.
96        self.is_async = true;
97
98        // Flatten the function body.
99        let mut block = self.reconstruct_block(constructor.block).0;
100
101        // Fold the return statements into the block.
102        let returns = std::mem::take(&mut self.returns);
103        let expression_returns: Vec<(Option<Expression>, ReturnStatement)> = returns
104            .into_iter()
105            .map(|(guard, statement)| match guard {
106                ReturnGuard::None => (None, statement),
107                ReturnGuard::Unconstructed(plain) | ReturnGuard::Constructed { plain, .. } => {
108                    (Some(plain.into()), statement)
109                }
110            })
111            .collect();
112
113        self.fold_returns(&mut block, expression_returns);
114
115        Constructor { annotations: constructor.annotations, block, span: constructor.span, id: constructor.id }
116    }
117}