leo_passes/common/replacer/
mod.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_ast::{AstReconstructor, Block, Expression, IterationStatement, Node as _, ProgramReconstructor, Statement};
18
19use crate::CompilerState;
20
21/// A `Replacer` traverses and reconstructs the AST, applying a user-defined replacement function to each `Expression`.
22///
23/// For example, this can be used to:
24/// - **Rename identifiers**: to systematically rename identifiers throughout the AST.
25/// - **Expression interpolation**: such as substituting arguments into function bodies.
26///
27/// During reconstruction, node IDs for blocks and loop bodies are regenerated using a `NodeBuilder` to ensure
28/// that the resulting scopes are different from the original. This avoids issues with stale or conflicting
29/// parent-child relationships between scopes.
30///
31/// The replacement function (`replace`) is applied early in expression reconstruction. If it produces a new
32/// expression (i.e., with a different node ID), it replaces the original; otherwise, the expression is
33/// recursively reconstructed as usual.
34///
35/// Note: Only `Expression` nodes are currently subject to replacement logic; all other AST nodes are
36/// reconstructed structurally.
37///
38/// TODO: Consider whether all nodes (not just scopes) should receive new IDs for consistency.
39pub struct Replacer<'a, F>
40where
41    F: Fn(&Expression) -> Expression,
42{
43    state: &'a mut CompilerState,
44    refresh_expr_ids: bool,
45    replace: F,
46}
47
48impl<'a, F> Replacer<'a, F>
49where
50    F: Fn(&Expression) -> Expression,
51{
52    pub fn new(replace: F, refresh_expr_ids: bool, state: &'a mut CompilerState) -> Self {
53        Self { replace, refresh_expr_ids, state }
54    }
55}
56
57impl<F> AstReconstructor for Replacer<'_, F>
58where
59    F: Fn(&Expression) -> Expression,
60{
61    type AdditionalOutput = ();
62
63    fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
64        let opt_old_type = self.state.type_table.get(&input.id());
65        let replaced_expr = (self.replace)(&input);
66        let (mut new_expr, additional) = if replaced_expr.id() == input.id() {
67            // Replacement didn't happen, so just use the default implementation.
68            match input {
69                Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
70                Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
71                Expression::Async(async_) => self.reconstruct_async(async_),
72                Expression::Array(array) => self.reconstruct_array(array),
73                Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
74                Expression::Binary(binary) => self.reconstruct_binary(*binary),
75                Expression::Call(call) => self.reconstruct_call(*call),
76                Expression::Cast(cast) => self.reconstruct_cast(*cast),
77                Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
78                Expression::Err(err) => self.reconstruct_err(err),
79                Expression::Path(path) => self.reconstruct_path(path),
80                Expression::Literal(value) => self.reconstruct_literal(value),
81                Expression::Locator(locator) => self.reconstruct_locator(locator),
82                Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
83                Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
84                Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
85                Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
86                Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
87                Expression::Unary(unary) => self.reconstruct_unary(*unary),
88                Expression::Unit(unit) => self.reconstruct_unit(unit),
89            }
90        } else {
91            (replaced_expr, Default::default())
92        };
93
94        // Refresh IDs if required
95        if self.refresh_expr_ids {
96            new_expr.set_id(self.state.node_builder.next_id());
97        }
98
99        // If the expression was in the type table before, make an entry for the new expression.
100        if let Some(old_type) = opt_old_type {
101            self.state.type_table.insert(new_expr.id(), old_type);
102        }
103
104        (new_expr, additional)
105    }
106
107    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
108        (
109            Block {
110                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
111                span: input.span,
112                id: self.state.node_builder.next_id(),
113            },
114            Default::default(),
115        )
116    }
117
118    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
119        (
120            IterationStatement {
121                type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
122                start: self.reconstruct_expression(input.start).0,
123                stop: self.reconstruct_expression(input.stop).0,
124                block: self.reconstruct_block(input.block).0,
125                id: self.state.node_builder.next_id(),
126                ..input
127            }
128            .into(),
129            Default::default(),
130        )
131    }
132}
133
134impl<F> ProgramReconstructor for Replacer<'_, F> where F: Fn(&Expression) -> Expression {}