leo_passes/monomorphization/
program.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::MonomorphizationVisitor;
18use leo_ast::{AstReconstructor, Module, Program, ProgramReconstructor, ProgramScope, Statement, Variant};
19use leo_span::sym;
20
21impl ProgramReconstructor for MonomorphizationVisitor<'_> {
22    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
23        // Set the current program name from the input.
24        self.program = input.program_id.name.name;
25
26        // We first reconstruct all structs. Struct fields can instantiate other generic structs that we need to handle
27        // first. We'll then address struct expressions and other struct type instantiations.
28        let struct_order = self.state.struct_graph.post_order().unwrap();
29
30        // Reconstruct structs in post-order.
31        for struct_name in &struct_order {
32            if let Some(r#struct) = self.struct_map.swap_remove(struct_name) {
33                // Perform monomorphization or other reconstruction logic.
34                let reconstructed_struct = self.reconstruct_struct(r#struct);
35                // Store the reconstructed struct for inclusion in the output scope.
36                self.reconstructed_structs.insert(struct_name.clone(), reconstructed_struct);
37            }
38        }
39
40        // If there are some structs left in `struct_map`, that means these are dead structs since they do not show up
41        // in `struct_graph`. Therefore, they won't be reconstructed, implying a change in the reconstructed program.
42        if !self.struct_map.is_empty() {
43            self.changed = true;
44        }
45
46        // Next, handle generic functions
47        //
48        // Compute a post-order traversal of the call graph. This ensures that functions are processed after all their callees.
49        // Make sure to only compute the post order by considering the entry points of the program, which are `async transition`, `transition` and `function`.
50        // We must consider entry points to ignore const generic inlines that have already been monomorphized but never called.
51        let order = self
52            .state
53            .call_graph
54            .post_order_with_filter(|location| {
55                // Filter out locations that are not from this program.
56                if location.program != self.program {
57                    return false;
58                }
59                // Allow constructors.
60                if location.program == self.program && location.path == vec![sym::constructor] {
61                    return true;
62                }
63                self.function_map
64                    .get(&location.path)
65                    .map(|f| {
66                        matches!(
67                            f.variant,
68                            Variant::AsyncTransition | Variant::Transition | Variant::Function | Variant::Script
69                        )
70                    })
71                    .unwrap_or(false)
72            })
73            .unwrap() // This unwrap is safe because the type checker guarantees an acyclic graph.
74            .into_iter()
75            .filter(|location| location.program == self.program).collect::<Vec<_>>();
76
77        for function_name in &order {
78            // Reconstruct functions in post-order.
79            if let Some(function) = self.function_map.swap_remove(&function_name.path) {
80                // Reconstruct the function.
81                let reconstructed_function = self.reconstruct_function(function.clone());
82                // Add the reconstructed function to the mapping.
83                self.reconstructed_functions.insert(function_name.path.clone(), reconstructed_function);
84            }
85        }
86
87        // Get any
88
89        // Now reconstruct mappings
90        let mappings =
91            input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect();
92
93        // Then consts
94        let consts = input
95            .consts
96            .into_iter()
97            .map(|(i, c)| match self.reconstruct_const(c) {
98                (Statement::Const(declaration), _) => (i, declaration),
99                _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
100            })
101            .collect();
102
103        // Reconstruct the constructor last, as it cannot be called by any other function.
104        let constructor = input.constructor.map(|c| self.reconstruct_constructor(c));
105
106        // Now retain only functions that are either not yet monomorphized or are still referenced by calls.
107        self.reconstructed_functions.retain(|f, _| {
108            let is_monomorphized = self.monomorphized_functions.contains(f);
109            let is_still_called = self.unresolved_calls.iter().any(|c| &c.function.absolute_path() == f);
110            !is_monomorphized || is_still_called
111        });
112
113        // Move reconstructed functions into the final `ProgramScope`.
114        // Make sure to place transitions before all the other functions.
115        let (transitions, mut non_transitions): (Vec<_>, Vec<_>) =
116            self.reconstructed_functions.clone().into_iter().partition(|(_, f)| f.variant.is_transition());
117
118        let mut all_functions = transitions;
119        all_functions.append(&mut non_transitions);
120
121        // Return the fully reconstructed scope with updated functions.
122        ProgramScope {
123            program_id: input.program_id,
124            structs: self
125                .reconstructed_structs
126                .iter()
127                .filter_map(|(path, c)| {
128                    // only consider structs defined at program scope. The rest will be added to their parent module.
129                    path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, c.clone()))
130                })
131                .collect(),
132            mappings,
133            functions: all_functions
134                .iter()
135                .filter_map(|(path, f)| {
136                    // only consider functions defined at program scope. The rest will be added to their parent module.
137                    path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, f.clone()))
138                })
139                .collect(),
140            constructor,
141            consts,
142            span: input.span,
143        }
144    }
145
146    fn reconstruct_program(&mut self, input: Program) -> Program {
147        // Populate `self.function_map` using the functions in the program scopes and the modules
148        input
149            .modules
150            .iter()
151            .flat_map(|(module_path, m)| {
152                m.functions.iter().map(move |(name, f)| {
153                    (module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
154                })
155            })
156            .chain(
157                input
158                    .program_scopes
159                    .iter()
160                    .flat_map(|(_, scope)| scope.functions.iter().map(|(name, f)| (vec![*name], f.clone()))),
161            )
162            .for_each(|(full_name, f)| {
163                self.function_map.insert(full_name, f);
164            });
165
166        // Populate `self.struct_map` using the structs in the program scopes and the modules
167        input
168            .modules
169            .iter()
170            .flat_map(|(module_path, m)| {
171                m.structs.iter().map(move |(name, f)| {
172                    (module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
173                })
174            })
175            .chain(
176                input
177                    .program_scopes
178                    .iter()
179                    .flat_map(|(_, scope)| scope.structs.iter().map(|(name, f)| (vec![*name], f.clone()))),
180            )
181            .for_each(|(full_name, f)| {
182                self.struct_map.insert(full_name, f);
183            });
184
185        // Reconstruct prrogram scopes first then reconstruct the modules after `self.reconstructed_structs`
186        // and `self.reconstructed_functions` have been populated.
187        Program {
188            program_scopes: input
189                .program_scopes
190                .into_iter()
191                .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
192                .collect(),
193            modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
194            ..input
195        }
196    }
197
198    fn reconstruct_module(&mut self, input: Module) -> Module {
199        // Here we're reconstructing structs and functions from `reconstructed_functions` and
200        // `reconstructed_structs` based on their paths and whether they match the module path
201        Module {
202            structs: self
203                .reconstructed_structs
204                .iter()
205                .filter_map(|(path, c)| path.split_last().map(|(last, rest)| (last, rest, c)))
206                .filter(|&(_, rest, _)| input.path == rest)
207                .map(|(last, _, c)| (*last, c.clone()))
208                .collect(),
209
210            functions: self
211                .reconstructed_functions
212                .iter()
213                .filter_map(|(path, f)| path.split_last().map(|(last, rest)| (last, rest, f)))
214                .filter(|&(_, rest, _)| input.path == rest)
215                .map(|(last, _, f)| (*last, f.clone()))
216                .collect(),
217            ..input
218        }
219    }
220}