leo_passes/loop_unrolling/
visitor.rs1use leo_ast::{
18 AstReconstructor,
19 Block,
20 IterationStatement,
21 Literal,
22 Node,
23 NodeID,
24 Statement,
25 Type,
26 interpreter_value::Value,
27};
28use leo_errors::LoopUnrollerError;
29use leo_span::{Span, Symbol};
30
31use itertools::Either;
32
33use crate::CompilerState;
34
35pub struct UnrollingVisitor<'a> {
36 pub state: &'a mut CompilerState,
37 pub program: Symbol,
39 pub loop_not_unrolled: Option<Span>,
41 pub loop_unrolled: bool,
43}
44
45impl UnrollingVisitor<'_> {
46 pub fn in_scope<T>(&mut self, id: NodeID, func: impl FnOnce(&mut Self) -> T) -> T {
47 self.state.symbol_table.enter_scope(Some(id));
48 let result = func(self);
49 self.state.symbol_table.enter_parent();
50 result
51 }
52
53 pub fn emit_err(&self, err: LoopUnrollerError) {
55 self.state.handler.emit_err(err);
56 }
57
58 pub fn unroll_iteration_statement(&mut self, input: IterationStatement, start: Value, stop: Value) -> Statement {
60 let cast_to_number = |v: Value| -> Result<i128, Statement> {
62 match v.as_i128() {
63 Some(val_as_i128) => Ok(val_as_i128),
64 None => {
65 self.state.handler.emit_err(LoopUnrollerError::value_out_of_i128_bounds(v, input.span()));
66 Err(Statement::dummy())
67 }
68 }
69 };
70
71 let start = match cast_to_number(start) {
73 Ok(v) => v,
74 Err(s) => return s,
75 };
76 let stop = match cast_to_number(stop) {
78 Ok(v) => v,
79 Err(s) => return s,
80 };
81
82 let new_block_id = self.state.node_builder.next_id();
83
84 let iter = if input.inclusive { Either::Left(start..=stop) } else { Either::Right(start..stop) };
85
86 self.in_scope(new_block_id, |slf| {
88 Block {
89 span: input.span,
90 statements: iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect(),
91 id: new_block_id,
92 }
93 .into()
94 })
95 }
96
97 fn unroll_single_iteration(&mut self, input: &IterationStatement, iteration_count: i128) -> Statement {
99 let const_id = self.state.node_builder.next_id();
101
102 let iterator_type =
103 self.state.type_table.get(&input.variable.id()).expect("guaranteed to have a type after type checking");
104
105 self.state.type_table.insert(const_id, iterator_type.clone());
107
108 let outer_block_id = self.state.node_builder.next_id();
109
110 let Type::Integer(integer_type) = &iterator_type else {
112 unreachable!("Type checking enforces that the iteration variable is of integer type");
113 };
114
115 self.in_scope(outer_block_id, |slf| {
116 let value = Literal::integer(*integer_type, iteration_count.to_string(), Default::default(), const_id);
117
118 slf.state.symbol_table.insert_const(slf.program, &[input.variable.name], value.into());
120
121 let duplicated_body =
122 super::duplicate::duplicate(input.block.clone(), &mut slf.state.symbol_table, &slf.state.node_builder);
123
124 let result = slf.reconstruct_block(duplicated_body).0.into();
125
126 Block { statements: vec![result], span: input.span(), id: outer_block_id }.into()
127 })
128 }
129}