leo_passes/processing_async/ast.rs
1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::ProcessingAsyncVisitor;
18use crate::{CompilerState, Replacer};
19use indexmap::{IndexMap, IndexSet};
20use leo_ast::{
21 AstReconstructor,
22 AstVisitor,
23 AsyncExpression,
24 Block,
25 CallExpression,
26 Expression,
27 Function,
28 Identifier,
29 Input,
30 IterationStatement,
31 Location,
32 Node,
33 Path,
34 ProgramVisitor,
35 Statement,
36 TupleAccess,
37 TupleExpression,
38 TupleType,
39 Type,
40 Variant,
41};
42use leo_span::{Span, Symbol};
43
44/// Collects all symbol accesses within an async block,
45/// including both direct variable identifiers (`x`) and tuple field accesses (`x.0`, `x.1`, etc.).
46/// Each access is recorded as a pair: (Symbol, Option<usize>).
47/// - `None` means a direct variable access.
48/// - `Some(index)` means a tuple field access.
49struct SymbolAccessCollector<'a> {
50 state: &'a CompilerState,
51 symbol_accesses: IndexSet<(Vec<Symbol>, Option<usize>)>,
52}
53
54impl AstVisitor for SymbolAccessCollector<'_> {
55 type AdditionalInput = ();
56 type Output = ();
57
58 fn visit_path(&mut self, input: &Path, _: &Self::AdditionalInput) -> Self::Output {
59 self.symbol_accesses.insert((input.absolute_path(), None));
60 }
61
62 fn visit_tuple_access(&mut self, input: &TupleAccess, _: &Self::AdditionalInput) -> Self::Output {
63 // Here we assume that we can't have nested tuples which is currently guaranteed by type
64 // checking. This may change in the future.
65 if let Expression::Path(path) = &input.tuple {
66 // Futures aren't accessed by field; treat the whole thing as a direct variable
67 if let Some(Type::Future(_)) = self.state.type_table.get(&input.tuple.id()) {
68 self.symbol_accesses.insert((path.absolute_path(), None));
69 } else {
70 self.symbol_accesses.insert((path.absolute_path(), Some(input.index.value())));
71 }
72 } else {
73 self.visit_expression(&input.tuple, &());
74 }
75 }
76}
77
78impl ProgramVisitor for SymbolAccessCollector<'_> {}
79
80impl AstReconstructor for ProcessingAsyncVisitor<'_> {
81 type AdditionalInput = ();
82 type AdditionalOutput = ();
83
84 /// Transforms an `AsyncExpression` into a standalone async `Function` and returns
85 /// a call to this function. This process:
86 /// - Collects all referenced symbol accesses in the async block.
87 /// - Filters out mappings and constructs typed input parameters.
88 /// - Reconstructs an async function with those inputs and the original block.
89 /// - Builds and returns a `CallExpression` that invokes the new function.
90 fn reconstruct_async(&mut self, input: AsyncExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
91 // Step 1: Generate a unique name for the async function
92 let finalize_fn_name = self.state.assigner.unique_symbol(self.current_function, "$");
93
94 // Step 2: Collect all symbol accesses in the async block
95 let mut access_collector = SymbolAccessCollector { state: self.state, symbol_accesses: IndexSet::new() };
96 access_collector.visit_async(&input, &());
97
98 // Stores mapping from accessed symbol (and optional index) to the expression used in replacement
99 let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
100
101 // Helper to create a fresh `Identifier`
102 let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
103 name: symbol,
104 span: Span::default(),
105 id: slf.state.node_builder.next_id(),
106 };
107
108 // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access.
109 //
110 // This function handles both:
111 // - Direct variable accesses (e.g., `foo`)
112 // - Tuple element accesses (e.g., `foo.0`)
113 //
114 // For tuple accesses:
115 // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`.
116 // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by:
117 // - Reusing existing inputs from `replacements` if already generated via prior field access.
118 // - Creating new inputs and arguments for any missing elements.
119 // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`.
120 //
121 // This function also ensures deduplication by consulting the `replacements` map:
122 // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated.
123 // - This prevents repeated parameters for accesses like both `foo` and `foo.0`.
124 //
125 // # Parameters
126 // - `symbol`: The symbol being accessed.
127 // - `var_type`: The type of the symbol (may be a tuple or base type).
128 // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access.
129 //
130 // # Returns
131 // A `Vec<(Input, Expression)>`, where:
132 // - `Input` is a parameter for the generated async function.
133 // - `Expression` is the call-site argument expression used to invoke that parameter.
134 let mut make_inputs_and_arguments =
135 |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
136 if replacements.contains_key(&(symbol, index_opt)) {
137 return vec![]; // No new input needed; argument already exists
138 }
139
140 match index_opt {
141 Some(index) => {
142 let Type::Tuple(TupleType { elements }) = var_type else {
143 panic!("Expected tuple type when accessing tuple field: {symbol}.{index}");
144 };
145
146 let synthetic_name = format!("\"{symbol}.{index}\"");
147 let synthetic_symbol = Symbol::intern(&synthetic_name);
148 let identifier = make_identifier(slf, synthetic_symbol);
149
150 let input = Input {
151 identifier,
152 mode: leo_ast::Mode::None,
153 type_: elements[index].clone(),
154 span: Span::default(),
155 id: slf.state.node_builder.next_id(),
156 };
157
158 replacements.insert((symbol, Some(index)), Path::from(identifier).into_absolute().into());
159
160 vec![(
161 input,
162 TupleAccess {
163 tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
164 index: index.into(),
165 span: Span::default(),
166 id: slf.state.node_builder.next_id(),
167 }
168 .into(),
169 )]
170 }
171
172 None => match var_type {
173 Type::Tuple(TupleType { elements }) => {
174 let mut inputs_and_arguments = Vec::with_capacity(elements.len());
175 let mut tuple_elements = Vec::with_capacity(elements.len());
176
177 for (i, element_type) in elements.iter().enumerate() {
178 let key = (symbol, Some(i));
179
180 // Skip if this field is already handled
181 if let Some(existing_expr) = replacements.get(&key) {
182 tuple_elements.push(existing_expr.clone());
183 continue;
184 }
185
186 // Otherwise, synthesize identifier and input
187 let synthetic_name = format!("\"{symbol}.{i}\"");
188 let synthetic_symbol = Symbol::intern(&synthetic_name);
189 let identifier = make_identifier(slf, synthetic_symbol);
190
191 let input = Input {
192 identifier,
193 mode: leo_ast::Mode::None,
194 type_: element_type.clone(),
195 span: Span::default(),
196 id: slf.state.node_builder.next_id(),
197 };
198
199 let expr: Expression = Path::from(identifier).into_absolute().into();
200
201 replacements.insert(key, expr.clone());
202 tuple_elements.push(expr.clone());
203 inputs_and_arguments.push((
204 input,
205 TupleAccess {
206 tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
207 index: i.into(),
208 span: Span::default(),
209 id: slf.state.node_builder.next_id(),
210 }
211 .into(),
212 ));
213 }
214
215 // Now insert the full tuple (even if all fields were already there)
216 replacements.insert(
217 (symbol, None),
218 Expression::Tuple(TupleExpression {
219 elements: tuple_elements,
220 span: Span::default(),
221 id: slf.state.node_builder.next_id(),
222 }),
223 );
224
225 inputs_and_arguments
226 }
227
228 _ => {
229 let identifier = make_identifier(slf, symbol);
230 let input = Input {
231 identifier,
232 mode: leo_ast::Mode::None,
233 type_: var_type.clone(),
234 span: Span::default(),
235 id: slf.state.node_builder.next_id(),
236 };
237
238 replacements.insert((symbol, None), Path::from(identifier).into_absolute().into());
239
240 let argument = Path::from(make_identifier(slf, symbol)).into_absolute().into();
241 vec![(input, argument)]
242 }
243 },
244 }
245 };
246
247 // Step 3: Resolve symbol accesses into inputs and call arguments
248 let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
249 .symbol_accesses
250 .iter()
251 .filter_map(|(path, index)| {
252 // Skip globals and variables that are local to this block or to one of its children.
253
254 // Skip globals.
255 if self.state.symbol_table.lookup_global(&Location::new(self.current_program, path.to_vec())).is_some()
256 {
257 return None;
258 }
259
260 // Skip variables that are local to this block or to one of its children.
261 let local_var_name = *path.last().expect("all paths must have at least one segment.");
262 if self.state.symbol_table.is_local_to_or_in_child_scope(input.block.id(), local_var_name) {
263 return None;
264 }
265
266 // All other variables become parameters to the async function being built.
267 let var = self.state.symbol_table.lookup_local(local_var_name)?;
268 Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index))
269 })
270 .flatten()
271 .unzip();
272
273 // Step 4: Replacement logic used to patch the async block
274 let replace_expr = |expr: &Expression| -> Expression {
275 match expr {
276 Expression::Path(path) => {
277 replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
278 }
279
280 Expression::TupleAccess(ta) => {
281 if let Expression::Path(path) = &ta.tuple {
282 replacements
283 .get(&(path.identifier().name, Some(ta.index.value())))
284 .cloned()
285 .unwrap_or_else(|| expr.clone())
286 } else {
287 expr.clone()
288 }
289 }
290
291 _ => expr.clone(),
292 }
293 };
294
295 // Step 5: Reconstruct the block with replaced references
296 let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
297 let new_block = replacer.reconstruct_block(input.block.clone()).0;
298
299 // Ensure we're not trying to capture too many variables
300 if inputs.len() > self.max_inputs {
301 self.state.handler.emit_err(leo_errors::StaticAnalyzerError::async_block_capturing_too_many_vars(
302 inputs.len(),
303 self.max_inputs,
304 input.span,
305 ));
306 }
307
308 // Step 6: Define the new async function
309 let function = Function {
310 annotations: vec![],
311 variant: Variant::AsyncFunction,
312 identifier: make_identifier(self, finalize_fn_name),
313 const_parameters: vec![],
314 input: inputs,
315 output: vec![], // `async function`s can't have returns
316 output_type: Type::Unit, // Always the case for `async function`s
317 block: new_block,
318 span: input.span,
319 id: self.state.node_builder.next_id(),
320 };
321
322 // Register the generated function
323 self.new_async_functions.push((finalize_fn_name, function));
324
325 // Step 7: Create the call expression to invoke the async function
326 let call_to_finalize = CallExpression {
327 function: Path::new(
328 vec![],
329 make_identifier(self, finalize_fn_name),
330 true,
331 Some(vec![finalize_fn_name]), // the finalize function lives in the top level program scope
332 Span::default(),
333 self.state.node_builder.next_id(),
334 ),
335 const_arguments: vec![],
336 arguments,
337 program: Some(self.current_program),
338 span: input.span,
339 id: self.state.node_builder.next_id(),
340 };
341
342 self.modified = true;
343
344 (call_to_finalize.into(), ())
345 }
346
347 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
348 self.in_scope(input.id(), |slf| {
349 (
350 Block {
351 statements: input.statements.into_iter().map(|s| slf.reconstruct_statement(s).0).collect(),
352 span: input.span,
353 id: input.id,
354 },
355 Default::default(),
356 )
357 })
358 }
359
360 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
361 self.in_scope(input.id(), |slf| {
362 (
363 IterationStatement {
364 type_: input.type_.map(|ty| slf.reconstruct_type(ty).0),
365 start: slf.reconstruct_expression(input.start, &()).0,
366 stop: slf.reconstruct_expression(input.stop, &()).0,
367 block: slf.reconstruct_block(input.block).0,
368 ..input
369 }
370 .into(),
371 Default::default(),
372 )
373 })
374 }
375}