leo_passes/write_transforming/
statement.rs1use super::WriteTransformingVisitor;
18
19use leo_ast::{
20 ArrayAccess,
21 AssertStatement,
22 AssertVariant,
23 AssignStatement,
24 Block,
25 ConditionalStatement,
26 DefinitionPlace,
27 DefinitionStatement,
28 Expression,
29 ExpressionReconstructor,
30 ExpressionStatement,
31 Identifier,
32 IntegerType,
33 IterationStatement,
34 Literal,
35 MemberAccess,
36 Node,
37 ReturnStatement,
38 Statement,
39 StatementReconstructor,
40 Type,
41};
42
43impl WriteTransformingVisitor<'_> {
44 pub fn define_variable_members(&mut self, name: Identifier, accumulate: &mut Vec<Statement>) {
49 if let Some(members) = self.array_members.get(&name.name).cloned() {
52 for (i, member) in members.iter().cloned().enumerate() {
53 let index = Literal::integer(
55 IntegerType::U8,
56 i.to_string(),
57 Default::default(),
58 self.state.node_builder.next_id(),
59 );
60 self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
61 let access = ArrayAccess {
62 array: name.into(),
63 index: index.into(),
64 span: Default::default(),
65 id: self.state.node_builder.next_id(),
66 };
67 self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
68 let def = DefinitionStatement {
69 place: DefinitionPlace::Single(member),
70 type_: None,
71 value: access.into(),
72 span: Default::default(),
73 id: self.state.node_builder.next_id(),
74 };
75 accumulate.push(def.into());
76 self.define_variable_members(member, accumulate);
78 }
79 } else if let Some(members) = self.struct_members.get(&name.name) {
80 for (&field_name, &member) in members.clone().iter() {
81 let access = MemberAccess {
83 inner: name.into(),
84 name: Identifier::new(field_name, self.state.node_builder.next_id()),
85 span: Default::default(),
86 id: self.state.node_builder.next_id(),
87 };
88 self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
89 let def = DefinitionStatement {
90 place: DefinitionPlace::Single(member),
91 type_: None,
92 value: access.into(),
93 span: Default::default(),
94 id: self.state.node_builder.next_id(),
95 };
96 accumulate.push(def.into());
97 self.define_variable_members(member, accumulate);
99 }
100 }
101 }
102}
103
104impl WriteTransformingVisitor<'_> {
105 pub fn reconstruct_assign_place(&mut self, input: Expression) -> Identifier {
111 use Expression::*;
112 match input {
113 ArrayAccess(array_access) => {
114 let identifier = self.reconstruct_assign_place(array_access.array);
115 self.get_array_member(identifier.name, &array_access.index).expect("We have visited all array writes.")
116 }
117 Identifier(identifier) => identifier,
118 MemberAccess(member_access) => {
119 let identifier = self.reconstruct_assign_place(member_access.inner);
120 self.get_struct_member(identifier.name, member_access.name.name)
121 .expect("We have visited all struct writes.")
122 }
123 TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
124 _ => panic!("Type checking should have ensured there are no other places for assignments"),
125 }
126 }
127
128 fn reconstruct_assign_recurse(&self, place: Identifier, value: Expression, accumulate: &mut Vec<Statement>) {
130 if let Some(array_members) = self.array_members.get(&place.name) {
131 if let Expression::Array(value_array) = value {
132 for (&member, rhs_element) in array_members.iter().zip(value_array.elements) {
137 self.reconstruct_assign_recurse(member, rhs_element, accumulate);
138 }
139 } else {
140 let one_assign = AssignStatement {
145 place: place.into(),
146 value,
147 span: Default::default(),
148 id: self.state.node_builder.next_id(),
149 }
150 .into();
151 accumulate.push(one_assign);
152 for (i, &member) in array_members.iter().enumerate() {
153 let access = ArrayAccess {
154 array: place.into(),
155 index: Literal::integer(
156 IntegerType::U32,
157 format!("{i}u32"),
158 Default::default(),
159 self.state.node_builder.next_id(),
160 )
161 .into(),
162 span: Default::default(),
163 id: self.state.node_builder.next_id(),
164 };
165 self.reconstruct_assign_recurse(member, access.into(), accumulate);
166 }
167 }
168 } else if let Some(struct_members) = self.struct_members.get(&place.name) {
169 if let Expression::Struct(value_struct) = value {
170 for initializer in value_struct.members.into_iter() {
175 let member_name = struct_members.get(&initializer.identifier.name).expect("Member should exist.");
176 let rhs_expression =
177 initializer.expression.expect("This should have been normalized to have a value.");
178 self.reconstruct_assign_recurse(*member_name, rhs_expression, accumulate);
179 }
180 } else {
181 let one_assign = AssignStatement {
186 place: place.into(),
187 value,
188 span: Default::default(),
189 id: self.state.node_builder.next_id(),
190 }
191 .into();
192 accumulate.push(one_assign);
193 for (field, member_name) in struct_members.iter() {
194 let access = MemberAccess {
195 inner: place.into(),
196 name: Identifier::new(*field, self.state.node_builder.next_id()),
197 span: Default::default(),
198 id: self.state.node_builder.next_id(),
199 };
200 self.reconstruct_assign_recurse(*member_name, access.into(), accumulate);
201 }
202 }
203 } else {
204 let stmt = AssignStatement {
205 value,
206 place: place.into(),
207 id: self.state.node_builder.next_id(),
208 span: Default::default(),
209 }
210 .into();
211 accumulate.push(stmt);
212 }
213 }
214}
215
216impl StatementReconstructor for WriteTransformingVisitor<'_> {
217 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
220 let (value, mut statements) = self.reconstruct_expression(input.value);
221 let place = self.reconstruct_assign_place(input.place);
222 self.reconstruct_assign_recurse(place, value, &mut statements);
223 (Statement::dummy(), statements)
224 }
225
226 fn reconstruct_assert(&mut self, input: leo_ast::AssertStatement) -> (Statement, Self::AdditionalOutput) {
227 let mut statements = Vec::new();
228 let stmt = AssertStatement {
229 variant: match input.variant {
230 AssertVariant::Assert(expr) => {
231 let (expr, statements2) = self.reconstruct_expression(expr);
232 statements.extend(statements2);
233 AssertVariant::Assert(expr)
234 }
235 AssertVariant::AssertEq(left, right) => {
236 let (left, statements2) = self.reconstruct_expression(left);
237 statements.extend(statements2);
238 let (right, statements3) = self.reconstruct_expression(right);
239 statements.extend(statements3);
240 AssertVariant::AssertEq(left, right)
241 }
242 AssertVariant::AssertNeq(left, right) => {
243 let (left, statements2) = self.reconstruct_expression(left);
244 statements.extend(statements2);
245 let (right, statements3) = self.reconstruct_expression(right);
246 statements.extend(statements3);
247 AssertVariant::AssertNeq(left, right)
248 }
249 },
250 ..input
251 }
252 .into();
253 (stmt, Default::default())
254 }
255
256 fn reconstruct_block(&mut self, block: Block) -> (Block, Self::AdditionalOutput) {
257 let mut statements = Vec::with_capacity(block.statements.len());
258
259 for statement in block.statements {
261 let (reconstructed_statement, additional_statements) = self.reconstruct_statement(statement);
262 statements.extend(additional_statements);
263 if !reconstructed_statement.is_empty() {
264 statements.push(reconstructed_statement);
265 }
266 }
267
268 (Block { statements, ..block }, Default::default())
269 }
270
271 fn reconstruct_definition(&mut self, mut input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
272 let (value, mut statements) = self.reconstruct_expression(input.value);
273 input.value = value;
274 match input.place.clone() {
275 DefinitionPlace::Single(identifier) => {
276 statements.push(input.into());
277 self.define_variable_members(identifier, &mut statements);
278 }
279 DefinitionPlace::Multiple(identifiers) => {
280 statements.push(input.into());
281 for &identifier in identifiers.iter() {
282 self.define_variable_members(identifier, &mut statements);
283 }
284 }
285 }
286 (Statement::dummy(), statements)
287 }
288
289 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
290 let (expression, statements) = self.reconstruct_expression(input.expression);
291 (ExpressionStatement { expression, ..input }.into(), statements)
292 }
293
294 fn reconstruct_iteration(&mut self, _input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
295 panic!("`IterationStatement`s should not be in the AST at this point.");
296 }
297
298 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
299 let (expression, statements) = self.reconstruct_expression(input.expression);
300 (ReturnStatement { expression, ..input }.into(), statements)
301 }
302
303 fn reconstruct_conditional(&mut self, input: leo_ast::ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
304 let (condition, mut statements) = self.reconstruct_expression(input.condition);
305 let (then, statements2) = self.reconstruct_block(input.then);
306 statements.extend(statements2);
307 let otherwise = input.otherwise.map(|oth| {
308 let (expr, statements3) = self.reconstruct_statement(*oth);
309 statements.extend(statements3);
310 Box::new(expr)
311 });
312 (ConditionalStatement { condition, then, otherwise, ..input }.into(), statements)
313 }
314}