leo_passes/monomorphization/program.rs
1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::MonomorphizationVisitor;
18use leo_ast::{AstReconstructor, Module, Program, ProgramReconstructor, ProgramScope, Statement, Variant};
19use leo_span::sym;
20
21impl ProgramReconstructor for MonomorphizationVisitor<'_> {
22 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
23 // Set the current program name from the input.
24 self.program = input.program_id.name.name;
25
26 // We first reconstruct all structs. Struct fields can instantiate other generic structs that we need to handle
27 // first. We'll then address struct expressions and other struct type instantiations.
28 let struct_order = self.state.struct_graph.post_order().unwrap();
29
30 // Reconstruct structs in post-order.
31 for struct_name in &struct_order {
32 if let Some(r#struct) = self.struct_map.swap_remove(struct_name) {
33 // Perform monomorphization or other reconstruction logic.
34 let reconstructed_struct = self.reconstruct_struct(r#struct);
35 // Store the reconstructed struct for inclusion in the output scope.
36 self.reconstructed_structs.insert(struct_name.clone(), reconstructed_struct);
37 }
38 }
39
40 // If there are some structs left in `struct_map`, that means these are dead structs since they do not show up
41 // in `struct_graph`. Therefore, they won't be reconstructed, implying a change in the reconstructed program.
42 if !self.struct_map.is_empty() {
43 self.changed = true;
44 }
45
46 // Next, handle generic functions
47 //
48 // Compute a post-order traversal of the call graph. This ensures that functions are processed after all their callees.
49 // Make sure to only compute the post order by considering the entry points of the program, which are `async transition`, `transition` and `function`.
50 // We must consider entry points to ignore const generic inlines that have already been monomorphized but never called.
51 let order = self
52 .state
53 .call_graph
54 .post_order_with_filter(|location| {
55 // Filter out locations that are not from this program.
56 if location.program != self.program {
57 return false;
58 }
59 // Allow constructors.
60 if location.program == self.program && location.path == vec![sym::constructor] {
61 return true;
62 }
63 self.function_map
64 .get(&location.path)
65 .map(|f| {
66 matches!(
67 f.variant,
68 Variant::AsyncTransition | Variant::Transition | Variant::Function | Variant::Script
69 )
70 })
71 .unwrap_or(false)
72 })
73 .unwrap() // This unwrap is safe because the type checker guarantees an acyclic graph.
74 .into_iter()
75 .filter(|location| location.program == self.program).collect::<Vec<_>>();
76
77 for function_name in &order {
78 // Reconstruct functions in post-order.
79 if let Some(function) = self.function_map.swap_remove(&function_name.path) {
80 // Reconstruct the function.
81 let reconstructed_function = self.reconstruct_function(function.clone());
82 // Add the reconstructed function to the mapping.
83 self.reconstructed_functions.insert(function_name.path.clone(), reconstructed_function);
84 }
85 }
86
87 // Get any
88
89 // Now reconstruct mappings and storage variables
90 let mappings =
91 input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect();
92 let storage_variables = input
93 .storage_variables
94 .into_iter()
95 .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
96 .collect();
97
98 // Then consts
99 let consts = input
100 .consts
101 .into_iter()
102 .map(|(i, c)| match self.reconstruct_const(c) {
103 (Statement::Const(declaration), _) => (i, declaration),
104 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
105 })
106 .collect();
107
108 // Reconstruct the constructor last, as it cannot be called by any other function.
109 let constructor = input.constructor.map(|c| self.reconstruct_constructor(c));
110
111 // Now retain only functions that are either not yet monomorphized or are still referenced by calls.
112 self.reconstructed_functions.retain(|f, _| {
113 let is_monomorphized = self.monomorphized_functions.contains(f);
114 let is_still_called = self.unresolved_calls.iter().any(|c| &c.function.absolute_path() == f);
115 !is_monomorphized || is_still_called
116 });
117
118 // Move reconstructed functions into the final `ProgramScope`.
119 // Make sure to place transitions before all the other functions.
120 let (transitions, mut non_transitions): (Vec<_>, Vec<_>) =
121 self.reconstructed_functions.clone().into_iter().partition(|(_, f)| f.variant.is_transition());
122
123 let mut all_functions = transitions;
124 all_functions.append(&mut non_transitions);
125
126 // Return the fully reconstructed scope with updated functions.
127 ProgramScope {
128 program_id: input.program_id,
129 structs: self
130 .reconstructed_structs
131 .iter()
132 .filter_map(|(path, c)| {
133 // only consider structs defined at program scope. The rest will be added to their parent module.
134 path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, c.clone()))
135 })
136 .collect(),
137 mappings,
138 storage_variables,
139 functions: all_functions
140 .iter()
141 .filter_map(|(path, f)| {
142 // only consider functions defined at program scope. The rest will be added to their parent module.
143 path.split_last().filter(|(_, rest)| rest.is_empty()).map(|(last, _)| (*last, f.clone()))
144 })
145 .collect(),
146 constructor,
147 consts,
148 span: input.span,
149 }
150 }
151
152 fn reconstruct_program(&mut self, input: Program) -> Program {
153 // Populate `self.function_map` using the functions in the program scopes and the modules
154 input
155 .modules
156 .iter()
157 .flat_map(|(module_path, m)| {
158 m.functions.iter().map(move |(name, f)| {
159 (module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
160 })
161 })
162 .chain(
163 input
164 .program_scopes
165 .iter()
166 .flat_map(|(_, scope)| scope.functions.iter().map(|(name, f)| (vec![*name], f.clone()))),
167 )
168 .for_each(|(full_name, f)| {
169 self.function_map.insert(full_name, f);
170 });
171
172 // Populate `self.struct_map` using the structs in the program scopes and the modules
173 input
174 .modules
175 .iter()
176 .flat_map(|(module_path, m)| {
177 m.structs.iter().map(move |(name, f)| {
178 (module_path.iter().cloned().chain(std::iter::once(*name)).collect(), f.clone())
179 })
180 })
181 .chain(
182 input
183 .program_scopes
184 .iter()
185 .flat_map(|(_, scope)| scope.structs.iter().map(|(name, f)| (vec![*name], f.clone()))),
186 )
187 .for_each(|(full_name, f)| {
188 self.struct_map.insert(full_name, f);
189 });
190
191 // Reconstruct prrogram scopes first then reconstruct the modules after `self.reconstructed_structs`
192 // and `self.reconstructed_functions` have been populated.
193 Program {
194 program_scopes: input
195 .program_scopes
196 .into_iter()
197 .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
198 .collect(),
199 modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
200 ..input
201 }
202 }
203
204 fn reconstruct_module(&mut self, input: Module) -> Module {
205 // Here we're reconstructing structs and functions from `reconstructed_functions` and
206 // `reconstructed_structs` based on their paths and whether they match the module path
207 Module {
208 structs: self
209 .reconstructed_structs
210 .iter()
211 .filter_map(|(path, c)| path.split_last().map(|(last, rest)| (last, rest, c)))
212 .filter(|&(_, rest, _)| input.path == rest)
213 .map(|(last, _, c)| (*last, c.clone()))
214 .collect(),
215
216 functions: self
217 .reconstructed_functions
218 .iter()
219 .filter_map(|(path, f)| path.split_last().map(|(last, rest)| (last, rest, f)))
220 .filter(|&(_, rest, _)| input.path == rest)
221 .map(|(last, _, f)| (*last, f.clone()))
222 .collect(),
223 ..input
224 }
225 }
226}