leo_passes/function_inlining/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::FunctionInliningVisitor;
18use crate::Replacer;
19
20use leo_ast::*;
21
22use indexmap::IndexMap;
23use itertools::Itertools;
24
25impl AstReconstructor for FunctionInliningVisitor<'_> {
26    type AdditionalOutput = Vec<Statement>;
27
28    /* Expressions */
29    fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
30        // Type checking guarantees that only functions local to the program scope can be inlined.
31        if input.program.unwrap() != self.program {
32            return (input.into(), Default::default());
33        }
34
35        // Lookup the reconstructed callee function.
36        // Since this pass processes functions in post-order, the callee function is guaranteed to exist in `self.reconstructed_functions`
37        let (_, callee) =
38            self.reconstructed_functions.iter().find(|(symbol, _)| *symbol == input.function.name).unwrap();
39
40        // Inline the callee function, if required, otherwise, return the call expression.
41        match callee.variant {
42            Variant::Inline => {
43                // Construct a mapping from input variables of the callee function to arguments passed to the callee.
44                let parameter_to_argument = callee
45                    .input
46                    .iter()
47                    .map(|input| input.identifier().name)
48                    .zip_eq(input.arguments)
49                    .collect::<IndexMap<_, _>>();
50
51                // Replace each input variable with the appropriate parameter.
52                let replace = |identifier: &Identifier| {
53                    parameter_to_argument.get(&identifier.name).cloned().unwrap_or(Expression::Identifier(*identifier))
54                };
55
56                let mut inlined_statements = Replacer::new(replace, &self.state.node_builder)
57                    .reconstruct_block(callee.block.clone())
58                    .0
59                    .statements;
60
61                // If the inlined block returns a value, then use the value in place of the call expression; otherwise, use the unit expression.
62                let result = match inlined_statements.last() {
63                    Some(Statement::Return(_)) => {
64                        // Note that this unwrap is safe since we know that the last statement is a return statement.
65                        match inlined_statements.pop().unwrap() {
66                            Statement::Return(ReturnStatement { expression, .. }) => expression,
67                            _ => panic!("This branch checks that the last statement is a return statement."),
68                        }
69                    }
70                    _ => {
71                        let id = self.state.node_builder.next_id();
72                        self.state.type_table.insert(id, Type::Unit);
73                        UnitExpression { span: Default::default(), id }.into()
74                    }
75                };
76
77                (result, inlined_statements)
78            }
79            Variant::Function
80            | Variant::Script
81            | Variant::AsyncFunction
82            | Variant::Transition
83            | Variant::AsyncTransition => (input.into(), Default::default()),
84        }
85    }
86
87    /* Statements */
88    fn reconstruct_assign(&mut self, _input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
89        panic!("`AssignStatement`s should not exist in the AST at this phase of compilation.")
90    }
91
92    /// Reconstructs the statements inside a basic block, accumulating any statements produced by function inlining.
93    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
94        let mut statements = Vec::with_capacity(block.statements.len());
95
96        for statement in block.statements {
97            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
98            statements.extend(additional_statements);
99            statements.push(reconstructed_statement);
100        }
101
102        (Block { span: block.span, statements, id: block.id }, Default::default())
103    }
104
105    /// Flattening removes conditional statements from the program.
106    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
107        if !self.is_async {
108            panic!("`ConditionalStatement`s should not be in the AST at this phase of compilation.")
109        } else {
110            (
111                ConditionalStatement {
112                    condition: self.reconstruct_expression(input.condition).0,
113                    then: self.reconstruct_block(input.then).0,
114                    otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
115                    span: input.span,
116                    id: input.id,
117                }
118                .into(),
119                Default::default(),
120            )
121        }
122    }
123
124    /// Reconstruct a definition statement by inlining any function calls.
125    /// This function also segments tuple assignment statements into multiple assignment statements.
126    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
127        let (value, mut statements) = self.reconstruct_expression(input.value);
128        match (input.place, value) {
129            // If we just inlined the production of a tuple literal, we need multiple definition statements.
130            (DefinitionPlace::Multiple(left), Expression::Tuple(right)) => {
131                assert_eq!(left.len(), right.elements.len());
132                for (identifier, rhs_value) in left.into_iter().zip(right.elements) {
133                    let stmt = DefinitionStatement {
134                        place: DefinitionPlace::Single(identifier),
135                        type_: None,
136                        value: rhs_value,
137                        span: Default::default(),
138                        id: self.state.node_builder.next_id(),
139                    }
140                    .into();
141
142                    statements.push(stmt);
143                }
144                (Statement::dummy(), statements)
145            }
146
147            (place, value) => {
148                input.value = value;
149                input.place = place;
150                (input.into(), statements)
151            }
152        }
153    }
154
155    /// Reconstructs expression statements by inlining any function calls.
156    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
157        // Reconstruct the expression.
158        // Note that type checking guarantees that the expression is a function call.
159        let (expression, additional_statements) = self.reconstruct_expression(input.expression);
160
161        // If the resulting expression is a unit expression, return a dummy statement.
162        let statement = match expression {
163            Expression::Unit(_) => Statement::dummy(),
164            _ => ExpressionStatement { expression, ..input }.into(),
165        };
166
167        (statement, additional_statements)
168    }
169
170    /// Loop unrolling unrolls and removes iteration statements from the program.
171    fn reconstruct_iteration(&mut self, _: IterationStatement) -> (Statement, Self::AdditionalOutput) {
172        panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
173    }
174}