leo_passes/write_transforming/
statement.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::WriteTransformingVisitor;
18
19use leo_ast::{
20    ArrayAccess,
21    AssertStatement,
22    AssertVariant,
23    AssignStatement,
24    Block,
25    ConditionalStatement,
26    DefinitionPlace,
27    DefinitionStatement,
28    Expression,
29    ExpressionReconstructor,
30    ExpressionStatement,
31    Identifier,
32    IntegerType,
33    IterationStatement,
34    Literal,
35    MemberAccess,
36    Node,
37    ReturnStatement,
38    Statement,
39    StatementReconstructor,
40    Type,
41};
42
43impl WriteTransformingVisitor<'_> {
44    /// If `name` is a struct or array whose members are written to, make
45    /// `DefinitionStatement`s for each of its variables that will correspond to
46    /// the members. Note that we create them for all members; unnecessary ones
47    /// will be removed by DCE.
48    pub fn define_variable_members(&mut self, name: Identifier, accumulate: &mut Vec<Statement>) {
49        // The `cloned` here and in the branch below are unfortunate but we need
50        // to mutably borrow `self` again below.
51        if let Some(members) = self.array_members.get(&name.name).cloned() {
52            for (i, member) in members.iter().cloned().enumerate() {
53                // Create a definition for each array index.
54                let index = Literal::integer(
55                    IntegerType::U8,
56                    i.to_string(),
57                    Default::default(),
58                    self.state.node_builder.next_id(),
59                );
60                self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
61                let access = ArrayAccess {
62                    array: name.into(),
63                    index: index.into(),
64                    span: Default::default(),
65                    id: self.state.node_builder.next_id(),
66                };
67                self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
68                let def = DefinitionStatement {
69                    place: DefinitionPlace::Single(member),
70                    type_: None,
71                    value: access.into(),
72                    span: Default::default(),
73                    id: self.state.node_builder.next_id(),
74                };
75                accumulate.push(def.into());
76                // And recurse - maybe its members are also written to.
77                self.define_variable_members(member, accumulate);
78            }
79        } else if let Some(members) = self.struct_members.get(&name.name) {
80            for (&field_name, &member) in members.clone().iter() {
81                // Create a definition for each field.
82                let access = MemberAccess {
83                    inner: name.into(),
84                    name: Identifier::new(field_name, self.state.node_builder.next_id()),
85                    span: Default::default(),
86                    id: self.state.node_builder.next_id(),
87                };
88                self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
89                let def = DefinitionStatement {
90                    place: DefinitionPlace::Single(member),
91                    type_: None,
92                    value: access.into(),
93                    span: Default::default(),
94                    id: self.state.node_builder.next_id(),
95                };
96                accumulate.push(def.into());
97                // And recurse - maybe its members are also written to.
98                self.define_variable_members(member, accumulate);
99            }
100        }
101    }
102}
103
104impl WriteTransformingVisitor<'_> {
105    /// If we're assigning to a struct or array member, find the variable name we're actually writing to,
106    /// recursively if necessary.
107    /// That is, if we have
108    /// `arr[0u32][1u32] = ...`,
109    /// we find the corresponding variable `arr_0_1`.
110    pub fn reconstruct_assign_place(&mut self, input: Expression) -> Identifier {
111        use Expression::*;
112        match input {
113            ArrayAccess(array_access) => {
114                let identifier = self.reconstruct_assign_place(array_access.array);
115                self.get_array_member(identifier.name, &array_access.index).expect("We have visited all array writes.")
116            }
117            Identifier(identifier) => identifier,
118            MemberAccess(member_access) => {
119                let identifier = self.reconstruct_assign_place(member_access.inner);
120                self.get_struct_member(identifier.name, member_access.name.name)
121                    .expect("We have visited all struct writes.")
122            }
123            TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
124            _ => panic!("Type checking should have ensured there are no other places for assignments"),
125        }
126    }
127
128    /// If we're assigning to a struct or array, create assignments to the individual members, if applicable.
129    fn reconstruct_assign_recurse(&self, place: Identifier, value: Expression, accumulate: &mut Vec<Statement>) {
130        if let Some(array_members) = self.array_members.get(&place.name) {
131            if let Expression::Array(value_array) = value {
132                // This was an assignment like
133                // `arr = [a, b, c];`
134                // Change it to this:
135                // `arr_0 = a; arr_1 = b; arr_2 = c`
136                for (&member, rhs_element) in array_members.iter().zip(value_array.elements) {
137                    self.reconstruct_assign_recurse(member, rhs_element, accumulate);
138                }
139            } else {
140                // This was an assignment like
141                // `arr = x;`
142                // Change it to this:
143                // `arr = x; arr_0 = x[0]; arr_1 = x[1]; arr_2 = x[2];`
144                let one_assign = AssignStatement {
145                    place: place.into(),
146                    value,
147                    span: Default::default(),
148                    id: self.state.node_builder.next_id(),
149                }
150                .into();
151                accumulate.push(one_assign);
152                for (i, &member) in array_members.iter().enumerate() {
153                    let access = ArrayAccess {
154                        array: place.into(),
155                        index: Literal::integer(
156                            IntegerType::U32,
157                            format!("{i}u32"),
158                            Default::default(),
159                            self.state.node_builder.next_id(),
160                        )
161                        .into(),
162                        span: Default::default(),
163                        id: self.state.node_builder.next_id(),
164                    };
165                    self.reconstruct_assign_recurse(member, access.into(), accumulate);
166                }
167            }
168        } else if let Some(struct_members) = self.struct_members.get(&place.name) {
169            if let Expression::Struct(value_struct) = value {
170                // This was an assignment like
171                // `struc = S { field0: a, field1: b };`
172                // Change it to this:
173                // `struc_field0 = a; struc_field1 = b;`
174                for initializer in value_struct.members.into_iter() {
175                    let member_name = struct_members.get(&initializer.identifier.name).expect("Member should exist.");
176                    let rhs_expression =
177                        initializer.expression.expect("This should have been normalized to have a value.");
178                    self.reconstruct_assign_recurse(*member_name, rhs_expression, accumulate);
179                }
180            } else {
181                // This was an assignment like
182                // `struc = x;`
183                // Change it to this:
184                // `struc = x; struc_field0 = x.field0; struc_field1 = x.field1;`
185                let one_assign = AssignStatement {
186                    place: place.into(),
187                    value,
188                    span: Default::default(),
189                    id: self.state.node_builder.next_id(),
190                }
191                .into();
192                accumulate.push(one_assign);
193                for (field, member_name) in struct_members.iter() {
194                    let access = MemberAccess {
195                        inner: place.into(),
196                        name: Identifier::new(*field, self.state.node_builder.next_id()),
197                        span: Default::default(),
198                        id: self.state.node_builder.next_id(),
199                    };
200                    self.reconstruct_assign_recurse(*member_name, access.into(), accumulate);
201                }
202            }
203        } else {
204            let stmt = AssignStatement {
205                value,
206                place: place.into(),
207                id: self.state.node_builder.next_id(),
208                span: Default::default(),
209            }
210            .into();
211            accumulate.push(stmt);
212        }
213    }
214}
215
216impl StatementReconstructor for WriteTransformingVisitor<'_> {
217    /// This is the only reconstructing function where we do anything other than traverse and combine statements,
218    /// by calling `reconstruct_assign_place` and `reconstruct_assign_recurse`.
219    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
220        let (value, mut statements) = self.reconstruct_expression(input.value);
221        let place = self.reconstruct_assign_place(input.place);
222        self.reconstruct_assign_recurse(place, value, &mut statements);
223        (Statement::dummy(), statements)
224    }
225
226    fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
227        let mut statements = Vec::new();
228        let stmt = AssertStatement {
229            variant: match input.variant {
230                AssertVariant::Assert(expr) => {
231                    let (expr, statements2) = self.reconstruct_expression(expr);
232                    statements.extend(statements2);
233                    AssertVariant::Assert(expr)
234                }
235                AssertVariant::AssertEq(left, right) => {
236                    let (left, statements2) = self.reconstruct_expression(left);
237                    statements.extend(statements2);
238                    let (right, statements3) = self.reconstruct_expression(right);
239                    statements.extend(statements3);
240                    AssertVariant::AssertEq(left, right)
241                }
242                AssertVariant::AssertNeq(left, right) => {
243                    let (left, statements2) = self.reconstruct_expression(left);
244                    statements.extend(statements2);
245                    let (right, statements3) = self.reconstruct_expression(right);
246                    statements.extend(statements3);
247                    AssertVariant::AssertNeq(left, right)
248                }
249            },
250            ..input
251        }
252        .into();
253        (stmt, Default::default())
254    }
255
256    fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
257        let mut statements = Vec::with_capacity(block.statements.len());
258
259        // Reconstruct the statements in the block, accumulating any additional statements.
260        for statement in block.statements {
261            let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
262            statements.extend(additional_statements);
263            if !reconstructed_statement.is_empty() {
264                statements.push(reconstructed_statement);
265            }
266        }
267
268        (Block { statements, ..block }, Default::default())
269    }
270
271    fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
272        let (value, mut statements) = self.reconstruct_expression(input.value);
273        input.value = value;
274        match input.place.clone() {
275            DefinitionPlace::Single(identifier) => {
276                statements.push(input.into());
277                self.define_variable_members(identifier, &mut statements);
278            }
279            DefinitionPlace::Multiple(identifiers) => {
280                statements.push(input.into());
281                for &identifier in identifiers.iter() {
282                    self.define_variable_members(identifier, &mut statements);
283                }
284            }
285        }
286        (Statement::dummy(), statements)
287    }
288
289    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
290        let (expression, statements) = self.reconstruct_expression(input.expression);
291        (ExpressionStatement { expression, ..input }.into(), statements)
292    }
293
294    fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
295        panic!("`IterationStatement`s should not be in the AST at this point.");
296    }
297
298    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
299        let (expression, statements) = self.reconstruct_expression(input.expression);
300        (ReturnStatement { expression, ..input }.into(), statements)
301    }
302
303    fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
304        let (condition, mut statements) = self.reconstruct_expression(input.condition);
305        let (then, statements2) = self.reconstruct_block(input.then);
306        statements.extend(statements2);
307        let otherwise = input.otherwise.map(|oth| {
308            let (expr, statements3) = self.reconstruct_statement(*oth);
309            statements.extend(statements3);
310            Box::new(expr)
311        });
312        (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
313    }
314}