1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
// Copyright (C) 2019-2024 Aleo Systems Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use crate::{FunctionInliner, Replacer};

use leo_ast::{
    CallExpression,
    Expression,
    ExpressionReconstructor,
    Identifier,
    ReturnStatement,
    Statement,
    StatementReconstructor,
    Type,
    UnitExpression,
    Variant,
};

use indexmap::IndexMap;
use itertools::Itertools;

impl ExpressionReconstructor for FunctionInliner<'_> {
    type AdditionalOutput = Vec<Statement>;

    fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
        // Type checking guarantees that only functions local to the program scope can be inlined.
        if input.program.unwrap() != self.program.unwrap() {
            return (Expression::Call(input), Default::default());
        }

        // Get the name of the callee function.
        let function_name = match *input.function {
            Expression::Identifier(identifier) => identifier.name,
            _ => unreachable!("Parser guarantees that `input.function` is always an identifier."),
        };

        // Lookup the reconstructed callee function.
        // Since this pass processes functions in post-order, the callee function is guaranteed to exist in `self.reconstructed_functions`
        let (_, callee) = self.reconstructed_functions.iter().find(|(symbol, _)| *symbol == function_name).unwrap();

        // Inline the callee function, if required, otherwise, return the call expression.
        match callee.variant {
            Variant::Inline => {
                // Construct a mapping from input variables of the callee function to arguments passed to the callee.
                let parameter_to_argument = callee
                    .input
                    .iter()
                    .map(|input| input.identifier().name)
                    .zip_eq(input.arguments)
                    .collect::<IndexMap<_, _>>();

                // Initializer `self.assignment_renamer` with the function parameters.
                self.assignment_renamer.load(
                    callee
                        .input
                        .iter()
                        .map(|input| (input.identifier().name, input.identifier().name, input.identifier().id)),
                );

                // Duplicate the body of the callee and create a unique assignment statement for each assignment in the body.
                // This is necessary to ensure the inlined variables do not conflict with variables in the caller.
                let unique_block = self.assignment_renamer.reconstruct_block(callee.block.clone()).0;

                // Reset `self.assignment_renamer`.
                self.assignment_renamer.clear();

                // Replace each input variable with the appropriate parameter.
                let replace = |identifier: &Identifier| match parameter_to_argument.get(&identifier.name) {
                    Some(expression) => expression.clone(),
                    None => Expression::Identifier(*identifier),
                };
                let mut inlined_statements = Replacer::new(replace).reconstruct_block(unique_block).0.statements;

                // If the inlined block returns a value, then use the value in place of the call expression, otherwise, use the unit expression.
                let result = match inlined_statements.last() {
                    Some(Statement::Return(_)) => {
                        // Note that this unwrap is safe since we know that the last statement is a return statement.
                        match inlined_statements.pop().unwrap() {
                            Statement::Return(ReturnStatement { expression, .. }) => expression,
                            _ => unreachable!("This branch checks that the last statement is a return statement."),
                        }
                    }
                    _ => {
                        let id = self.node_builder.next_id();
                        self.type_table.insert(id, Type::Unit);
                        Expression::Unit(UnitExpression { span: Default::default(), id })
                    }
                };

                (result, inlined_statements)
            }
            Variant::Function | Variant::AsyncFunction | Variant::Transition | Variant::AsyncTransition => {
                (Expression::Call(input), Default::default())
            }
        }
    }
}