leo_passes/write_transforming/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::CompilerState;
18
19use leo_ast::{
20    AssignStatement,
21    Expression,
22    ExpressionVisitor,
23    Identifier,
24    Location,
25    Node as _,
26    Program,
27    StatementVisitor,
28    Type,
29    TypeReconstructor,
30    TypeVisitor,
31};
32use leo_span::Symbol;
33
34use indexmap::IndexMap;
35
36/// This visitor associates a variable for each member of a struct or array that is
37/// written to. Whenever a member of the struct or array is written to, we change the
38/// assignment to access the variable instead. Whenever the struct or array itself
39/// is accessed, we first rebuild the struct or array from its variables.
40pub struct WriteTransformingVisitor<'a> {
41    pub state: &'a mut CompilerState,
42
43    /// For any struct whose members are written to, a map of its field names to variables
44    /// corresponding to the members.
45    pub struct_members: IndexMap<Symbol, IndexMap<Symbol, Identifier>>,
46
47    /// For any array whose members are written to, a vec containing the variables for each index.
48    pub array_members: IndexMap<Symbol, Vec<Identifier>>,
49
50    pub program: Symbol,
51}
52
53impl<'a> WriteTransformingVisitor<'a> {
54    pub fn new(state: &'a mut CompilerState, program: &Program) -> Self {
55        let visitor = WriteTransformingVisitor {
56            state,
57            struct_members: Default::default(),
58            array_members: Default::default(),
59            program: Symbol::intern(""),
60        };
61
62        // We need to do an initial pass through the AST to identify all arrays and structs that are written to.
63        let mut wtf = WriteTransformingFiller(visitor);
64        wtf.fill(program);
65        wtf.0
66    }
67}
68
69struct WriteTransformingFiller<'a>(WriteTransformingVisitor<'a>);
70
71// We don't actually need to visit expressions here; we're only implementing
72// `ExpressionVisitor` because `StatementVisitor` requires it.
73impl ExpressionVisitor for WriteTransformingFiller<'_> {
74    type AdditionalInput = ();
75    type Output = ();
76
77    fn visit_expression(&mut self, _input: &Expression, _additional: &Self::AdditionalInput) -> Self::Output {}
78}
79
80// All we actually need is `visit_assign`; we're just using `StatementVisitor`'s
81// default traversal.
82impl TypeVisitor for WriteTransformingFiller<'_> {}
83
84impl StatementVisitor for WriteTransformingFiller<'_> {
85    fn visit_assign(&mut self, input: &AssignStatement) {
86        self.access_recurse(&input.place);
87    }
88}
89
90impl WriteTransformingFiller<'_> {
91    fn fill(&mut self, program: &Program) {
92        for (_, scope) in program.program_scopes.iter() {
93            for (_, function) in scope.functions.iter() {
94                self.0.program = scope.program_id.name.name;
95                self.visit_block(&function.block);
96            }
97        }
98    }
99
100    /// Find assignments to arrays and structs and populate `array_members` and `struct_members` with new
101    /// variables names.
102    fn access_recurse(&mut self, place: &Expression) -> Identifier {
103        match place {
104            Expression::Identifier(identifier) => *identifier,
105            Expression::ArrayAccess(array_access) => {
106                let array_name = self.access_recurse(&array_access.array);
107                let members = self.0.array_members.entry(array_name.name).or_insert_with(|| {
108                    let ty = self.0.state.type_table.get(&array_access.array.id()).unwrap();
109                    let Type::Array(arr) = ty else { panic!("Type checking should have prevented this.") };
110                    (0..arr.length.as_u32().expect("length should be known at this point"))
111                        .map(|i| {
112                            let id = self.0.state.node_builder.next_id();
113                            let symbol = self.0.state.assigner.unique_symbol(format_args!("{array_name}#{i}"), "$");
114                            self.0.state.type_table.insert(id, arr.element_type().clone());
115                            Identifier::new(symbol, id)
116                        })
117                        .collect()
118                });
119                let Expression::Literal(lit) = &array_access.index else {
120                    panic!("Const propagation should have ensured this is a literal.");
121                };
122                members[lit
123                    .as_u32()
124                    .expect("Const propagation should have ensured this is in range, and consequently a valid u32.")
125                    as usize]
126            }
127            Expression::MemberAccess(member_access) => {
128                let struct_name = self.access_recurse(&member_access.inner);
129                let members = self.0.struct_members.entry(struct_name.name).or_insert_with(|| {
130                    let ty = self.0.state.type_table.get(&member_access.inner.id()).unwrap();
131                    let Type::Composite(comp) = ty else {
132                        panic!("Type checking should have prevented this.");
133                    };
134                    let struct_ = self
135                        .0
136                        .state
137                        .symbol_table
138                        .lookup_struct(comp.id.name)
139                        .or_else(|| {
140                            self.0
141                                .state
142                                .symbol_table
143                                .lookup_record(Location::new(comp.program.unwrap_or(self.0.program), comp.id.name))
144                        })
145                        .unwrap();
146                    struct_
147                        .members
148                        .iter()
149                        .map(|member| {
150                            let name = member.name();
151                            let id = self.0.state.node_builder.next_id();
152                            let symbol = self.0.state.assigner.unique_symbol(format_args!("{struct_name}#{name}"), "$");
153                            self.0.state.type_table.insert(id, member.type_.clone());
154                            (member.name(), Identifier::new(symbol, id))
155                        })
156                        .collect()
157                });
158                *members.get(&member_access.name.name).expect("Type checking should have ensured this is valid.")
159            }
160            Expression::TupleAccess(_) => panic!("TupleAccess writes should have been removed by Destructuring"),
161            _ => panic!("Type checking should have ensured there are no other places for assignments"),
162        }
163    }
164}
165
166impl TypeReconstructor for WriteTransformingVisitor<'_> {}