leo_passes/loop_unrolling/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{Expression::Literal, interpreter_value::literal_to_value, *};
18
19use leo_errors::LoopUnrollerError;
20
21use super::UnrollingVisitor;
22
23impl AstReconstructor for UnrollingVisitor<'_> {
24    type AdditionalInput = ();
25    type AdditionalOutput = ();
26
27    /* Expressions */
28    fn reconstruct_repeat(
29        &mut self,
30        input: RepeatExpression,
31        _additional: &(),
32    ) -> (Expression, Self::AdditionalOutput) {
33        // Because the value of `count` affects the type of a repeat expression, we need to assign a new ID to the
34        // reconstructed `RepeatExpression` and update the type table accordingly.
35        let new_id = self.state.node_builder.next_id();
36        let new_count = self.reconstruct_expression(input.count, &()).0;
37        let el_ty = self.state.type_table.get(&input.expr.id()).expect("guaranteed by type checking");
38        self.state.type_table.insert(new_id, Type::Array(ArrayType::new(el_ty, new_count.clone())));
39        (
40            RepeatExpression {
41                expr: self.reconstruct_expression(input.expr, &()).0,
42                count: new_count,
43                id: new_id,
44                ..input
45            }
46            .into(),
47            Default::default(),
48        )
49    }
50
51    fn reconstruct_block(&mut self, mut input: Block) -> (Block, Self::AdditionalOutput) {
52        self.in_scope(input.id(), |slf| {
53            input.statements = input.statements.into_iter().map(|stmt| slf.reconstruct_statement(stmt).0).collect();
54
55            (input, Default::default())
56        })
57    }
58
59    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
60        (input.into(), Default::default())
61    }
62
63    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
64        (
65            DefinitionStatement {
66                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
67                value: self.reconstruct_expression(input.value, &()).0,
68                ..input
69            }
70            .into(),
71            Default::default(),
72        )
73    }
74
75    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
76        // There's no need to reconstruct the bound expressions - they must be constants
77        // which can be evaluated through constant propagation.
78
79        let Literal(start_lit_ref) = &input.start else {
80            self.loop_not_unrolled = Some(input.start.span());
81            return (Statement::Iteration(Box::new(input)), Default::default());
82        };
83
84        let Literal(stop_lit_ref) = &input.stop else {
85            self.loop_not_unrolled = Some(input.stop.span());
86            return (Statement::Iteration(Box::new(input)), Default::default());
87        };
88
89        // Helper to clone and resolve Unsuffixed -> Integer literal based on type table
90        let resolve_unsuffixed = |lit: &leo_ast::Literal, expr_id| {
91            let mut resolved = lit.clone();
92            if let LiteralVariant::Unsuffixed(s) = &resolved.variant {
93                if let Some(Type::Integer(integer_type)) = self.state.type_table.get(&expr_id) {
94                    resolved.variant = LiteralVariant::Integer(integer_type, s.clone());
95                }
96            }
97            resolved
98        };
99
100        // Clone and resolve both literals
101        let resolved_start_lit = resolve_unsuffixed(start_lit_ref, input.start.id());
102        let resolved_stop_lit = resolve_unsuffixed(stop_lit_ref, input.stop.id());
103
104        // Convert resolved literals into constant values
105        let start_value =
106            literal_to_value(&resolved_start_lit, &None).expect("Parsing and type checking guarantee this works.");
107        let stop_value =
108            literal_to_value(&resolved_stop_lit, &None).expect("Parsing and type checking guarantee this works.");
109
110        // Ensure loop bounds are strictly increasing
111        if start_value.gte(&stop_value).expect("Type checking guarantees these are the same type") {
112            self.emit_err(LoopUnrollerError::loop_range_decreasing(input.stop.span()));
113        }
114
115        self.loop_unrolled = true;
116
117        // Actually unroll.
118        (self.unroll_iteration_statement(input, start_value, stop_value), Default::default())
119    }
120}