leo_passes/loop_unrolling/
visitor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{
18    AstReconstructor,
19    Block,
20    IterationStatement,
21    Literal,
22    Node,
23    NodeID,
24    Statement,
25    Type,
26    interpreter_value::Value,
27};
28use leo_errors::LoopUnrollerError;
29use leo_span::{Span, Symbol};
30
31use itertools::Either;
32
33use crate::CompilerState;
34
35pub struct UnrollingVisitor<'a> {
36    pub state: &'a mut CompilerState,
37    /// The current program name.
38    pub program: Symbol,
39    /// If we've encountered a loop that was not unrolled, here's it's spanned.
40    pub loop_not_unrolled: Option<Span>,
41    /// Have we unrolled any loop?
42    pub loop_unrolled: bool,
43}
44
45impl UnrollingVisitor<'_> {
46    pub fn in_scope<T>(&mut self, id: NodeID, func: impl FnOnce(&mut Self) -> T) -> T {
47        self.state.symbol_table.enter_scope(Some(id));
48        let result = func(self);
49        self.state.symbol_table.enter_parent();
50        result
51    }
52
53    /// Emits a Loop Unrolling Error
54    pub fn emit_err(&self, err: LoopUnrollerError) {
55        self.state.handler.emit_err(err);
56    }
57
58    /// Unrolls an IterationStatement.
59    pub fn unroll_iteration_statement(&mut self, input: IterationStatement, start: Value, stop: Value) -> Statement {
60        // We already know these are integers since loop unrolling occurs after type checking.
61        let cast_to_number = |v: Value| -> Result<i128, Statement> {
62            match v.as_i128() {
63                Some(val_as_i128) => Ok(val_as_i128),
64                None => {
65                    self.state.handler.emit_err(LoopUnrollerError::value_out_of_i128_bounds(v, input.span()));
66                    Err(Statement::dummy())
67                }
68            }
69        };
70
71        // Cast `start` to `i128`.
72        let start = match cast_to_number(start) {
73            Ok(v) => v,
74            Err(s) => return s,
75        };
76        // Cast `stop` to `i128`.
77        let stop = match cast_to_number(stop) {
78            Ok(v) => v,
79            Err(s) => return s,
80        };
81
82        let new_block_id = self.state.node_builder.next_id();
83
84        let iter = if input.inclusive { Either::Left(start..=stop) } else { Either::Right(start..stop) };
85
86        // Create a block statement to replace the iteration statement.
87        self.in_scope(new_block_id, |slf| {
88            Block {
89                span: input.span,
90                statements: iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect(),
91                id: new_block_id,
92            }
93            .into()
94        })
95    }
96
97    /// A helper function to unroll a single iteration an IterationStatement.
98    fn unroll_single_iteration(&mut self, input: &IterationStatement, iteration_count: i128) -> Statement {
99        // Construct a new node ID.
100        let const_id = self.state.node_builder.next_id();
101
102        let iterator_type =
103            self.state.type_table.get(&input.variable.id()).expect("guaranteed to have a type after type checking");
104
105        // Update the type table.
106        self.state.type_table.insert(const_id, iterator_type.clone());
107
108        let outer_block_id = self.state.node_builder.next_id();
109
110        // Reconstruct `iteration_count` as a `Literal`.
111        let Type::Integer(integer_type) = &iterator_type else {
112            unreachable!("Type checking enforces that the iteration variable is of integer type");
113        };
114
115        self.in_scope(outer_block_id, |slf| {
116            let value = Literal::integer(*integer_type, iteration_count.to_string(), Default::default(), const_id);
117
118            // Add the loop variable as a constant for the current scope.
119            slf.state.symbol_table.insert_const(slf.program, &[input.variable.name], value.into());
120
121            let duplicated_body =
122                super::duplicate::duplicate(input.block.clone(), &mut slf.state.symbol_table, &slf.state.node_builder);
123
124            let result = slf.reconstruct_block(duplicated_body).0.into();
125
126            Block { statements: vec![result], span: input.span(), id: outer_block_id }.into()
127        })
128    }
129}