leo_passes/loop_unrolling/
unroller.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
// Copyright (C) 2019-2025 Provable Inc.
// This file is part of the Leo library.

// The Leo library is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.

// The Leo library is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.

// You should have received a copy of the GNU General Public License
// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.

use leo_ast::{
    Block,
    Expression,
    ExpressionReconstructor,
    IterationStatement,
    Literal,
    Node,
    NodeBuilder,
    NodeID,
    Statement,
    StatementReconstructor,
    Type,
    Value,
};

use leo_errors::{emitter::Handler, loop_unroller::LoopUnrollerError};
use leo_span::{Span, Symbol};

use crate::{Clusivity, LoopBound, RangeIterator, SymbolTable, TypeTable};

pub struct Unroller<'a> {
    /// The symbol table for the function being processed.
    pub(crate) symbol_table: &'a mut SymbolTable,
    /// A mapping from node IDs to their types.
    pub(crate) type_table: &'a TypeTable,
    /// An error handler used for any errors found during unrolling.
    pub(crate) handler: &'a Handler,
    /// A counter used to generate unique node IDs.
    pub(crate) node_builder: &'a NodeBuilder,
    /// Are we in the midst of unrolling a loop?
    pub(crate) is_unrolling: bool,
    /// The current program name.
    pub(crate) current_program: Option<Symbol>,
    /// If we've encountered a loop that was not unrolled, here's it's spanned.
    pub(crate) loop_not_unrolled: Option<Span>,
    /// Have we unrolled any loop?
    pub(crate) loop_unrolled: bool,
}

impl<'a> Unroller<'a> {
    pub(crate) fn new(
        symbol_table: &'a mut SymbolTable,
        type_table: &'a TypeTable,
        handler: &'a Handler,
        node_builder: &'a NodeBuilder,
    ) -> Self {
        Self {
            symbol_table,
            type_table,
            handler,
            node_builder,
            is_unrolling: false,
            current_program: None,
            loop_not_unrolled: None,
            loop_unrolled: false,
        }
    }

    pub(crate) fn in_scope<T>(&mut self, id: NodeID, func: impl FnOnce(&mut Self) -> T) -> T {
        self.symbol_table.enter_scope(Some(id));
        let result = func(self);
        self.symbol_table.enter_parent();
        result
    }

    /// Emits a Loop Unrolling Error
    pub(crate) fn emit_err(&self, err: LoopUnrollerError) {
        self.handler.emit_err(err);
    }

    /// Unrolls an IterationStatement.
    pub(crate) fn unroll_iteration_statement<I: LoopBound>(&mut self, input: IterationStatement) -> Statement {
        let start: Value = input.start_value.borrow().as_ref().expect("Failed to get start value").clone();
        let stop: Value = input.stop_value.borrow().as_ref().expect("Failed to get stop value").clone();

        // Closure to check that the constant values are valid u128.
        // We already know these are integers since loop unrolling occurs after type checking.
        let cast_to_number = |v: Value| -> Result<I, Statement> {
            match v.try_into() {
                Ok(val_as_u128) => Ok(val_as_u128),
                Err(err) => {
                    self.handler.emit_err(err);
                    Err(Statement::dummy(input.span, self.node_builder.next_id()))
                }
            }
        };

        // Cast `start` to `I`.
        let start = match cast_to_number(start) {
            Ok(v) => v,
            Err(s) => return s,
        };
        // Cast `stop` to `I`.
        let stop = match cast_to_number(stop) {
            Ok(v) => v,
            Err(s) => return s,
        };

        let new_block_id = self.node_builder.next_id();

        // Create a block statement to replace the iteration statement.
        self.in_scope(new_block_id, |slf| {
            Statement::Block(Block {
                span: input.span,
                statements: match input.inclusive {
                    true => {
                        let iter = RangeIterator::new(start, stop, Clusivity::Inclusive);
                        iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect()
                    }
                    false => {
                        let iter = RangeIterator::new(start, stop, Clusivity::Exclusive);
                        iter.map(|iteration_count| slf.unroll_single_iteration(&input, iteration_count)).collect()
                    }
                },
                id: new_block_id,
            })
        })
    }

    /// A helper function to unroll a single iteration an IterationStatement.
    fn unroll_single_iteration<I: LoopBound>(&mut self, input: &IterationStatement, iteration_count: I) -> Statement {
        // Construct a new node ID.
        let const_id = self.node_builder.next_id();
        // Update the type table.
        self.type_table.insert(const_id, input.type_.clone());

        let outer_block_id = self.node_builder.next_id();

        // Reconstruct `iteration_count` as a `Literal`.
        let Type::Integer(integer_type) = &input.type_ else {
            unreachable!("Type checking enforces that the iteration variable is of integer type");
        };

        self.in_scope(outer_block_id, |slf| {
            let value = Literal::Integer(*integer_type, iteration_count.to_string(), Default::default(), const_id);

            // Add the loop variable as a constant for the current scope.
            slf.symbol_table.insert_const(
                slf.current_program.unwrap(),
                input.variable.name,
                Expression::Literal(value),
            );

            let duplicated_body = super::duplicate::duplicate(input.block.clone(), slf.symbol_table, slf.node_builder);

            let prior_is_unrolling = slf.is_unrolling;
            slf.is_unrolling = true;

            let result = Statement::Block(slf.reconstruct_block(duplicated_body).0);

            slf.is_unrolling = prior_is_unrolling;

            Statement::Block(Block { statements: vec![result], span: input.span(), id: outer_block_id })
        })
    }
}

impl ExpressionReconstructor for Unroller<'_> {
    type AdditionalOutput = ();
}