leo_passes/write_transforming/
expression.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::WriteTransformingVisitor;
18
19use leo_ast::{
20    ArrayAccess,
21    ArrayExpression,
22    AssignStatement,
23    AssociatedFunctionExpression,
24    BinaryExpression,
25    CallExpression,
26    CastExpression,
27    Expression,
28    ExpressionReconstructor,
29    Identifier,
30    MemberAccess,
31    Node,
32    Statement,
33    StructExpression,
34    StructVariableInitializer,
35    TernaryExpression,
36    TupleAccess,
37    TupleExpression,
38    Type,
39    UnaryExpression,
40};
41use leo_span::Symbol;
42
43impl WriteTransformingVisitor<'_> {
44    pub fn get_array_member(&self, array_name: Symbol, index: &Expression) -> Option<Identifier> {
45        let members = self.array_members.get(&array_name)?;
46        let Expression::Literal(lit) = index else {
47            panic!("Const propagation should have ensured this is a literal.");
48        };
49        let index = lit
50            .as_u32()
51            .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
52            as usize;
53        Some(members[index])
54    }
55
56    pub fn get_struct_member(&self, struct_name: Symbol, field_name: Symbol) -> Option<Identifier> {
57        let members = self.struct_members.get(&struct_name)?;
58        members.get(&field_name).cloned()
59    }
60}
61
62impl ExpressionReconstructor for WriteTransformingVisitor<'_> {
63    type AdditionalOutput = Vec<Statement>;
64
65    fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
66        let ty = self.state.type_table.get(&input.id()).unwrap();
67        let mut statements = Vec::new();
68        if let Some(array_members) = self.array_members.get(&input.name) {
69            // Build the array expression from the members.
70            let id = self.state.node_builder.next_id();
71            self.state.type_table.insert(id, ty.clone());
72            let expr = ArrayExpression {
73                elements: array_members
74                    // This clone is unfortunate, but both `array_members` and the closure below borrow self.
75                    .clone()
76                    .iter()
77                    .map(|identifier| {
78                        let (expr, statements2) = self.reconstruct_identifier(*identifier);
79                        statements.extend(statements2);
80                        expr
81                    })
82                    .collect(),
83                span: Default::default(),
84                id,
85            };
86            let statement = AssignStatement {
87                place: input.into(),
88                value: expr.into(),
89                span: Default::default(),
90                id: self.state.node_builder.next_id(),
91            };
92            statements.push(statement.into());
93            (input.into(), statements)
94        } else if let Some(struct_members) = self.struct_members.get(&input.name) {
95            // Build the struct expression from the members.
96            let id = self.state.node_builder.next_id();
97            self.state.type_table.insert(id, ty.clone());
98            let Type::Composite(comp_type) = ty else {
99                panic!("The type of a struct init should be a composite.");
100            };
101            let expr = StructExpression {
102                members: struct_members
103                    // This clone is unfortunate, but both `struct_members` and the closure below borrow self.
104                    .clone()
105                    .iter()
106                    .map(|(field_name, ident)| {
107                        let (expr, statements2) = self.reconstruct_identifier(*ident);
108                        statements.extend(statements2);
109                        StructVariableInitializer {
110                            identifier: Identifier::new(*field_name, self.state.node_builder.next_id()),
111                            expression: Some(expr),
112                            span: Default::default(),
113                            id: self.state.node_builder.next_id(),
114                        }
115                    })
116                    .collect(),
117                name: comp_type.id,
118                span: Default::default(),
119                id,
120            };
121            let statement = AssignStatement {
122                place: input.into(),
123                value: expr.into(),
124                span: Default::default(),
125                id: self.state.node_builder.next_id(),
126            };
127            statements.push(statement.into());
128            (input.into(), statements)
129        } else {
130            // This is not a struct or array whose members are written to, so there's nothing to do.
131            (input.into(), Default::default())
132        }
133    }
134
135    fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
136        let Expression::Identifier(array_name) = input.array else {
137            panic!("SSA ensures that this is an Identifier.");
138        };
139        if let Some(member) = self.get_array_member(array_name.name, &input.index) {
140            self.reconstruct_identifier(member)
141        } else {
142            (input.into(), Default::default())
143        }
144    }
145
146    fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
147        let Expression::Identifier(array_name) = input.inner else {
148            panic!("SSA ensures that this is an Identifier.");
149        };
150        if let Some(member) = self.get_struct_member(array_name.name, input.name.name) {
151            self.reconstruct_identifier(member)
152        } else {
153            (input.into(), Default::default())
154        }
155    }
156
157    // The rest of the methods below don't do anything but traverse - we only modify their default implementations
158    // to combine the `Vec<Statement>` outputs.
159
160    fn reconstruct_associated_function(
161        &mut self,
162        mut input: AssociatedFunctionExpression,
163    ) -> (Expression, Self::AdditionalOutput) {
164        let mut statements = Vec::new();
165        for arg in input.arguments.iter_mut() {
166            let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg));
167            statements.extend(statements2);
168            *arg = expr;
169        }
170        (input.into(), statements)
171    }
172
173    fn reconstruct_tuple_access(&mut self, _input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
174        panic!("`TupleAccess` should not be in the AST at this point.");
175    }
176
177    fn reconstruct_array(&mut self, mut input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
178        let mut statements = Vec::new();
179        for element in input.elements.iter_mut() {
180            let (expr, statements2) = self.reconstruct_expression(std::mem::take(element));
181            statements.extend(statements2);
182            *element = expr;
183        }
184        (input.into(), statements)
185    }
186
187    fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
188        let (left, mut statements) = self.reconstruct_expression(input.left);
189        let (right, statements2) = self.reconstruct_expression(input.right);
190        statements.extend(statements2);
191        (BinaryExpression { left, right, ..input }.into(), statements)
192    }
193
194    fn reconstruct_call(&mut self, mut input: CallExpression) -> (Expression, Self::AdditionalOutput) {
195        let mut statements = Vec::new();
196        for arg in input.arguments.iter_mut() {
197            let (expr, statements2) = self.reconstruct_expression(std::mem::take(arg));
198            statements.extend(statements2);
199            *arg = expr;
200        }
201        (input.into(), statements)
202    }
203
204    fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
205        let (expression, statements) = self.reconstruct_expression(input.expression);
206        (CastExpression { expression, ..input }.into(), statements)
207    }
208
209    fn reconstruct_struct_init(&mut self, mut input: StructExpression) -> (Expression, Self::AdditionalOutput) {
210        let mut statements = Vec::new();
211        for member in input.members.iter_mut() {
212            assert!(member.expression.is_some());
213            let (expr, statements2) = self.reconstruct_expression(member.expression.take().unwrap());
214            statements.extend(statements2);
215            member.expression = Some(expr);
216        }
217
218        (input.into(), statements)
219    }
220
221    fn reconstruct_err(&mut self, _input: leo_ast::ErrExpression) -> (Expression, Self::AdditionalOutput) {
222        std::panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
223    }
224
225    fn reconstruct_literal(&mut self, input: leo_ast::Literal) -> (Expression, Self::AdditionalOutput) {
226        (input.into(), Default::default())
227    }
228
229    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
230        let (condition, mut statements) = self.reconstruct_expression(input.condition);
231        let (if_true, statements2) = self.reconstruct_expression(input.if_true);
232        let (if_false, statements3) = self.reconstruct_expression(input.if_false);
233        statements.extend(statements2);
234        statements.extend(statements3);
235        (TernaryExpression { condition, if_true, if_false, ..input }.into(), statements)
236    }
237
238    fn reconstruct_tuple(&mut self, input: leo_ast::TupleExpression) -> (Expression, Self::AdditionalOutput) {
239        // This should ony appear in a return statement.
240        let mut statements = Vec::new();
241        let elements = input
242            .elements
243            .into_iter()
244            .map(|element| {
245                let (expr, statements2) = self.reconstruct_expression(element);
246                statements.extend(statements2);
247                expr
248            })
249            .collect();
250        (TupleExpression { elements, ..input }.into(), statements)
251    }
252
253    fn reconstruct_unary(&mut self, input: leo_ast::UnaryExpression) -> (Expression, Self::AdditionalOutput) {
254        let (receiver, statements) = self.reconstruct_expression(input.receiver);
255        (UnaryExpression { receiver, ..input }.into(), statements)
256    }
257
258    fn reconstruct_unit(&mut self, input: leo_ast::UnitExpression) -> (Expression, Self::AdditionalOutput) {
259        (input.into(), Default::default())
260    }
261}