leo_passes/function_inlining/
ast.rs1use super::FunctionInliningVisitor;
18use crate::Replacer;
19
20use leo_ast::*;
21
22use indexmap::IndexMap;
23use itertools::Itertools;
24
25impl AstReconstructor for FunctionInliningVisitor<'_> {
26 type AdditionalOutput = Vec<Statement>;
27
28 fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
30 if input.program.unwrap() != self.program {
32 return (input.into(), Default::default());
33 }
34
35 let (_, callee) =
38 self.reconstructed_functions.iter().find(|(symbol, _)| *symbol == input.function.name).unwrap();
39
40 match callee.variant {
42 Variant::Inline => {
43 let parameter_to_argument = callee
45 .input
46 .iter()
47 .map(|input| input.identifier().name)
48 .zip_eq(input.arguments)
49 .collect::<IndexMap<_, _>>();
50
51 let replace = |identifier: &Identifier| {
53 parameter_to_argument.get(&identifier.name).cloned().unwrap_or(Expression::Identifier(*identifier))
54 };
55
56 let mut inlined_statements = Replacer::new(replace, &self.state.node_builder)
57 .reconstruct_block(callee.block.clone())
58 .0
59 .statements;
60
61 let result = match inlined_statements.last() {
63 Some(Statement::Return(_)) => {
64 match inlined_statements.pop().unwrap() {
66 Statement::Return(ReturnStatement { expression, .. }) => expression,
67 _ => panic!("This branch checks that the last statement is a return statement."),
68 }
69 }
70 _ => {
71 let id = self.state.node_builder.next_id();
72 self.state.type_table.insert(id, Type::Unit);
73 UnitExpression { span: Default::default(), id }.into()
74 }
75 };
76
77 (result, inlined_statements)
78 }
79 Variant::Function
80 | Variant::Script
81 | Variant::AsyncFunction
82 | Variant::Transition
83 | Variant::AsyncTransition => (input.into(), Default::default()),
84 }
85 }
86
87 fn reconstruct_assign(&mut self, _input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
89 panic!("`AssignStatement`s should not exist in the AST at this phase of compilation.")
90 }
91
92 fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
94 let mut statements = Vec::with_capacity(block.statements.len());
95
96 for statement in block.statements {
97 let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
98 statements.extend(additional_statements);
99 statements.push(reconstructed_statement);
100 }
101
102 (Block { span: block.span, statements, id: block.id }, Default::default())
103 }
104
105 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
107 if !self.is_async {
108 panic!("`ConditionalStatement`s should not be in the AST at this phase of compilation.")
109 } else {
110 (
111 ConditionalStatement {
112 condition: self.reconstruct_expression(input.condition).0,
113 then: self.reconstruct_block(input.then).0,
114 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
115 span: input.span,
116 id: input.id,
117 }
118 .into(),
119 Default::default(),
120 )
121 }
122 }
123
124 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
127 let (value, mut statements) = self.reconstruct_expression(input.value);
128 match (input.place, value) {
129 (DefinitionPlace::Multiple(left), Expression::Tuple(right)) => {
131 assert_eq!(left.len(), right.elements.len());
132 for (identifier, rhs_value) in left.into_iter().zip(right.elements) {
133 let stmt = DefinitionStatement {
134 place: DefinitionPlace::Single(identifier),
135 type_: None,
136 value: rhs_value,
137 span: Default::default(),
138 id: self.state.node_builder.next_id(),
139 }
140 .into();
141
142 statements.push(stmt);
143 }
144 (Statement::dummy(), statements)
145 }
146
147 (place, value) => {
148 input.value = value;
149 input.place = place;
150 (input.into(), statements)
151 }
152 }
153 }
154
155 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
157 let (expression, additional_statements) = self.reconstruct_expression(input.expression);
160
161 let statement = match expression {
163 Expression::Unit(_) => Statement::dummy(),
164 _ => ExpressionStatement { expression, ..input }.into(),
165 };
166
167 (statement, additional_statements)
168 }
169
170 fn reconstruct_iteration(&mut self, _: IterationStatement) -> (Statement, Self::AdditionalOutput) {
172 panic!("`IterationStatement`s should not be in the AST at this phase of compilation.");
173 }
174}