leo_passes/common/replacer/
mod.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{AstReconstructor, Block, Expression, IterationStatement, Node as _, ProgramReconstructor, Statement};
18
19use crate::CompilerState;
20
21/// A `Replacer` traverses and reconstructs the AST, applying a user-defined replacement function to each `Expression`.
22///
23/// For example, this can be used to:
24/// - **Rename identifiers**: to systematically rename identifiers throughout the AST.
25/// - **Expression interpolation**: such as substituting arguments into function bodies.
26///
27/// During reconstruction, node IDs for blocks and loop bodies are regenerated using a `NodeBuilder` to ensure
28/// that the resulting scopes are different from the original. This avoids issues with stale or conflicting
29/// parent-child relationships between scopes.
30///
31/// The replacement function (`replace`) is applied early in expression reconstruction. If it produces a new
32/// expression (i.e., with a different node ID), it replaces the original; otherwise, the expression is
33/// recursively reconstructed as usual.
34///
35/// Note: Only `Expression` nodes are currently subject to replacement logic; all other AST nodes are
36/// reconstructed structurally.
37///
38/// TODO: Consider whether all nodes (not just scopes) should receive new IDs for consistency.
39pub struct Replacer<'a, F>
40where
41    F: Fn(&Expression) -> Expression,
42{
43    state: &'a mut CompilerState,
44    refresh_expr_ids: bool,
45    replace: F,
46}
47
48impl<'a, F> Replacer<'a, F>
49where
50    F: Fn(&Expression) -> Expression,
51{
52    pub fn new(replace: F, refresh_expr_ids: bool, state: &'a mut CompilerState) -> Self {
53        Self { replace, refresh_expr_ids, state }
54    }
55}
56
57impl<F> AstReconstructor for Replacer<'_, F>
58where
59    F: Fn(&Expression) -> Expression,
60{
61    type AdditionalInput = ();
62    type AdditionalOutput = ();
63
64    fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
65        let opt_old_type = self.state.type_table.get(&input.id());
66        let replaced_expr = (self.replace)(&input);
67        let (mut new_expr, additional) = if replaced_expr.id() == input.id() {
68            // Replacement didn't happen, so just use the default implementation.
69            match input {
70                Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, &()),
71                Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, &()),
72                Expression::Async(async_) => self.reconstruct_async(async_, &()),
73                Expression::Array(array) => self.reconstruct_array(array, &()),
74                Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
75                Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
76                Expression::Call(call) => self.reconstruct_call(*call, &()),
77                Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
78                Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
79                Expression::Err(err) => self.reconstruct_err(err, &()),
80                Expression::Path(path) => self.reconstruct_path(path, &()),
81                Expression::Literal(value) => self.reconstruct_literal(value, &()),
82                Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
83                Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
84                Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
85                Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
86                Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
87                Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
88                Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
89                Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
90            }
91        } else {
92            (replaced_expr, Default::default())
93        };
94
95        // Refresh IDs if required
96        if self.refresh_expr_ids {
97            new_expr.set_id(self.state.node_builder.next_id());
98        }
99
100        // If the expression was in the type table before, make an entry for the new expression.
101        if let Some(old_type) = opt_old_type {
102            self.state.type_table.insert(new_expr.id(), old_type);
103        }
104
105        (new_expr, additional)
106    }
107
108    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
109        (
110            Block {
111                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
112                span: input.span,
113                id: self.state.node_builder.next_id(),
114            },
115            Default::default(),
116        )
117    }
118
119    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
120        (
121            IterationStatement {
122                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
123                start: self.reconstruct_expression(input.start, &()).0,
124                stop: self.reconstruct_expression(input.stop, &()).0,
125                block: self.reconstruct_block(input.block).0,
126                id: self.state.node_builder.next_id(),
127                ..input
128            }
129            .into(),
130            Default::default(),
131        )
132    }
133}
134
135impl<F> ProgramReconstructor for Replacer<'_, F> where F: Fn(&Expression) -> Expression {}