leo_passes/const_propagation/
statement.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::ConstPropagationVisitor;
18
19use leo_ast::{
20    AssertStatement,
21    AssertVariant,
22    AssignStatement,
23    Block,
24    ConditionalStatement,
25    ConstDeclaration,
26    DefinitionStatement,
27    Expression,
28    ExpressionReconstructor,
29    ExpressionStatement,
30    IterationStatement,
31    Node,
32    ReturnStatement,
33    Statement,
34    StatementReconstructor,
35    TypeReconstructor,
36};
37
38impl StatementReconstructor for ConstPropagationVisitor<'_> {
39    fn reconstruct_assert(&mut self, mut input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
40        // Catching asserts at compile time is not feasible here due to control flow, but could be done in
41        // a later pass after loops are unrolled and conditionals are flattened.
42        input.variant = match input.variant {
43            AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
44
45            AssertVariant::AssertEq(lhs, rhs) => {
46                AssertVariant::AssertEq(self.reconstruct_expression(lhs).0, self.reconstruct_expression(rhs).0)
47            }
48
49            AssertVariant::AssertNeq(lhs, rhs) => {
50                AssertVariant::AssertNeq(self.reconstruct_expression(lhs).0, self.reconstruct_expression(rhs).0)
51            }
52        };
53
54        (input.into(), None)
55    }
56
57    fn reconstruct_assign(&mut self, assign: AssignStatement) -> (Statement, Self::AdditionalOutput) {
58        let value = self.reconstruct_expression(assign.value).0;
59        let place = self.reconstruct_expression(assign.place).0;
60        (AssignStatement { value, place, ..assign }.into(), None)
61    }
62
63    fn reconstruct_block(&mut self, mut block: Block) -> (Block, Self::AdditionalOutput) {
64        self.in_scope(block.id(), |slf| {
65            block.statements.retain_mut(|statement| {
66                let bogus_statement = Statement::dummy();
67                let this_statement = std::mem::replace(statement, bogus_statement);
68                *statement = slf.reconstruct_statement(this_statement).0;
69                !statement.is_empty()
70            });
71            (block, None)
72        })
73    }
74
75    fn reconstruct_conditional(
76        &mut self,
77        mut conditional: ConditionalStatement,
78    ) -> (Statement, Self::AdditionalOutput) {
79        conditional.condition = self.reconstruct_expression(conditional.condition).0;
80        conditional.then = self.reconstruct_block(conditional.then).0;
81        if let Some(mut otherwise) = conditional.otherwise {
82            *otherwise = self.reconstruct_statement(*otherwise).0;
83            conditional.otherwise = Some(otherwise);
84        }
85
86        (Statement::Conditional(conditional), None)
87    }
88
89    fn reconstruct_const(&mut self, mut input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
90        let span = input.span();
91
92        let type_ = self.reconstruct_type(input.type_).0;
93        let (expr, opt_value) = self.reconstruct_expression(input.value);
94
95        if opt_value.is_some() {
96            self.state.symbol_table.insert_const(self.program, input.place.name, expr.clone());
97        } else {
98            self.const_not_evaluated = Some(span);
99        }
100
101        input.type_ = type_;
102        input.value = expr;
103
104        (Statement::Const(input), None)
105    }
106
107    fn reconstruct_definition(&mut self, definition: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
108        (
109            DefinitionStatement {
110                type_: definition.type_.map(|ty| self.reconstruct_type(ty).0),
111                value: self.reconstruct_expression(definition.value).0,
112                ..definition
113            }
114            .into(),
115            None,
116        )
117    }
118
119    fn reconstruct_expression_statement(
120        &mut self,
121        mut input: ExpressionStatement,
122    ) -> (Statement, Self::AdditionalOutput) {
123        input.expression = self.reconstruct_expression(input.expression).0;
124
125        if matches!(&input.expression, Expression::Unit(..) | Expression::Literal(..)) {
126            // We were able to evaluate this at compile time, but we need to get rid of this statement as
127            // we can't have expression statements that aren't calls.
128            (Statement::dummy(), Default::default())
129        } else {
130            (input.into(), Default::default())
131        }
132    }
133
134    fn reconstruct_iteration(&mut self, iteration: IterationStatement) -> (Statement, Self::AdditionalOutput) {
135        let id = iteration.id();
136        let type_ = iteration.type_.map(|ty| self.reconstruct_type(ty).0);
137        let start = self.reconstruct_expression(iteration.start).0;
138        let stop = self.reconstruct_expression(iteration.stop).0;
139        self.in_scope(id, |slf| {
140            (
141                IterationStatement { type_, start, stop, block: slf.reconstruct_block(iteration.block).0, ..iteration }
142                    .into(),
143                None,
144            )
145        })
146    }
147
148    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
149        (
150            ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
151            Default::default(),
152        )
153    }
154}