leo_passes/write_transforming/
visitor.rs1use crate::CompilerState;
18
19use leo_ast::{
20 ArrayAccess,
21 AssignStatement,
22 AstVisitor,
23 DefinitionPlace,
24 DefinitionStatement,
25 Expression,
26 Identifier,
27 IntegerType,
28 Literal,
29 Location,
30 MemberAccess,
31 Node as _,
32 Path,
33 Program,
34 Statement,
35 Type,
36};
37use leo_span::Symbol;
38
39use indexmap::IndexMap;
40
41pub struct WriteTransformingVisitor<'a> {
46 pub state: &'a mut CompilerState,
47
48 pub struct_members: IndexMap<Symbol, IndexMap<Symbol, Identifier>>,
51
52 pub array_members: IndexMap<Symbol, Vec<Identifier>>,
54
55 pub program: Symbol,
56}
57
58impl<'a> WriteTransformingVisitor<'a> {
59 pub fn new(state: &'a mut CompilerState, program: &Program) -> Self {
60 let visitor = WriteTransformingVisitor {
61 state,
62 struct_members: Default::default(),
63 array_members: Default::default(),
64 program: Symbol::intern(""),
65 };
66
67 let mut wtf = WriteTransformingFiller(visitor);
69 wtf.fill(program);
70 wtf.0
71 }
72
73 pub fn define_variable_members(&mut self, name: Identifier, accumulate: &mut Vec<Statement>) {
78 if let Some(members) = self.array_members.get(&name.name).cloned() {
81 for (i, member) in members.iter().cloned().enumerate() {
82 let index = Literal::integer(
84 IntegerType::U8,
85 i.to_string(),
86 Default::default(),
87 self.state.node_builder.next_id(),
88 );
89 self.state.type_table.insert(index.id(), Type::Integer(IntegerType::U32));
90 let access = ArrayAccess {
91 array: Path::from(name).into_absolute().into(),
92 index: index.into(),
93 span: Default::default(),
94 id: self.state.node_builder.next_id(),
95 };
96 self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
97 let def = DefinitionStatement {
98 place: DefinitionPlace::Single(member),
99 type_: None,
100 value: access.into(),
101 span: Default::default(),
102 id: self.state.node_builder.next_id(),
103 };
104 accumulate.push(def.into());
105 self.define_variable_members(member, accumulate);
107 }
108 } else if let Some(members) = self.struct_members.get(&name.name) {
109 for (&field_name, &member) in members.clone().iter() {
110 let access = MemberAccess {
112 inner: Path::from(name).into_absolute().into(),
113 name: Identifier::new(field_name, self.state.node_builder.next_id()),
114 span: Default::default(),
115 id: self.state.node_builder.next_id(),
116 };
117 self.state.type_table.insert(access.id(), self.state.type_table.get(&member.id()).unwrap().clone());
118 let def = DefinitionStatement {
119 place: DefinitionPlace::Single(member),
120 type_: None,
121 value: access.into(),
122 span: Default::default(),
123 id: self.state.node_builder.next_id(),
124 };
125 accumulate.push(def.into());
126 self.define_variable_members(member, accumulate);
128 }
129 }
130 }
131
132 pub fn reconstruct_assign_place(&mut self, input: Expression) -> Identifier {
138 use Expression::*;
139 match input {
140 ArrayAccess(array_access) => {
141 let identifier = self.reconstruct_assign_place(array_access.array);
142 self.get_array_member(identifier.name, &array_access.index).expect("We have visited all array writes.")
143 }
144 Path(path) => leo_ast::Identifier { name: path.identifier().name, span: path.span, id: path.id },
145 MemberAccess(member_access) => {
146 let identifier = self.reconstruct_assign_place(member_access.inner);
147 self.get_struct_member(identifier.name, member_access.name.name)
148 .expect("We have visited all struct writes.")
149 }
150 TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
151 _ => panic!("Type checking should have ensured there are no other places for assignments"),
152 }
153 }
154
155 pub fn reconstruct_assign_recurse(&self, place: Identifier, value: Expression, accumulate: &mut Vec<Statement>) {
157 if let Some(array_members) = self.array_members.get(&place.name) {
158 if let Expression::Array(value_array) = value {
159 for (&member, rhs_element) in array_members.iter().zip(value_array.elements) {
164 self.reconstruct_assign_recurse(member, rhs_element, accumulate);
165 }
166 } else {
167 let one_assign = AssignStatement {
172 place: Path::from(place).into_absolute().into(),
173 value,
174 span: Default::default(),
175 id: self.state.node_builder.next_id(),
176 }
177 .into();
178 accumulate.push(one_assign);
179 for (i, &member) in array_members.iter().enumerate() {
180 let access = ArrayAccess {
181 array: Path::from(place).into_absolute().into(),
182 index: Literal::integer(
183 IntegerType::U32,
184 format!("{i}u32"),
185 Default::default(),
186 self.state.node_builder.next_id(),
187 )
188 .into(),
189 span: Default::default(),
190 id: self.state.node_builder.next_id(),
191 };
192 self.reconstruct_assign_recurse(member, access.into(), accumulate);
193 }
194 }
195 } else if let Some(struct_members) = self.struct_members.get(&place.name) {
196 if let Expression::Struct(value_struct) = value {
197 for initializer in value_struct.members.into_iter() {
202 let member_name = struct_members.get(&initializer.identifier.name).expect("Member should exist.");
203 let rhs_expression =
204 initializer.expression.expect("This should have been normalized to have a value.");
205 self.reconstruct_assign_recurse(*member_name, rhs_expression, accumulate);
206 }
207 } else {
208 let one_assign = AssignStatement {
213 place: Path::from(place).into_absolute().into(),
214 value,
215 span: Default::default(),
216 id: self.state.node_builder.next_id(),
217 }
218 .into();
219 accumulate.push(one_assign);
220 for (field, member_name) in struct_members.iter() {
221 let access = MemberAccess {
222 inner: Path::from(place).into_absolute().into(),
223 name: Identifier::new(*field, self.state.node_builder.next_id()),
224 span: Default::default(),
225 id: self.state.node_builder.next_id(),
226 };
227 self.reconstruct_assign_recurse(*member_name, access.into(), accumulate);
228 }
229 }
230 } else {
231 let stmt = AssignStatement {
232 value,
233 place: Path::from(place).into_absolute().into(),
234 id: self.state.node_builder.next_id(),
235 span: Default::default(),
236 }
237 .into();
238 accumulate.push(stmt);
239 }
240 }
241}
242
243struct WriteTransformingFiller<'a>(WriteTransformingVisitor<'a>);
244
245impl AstVisitor for WriteTransformingFiller<'_> {
246 type AdditionalInput = ();
247 type Output = ();
248
249 fn visit_expression(&mut self, _input: &Expression, _additional: &Self::AdditionalInput) -> Self::Output {}
251
252 fn visit_assign(&mut self, input: &AssignStatement) {
254 self.access_recurse(&input.place);
255 }
256}
257
258impl WriteTransformingFiller<'_> {
259 fn fill(&mut self, program: &Program) {
260 for (_, module) in program.modules.iter() {
261 self.0.program = module.program_name;
262 for (_, function) in module.functions.iter() {
263 self.visit_block(&function.block);
264 }
265 }
266 for (_, scope) in program.program_scopes.iter() {
267 for (_, function) in scope.functions.iter() {
268 self.0.program = scope.program_id.name.name;
269 self.visit_block(&function.block);
270 }
271 }
272 }
273
274 fn access_recurse(&mut self, place: &Expression) -> Identifier {
277 match place {
278 Expression::Path(path) => Identifier { name: path.identifier().name, span: path.span, id: path.id },
279 Expression::ArrayAccess(array_access) => {
280 let array_name = self.access_recurse(&array_access.array);
281 let members = self.0.array_members.entry(array_name.name).or_insert_with(|| {
282 let ty = self.0.state.type_table.get(&array_access.array.id()).unwrap();
283 let Type::Array(arr) = ty else { panic!("Type checking should have prevented this.") };
284 (0..arr.length.as_u32().expect("length should be known at this point"))
285 .map(|i| {
286 let id = self.0.state.node_builder.next_id();
287 let symbol = self.0.state.assigner.unique_symbol(format_args!("{array_name}#{i}"), "$");
288 self.0.state.type_table.insert(id, arr.element_type().clone());
289 Identifier::new(symbol, id)
290 })
291 .collect()
292 });
293 let Expression::Literal(lit) = &array_access.index else {
294 panic!("Const propagation should have ensured this is a literal.");
295 };
296 members[lit
297 .as_u32()
298 .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
299 as usize]
300 }
301 Expression::MemberAccess(member_access) => {
302 let struct_name = self.access_recurse(&member_access.inner);
303 let members = self.0.struct_members.entry(struct_name.name).or_insert_with(|| {
304 let ty = self.0.state.type_table.get(&member_access.inner.id()).unwrap();
305 let Type::Composite(comp) = ty else {
306 panic!("Type checking should have prevented this.");
307 };
308 let struct_ = self
309 .0
310 .state
311 .symbol_table
312 .lookup_struct(&comp.path.absolute_path())
313 .or_else(|| {
314 self.0.state.symbol_table.lookup_record(&Location::new(
315 comp.program.unwrap_or(self.0.program),
316 comp.path.absolute_path(),
317 ))
318 })
319 .unwrap();
320 struct_
321 .members
322 .iter()
323 .map(|member| {
324 let name = member.name();
325 let id = self.0.state.node_builder.next_id();
326 let symbol = self.0.state.assigner.unique_symbol(format_args!("{struct_name}#{name}"), "$");
327 self.0.state.type_table.insert(id, member.type_.clone());
328 (member.name(), Identifier::new(symbol, id))
329 })
330 .collect()
331 });
332 *members.get(&member_access.name.name).expect("Type checking should have ensured this is valid.")
333 }
334 Expression::TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
335 _ => panic!("Type checking should have ensured there are no other places for assignments"),
336 }
337 }
338}