leo_passes/processing_async/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::ProcessingAsyncVisitor;
18use crate::{CompilerState, Replacer};
19use indexmap::{IndexMap, IndexSet};
20use leo_ast::{
21    AstReconstructor,
22    AstVisitor,
23    AsyncExpression,
24    Block,
25    CallExpression,
26    Expression,
27    Function,
28    Identifier,
29    Input,
30    IterationStatement,
31    Location,
32    Node,
33    Path,
34    ProgramVisitor,
35    Statement,
36    TupleAccess,
37    TupleExpression,
38    TupleType,
39    Type,
40    Variant,
41};
42use leo_span::{Span, Symbol};
43
44/// Collects all symbol accesses within an async block,
45/// including both direct variable identifiers (`x`) and tuple field accesses (`x.0`, `x.1`, etc.).
46/// Each access is recorded as a pair: (Symbol, Option<usize>).
47/// - `None` means a direct variable access.
48/// - `Some(index)` means a tuple field access.
49struct SymbolAccessCollector<'a> {
50    state: &'a CompilerState,
51    symbol_accesses: IndexSet<(Vec<Symbol>, Option<usize>)>,
52}
53
54impl AstVisitor for SymbolAccessCollector<'_> {
55    type AdditionalInput = ();
56    type Output = ();
57
58    fn visit_path(&mut self, input: &Path, _: &Self::AdditionalInput) -> Self::Output {
59        self.symbol_accesses.insert((input.absolute_path(), None));
60    }
61
62    fn visit_tuple_access(&mut self, input: &TupleAccess, _: &Self::AdditionalInput) -> Self::Output {
63        // Here we assume that we can't have nested tuples which is currently guaranteed by type
64        // checking. This may change in the future.
65        if let Expression::Path(path) = &input.tuple {
66            // Futures aren't accessed by field; treat the whole thing as a direct variable
67            if let Some(Type::Future(_)) = self.state.type_table.get(&input.tuple.id()) {
68                self.symbol_accesses.insert((path.absolute_path(), None));
69            } else {
70                self.symbol_accesses.insert((path.absolute_path(), Some(input.index.value())));
71            }
72        } else {
73            self.visit_expression(&input.tuple, &());
74        }
75    }
76}
77
78impl ProgramVisitor for SymbolAccessCollector<'_> {}
79
80impl AstReconstructor for ProcessingAsyncVisitor<'_> {
81    type AdditionalInput = ();
82    type AdditionalOutput = ();
83
84    /// Transforms an `AsyncExpression` into a standalone async `Function` and returns
85    /// a call to this function. This process:
86    /// - Collects all referenced symbol accesses in the async block.
87    /// - Filters out mappings and constructs typed input parameters.
88    /// - Reconstructs an async function with those inputs and the original block.
89    /// - Builds and returns a `CallExpression` that invokes the new function.
90    fn reconstruct_async(&mut self, input: AsyncExpression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
91        // Step 1: Generate a unique name for the async function
92        let finalize_fn_name = self.state.assigner.unique_symbol(self.current_function, "$");
93
94        // Step 2: Collect all symbol accesses in the async block
95        let mut access_collector = SymbolAccessCollector { state: self.state, symbol_accesses: IndexSet::new() };
96        access_collector.visit_async(&input, &());
97
98        // Stores mapping from accessed symbol (and optional index) to the expression used in replacement
99        let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
100
101        // Helper to create a fresh `Identifier`
102        let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
103            name: symbol,
104            span: Span::default(),
105            id: slf.state.node_builder.next_id(),
106        };
107
108        // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access.
109        //
110        // This function handles both:
111        // - Direct variable accesses (e.g., `foo`)
112        // - Tuple element accesses (e.g., `foo.0`)
113        //
114        // For tuple accesses:
115        // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`.
116        // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by:
117        //     - Reusing existing inputs from `replacements` if already generated via prior field access.
118        //     - Creating new inputs and arguments for any missing elements.
119        // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`.
120        //
121        // This function also ensures deduplication by consulting the `replacements` map:
122        // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated.
123        // - This prevents repeated parameters for accesses like both `foo` and `foo.0`.
124        //
125        // # Parameters
126        // - `symbol`: The symbol being accessed.
127        // - `var_type`: The type of the symbol (may be a tuple or base type).
128        // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access.
129        //
130        // # Returns
131        // A `Vec<(Input, Expression)>`, where:
132        // - `Input` is a parameter for the generated async function.
133        // - `Expression` is the call-site argument expression used to invoke that parameter.
134        let mut make_inputs_and_arguments =
135            |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
136                if replacements.contains_key(&(symbol, index_opt)) {
137                    return vec![]; // No new input needed; argument already exists
138                }
139
140                match index_opt {
141                    Some(index) => {
142                        let Type::Tuple(TupleType { elements }) = var_type else {
143                            panic!("Expected tuple type when accessing tuple field: {symbol}.{index}");
144                        };
145
146                        let synthetic_name = format!("\"{symbol}.{index}\"");
147                        let synthetic_symbol = Symbol::intern(&synthetic_name);
148                        let identifier = make_identifier(slf, synthetic_symbol);
149
150                        let input = Input {
151                            identifier,
152                            mode: leo_ast::Mode::None,
153                            type_: elements[index].clone(),
154                            span: Span::default(),
155                            id: slf.state.node_builder.next_id(),
156                        };
157
158                        replacements.insert((symbol, Some(index)), Path::from(identifier).into_absolute().into());
159
160                        vec![(
161                            input,
162                            TupleAccess {
163                                tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
164                                index: index.into(),
165                                span: Span::default(),
166                                id: slf.state.node_builder.next_id(),
167                            }
168                            .into(),
169                        )]
170                    }
171
172                    None => match var_type {
173                        Type::Tuple(TupleType { elements }) => {
174                            let mut inputs_and_arguments = Vec::with_capacity(elements.len());
175                            let mut tuple_elements = Vec::with_capacity(elements.len());
176
177                            for (i, element_type) in elements.iter().enumerate() {
178                                let key = (symbol, Some(i));
179
180                                // Skip if this field is already handled
181                                if let Some(existing_expr) = replacements.get(&key) {
182                                    tuple_elements.push(existing_expr.clone());
183                                    continue;
184                                }
185
186                                // Otherwise, synthesize identifier and input
187                                let synthetic_name = format!("\"{symbol}.{i}\"");
188                                let synthetic_symbol = Symbol::intern(&synthetic_name);
189                                let identifier = make_identifier(slf, synthetic_symbol);
190
191                                let input = Input {
192                                    identifier,
193                                    mode: leo_ast::Mode::None,
194                                    type_: element_type.clone(),
195                                    span: Span::default(),
196                                    id: slf.state.node_builder.next_id(),
197                                };
198
199                                let expr: Expression = Path::from(identifier).into_absolute().into();
200
201                                replacements.insert(key, expr.clone());
202                                tuple_elements.push(expr.clone());
203                                inputs_and_arguments.push((
204                                    input,
205                                    TupleAccess {
206                                        tuple: Path::from(make_identifier(slf, symbol)).into_absolute().into(),
207                                        index: i.into(),
208                                        span: Span::default(),
209                                        id: slf.state.node_builder.next_id(),
210                                    }
211                                    .into(),
212                                ));
213                            }
214
215                            // Now insert the full tuple (even if all fields were already there)
216                            replacements.insert(
217                                (symbol, None),
218                                Expression::Tuple(TupleExpression {
219                                    elements: tuple_elements,
220                                    span: Span::default(),
221                                    id: slf.state.node_builder.next_id(),
222                                }),
223                            );
224
225                            inputs_and_arguments
226                        }
227
228                        _ => {
229                            let identifier = make_identifier(slf, symbol);
230                            let input = Input {
231                                identifier,
232                                mode: leo_ast::Mode::None,
233                                type_: var_type.clone(),
234                                span: Span::default(),
235                                id: slf.state.node_builder.next_id(),
236                            };
237
238                            replacements.insert((symbol, None), Path::from(identifier).into_absolute().into());
239
240                            let argument = Path::from(make_identifier(slf, symbol)).into_absolute().into();
241                            vec![(input, argument)]
242                        }
243                    },
244                }
245            };
246
247        // Step 3: Resolve symbol accesses into inputs and call arguments
248        let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
249            .symbol_accesses
250            .iter()
251            .filter_map(|(path, index)| {
252                // Skip globals and variables that are local to this block or to one of its children.
253
254                // Skip globals.
255                if self.state.symbol_table.lookup_global(&Location::new(self.current_program, path.to_vec())).is_some()
256                {
257                    return None;
258                }
259
260                // Skip variables that are local to this block or to one of its children.
261                let local_var_name = *path.last().expect("all paths must have at least one segment.");
262                if self.state.symbol_table.is_local_to_or_in_child_scope(input.block.id(), local_var_name) {
263                    return None;
264                }
265
266                // All other variables become parameters to the async function being built.
267                let var = self.state.symbol_table.lookup_local(local_var_name)?;
268                Some(make_inputs_and_arguments(self, local_var_name, &var.type_, *index))
269            })
270            .flatten()
271            .unzip();
272
273        // Step 4: Replacement logic used to patch the async block
274        let replace_expr = |expr: &Expression| -> Expression {
275            match expr {
276                Expression::Path(path) => {
277                    replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
278                }
279
280                Expression::TupleAccess(ta) => {
281                    if let Expression::Path(path) = &ta.tuple {
282                        replacements
283                            .get(&(path.identifier().name, Some(ta.index.value())))
284                            .cloned()
285                            .unwrap_or_else(|| expr.clone())
286                    } else {
287                        expr.clone()
288                    }
289                }
290
291                _ => expr.clone(),
292            }
293        };
294
295        // Step 5: Reconstruct the block with replaced references
296        let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
297        let new_block = replacer.reconstruct_block(input.block.clone()).0;
298
299        // Ensure we're not trying to capture too many variables
300        if inputs.len() > self.max_inputs {
301            self.state.handler.emit_err(leo_errors::StaticAnalyzerError::async_block_capturing_too_many_vars(
302                inputs.len(),
303                self.max_inputs,
304                input.span,
305            ));
306        }
307
308        // Step 6: Define the new async function
309        let function = Function {
310            annotations: vec![],
311            variant: Variant::AsyncFunction,
312            identifier: make_identifier(self, finalize_fn_name),
313            const_parameters: vec![],
314            input: inputs,
315            output: vec![],          // `async function`s can't have returns
316            output_type: Type::Unit, // Always the case for `async function`s
317            block: new_block,
318            span: input.span,
319            id: self.state.node_builder.next_id(),
320        };
321
322        // Register the generated function
323        self.new_async_functions.push((finalize_fn_name, function));
324
325        // Step 7: Create the call expression to invoke the async function
326        let call_to_finalize = CallExpression {
327            function: Path::new(
328                vec![],
329                make_identifier(self, finalize_fn_name),
330                true,
331                Some(vec![finalize_fn_name]), // the finalize function lives in the top level program scope
332                Span::default(),
333                self.state.node_builder.next_id(),
334            ),
335            const_arguments: vec![],
336            arguments,
337            program: Some(self.current_program),
338            span: input.span,
339            id: self.state.node_builder.next_id(),
340        };
341
342        self.modified = true;
343
344        (call_to_finalize.into(), ())
345    }
346
347    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
348        self.in_scope(input.id(), |slf| {
349            (
350                Block {
351                    statements: input.statements.into_iter().map(|s| slf.reconstruct_statement(s).0).collect(),
352                    span: input.span,
353                    id: input.id,
354                },
355                Default::default(),
356            )
357        })
358    }
359
360    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
361        self.in_scope(input.id(), |slf| {
362            (
363                IterationStatement {
364                    type_: input.type_.map(|ty| slf.reconstruct_type(ty).0),
365                    start: slf.reconstruct_expression(input.start, &()).0,
366                    stop: slf.reconstruct_expression(input.stop, &()).0,
367                    block: slf.reconstruct_block(input.block).0,
368                    ..input
369                }
370                .into(),
371                Default::default(),
372            )
373        })
374    }
375}