leo_passes/common/replacer/
mod.rs1use leo_ast::{AstReconstructor, Block, Expression, IterationStatement, Node as _, ProgramReconstructor, Statement};
18
19use crate::CompilerState;
20
21pub struct Replacer<'a, F>
40where
41 F: Fn(&Expression) -> Expression,
42{
43 state: &'a mut CompilerState,
44 refresh_expr_ids: bool,
45 replace: F,
46}
47
48impl<'a, F> Replacer<'a, F>
49where
50 F: Fn(&Expression) -> Expression,
51{
52 pub fn new(replace: F, refresh_expr_ids: bool, state: &'a mut CompilerState) -> Self {
53 Self { replace, refresh_expr_ids, state }
54 }
55}
56
57impl<F> AstReconstructor for Replacer<'_, F>
58where
59 F: Fn(&Expression) -> Expression,
60{
61 type AdditionalInput = ();
62 type AdditionalOutput = ();
63
64 fn reconstruct_expression(&mut self, input: Expression, _additional: &()) -> (Expression, Self::AdditionalOutput) {
65 let opt_old_type = self.state.type_table.get(&input.id());
66 let replaced_expr = (self.replace)(&input);
67 let (mut new_expr, additional) = if replaced_expr.id() == input.id() {
68 match input {
70 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, &()),
71 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, &()),
72 Expression::Async(async_) => self.reconstruct_async(async_, &()),
73 Expression::Array(array) => self.reconstruct_array(array, &()),
74 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, &()),
75 Expression::Binary(binary) => self.reconstruct_binary(*binary, &()),
76 Expression::Call(call) => self.reconstruct_call(*call, &()),
77 Expression::Cast(cast) => self.reconstruct_cast(*cast, &()),
78 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, &()),
79 Expression::Err(err) => self.reconstruct_err(err, &()),
80 Expression::Path(path) => self.reconstruct_path(path, &()),
81 Expression::Literal(value) => self.reconstruct_literal(value, &()),
82 Expression::Locator(locator) => self.reconstruct_locator(locator, &()),
83 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, &()),
84 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, &()),
85 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, &()),
86 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, &()),
87 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, &()),
88 Expression::Unary(unary) => self.reconstruct_unary(*unary, &()),
89 Expression::Unit(unit) => self.reconstruct_unit(unit, &()),
90 }
91 } else {
92 (replaced_expr, Default::default())
93 };
94
95 if self.refresh_expr_ids {
97 new_expr.set_id(self.state.node_builder.next_id());
98 }
99
100 if let Some(old_type) = opt_old_type {
102 self.state.type_table.insert(new_expr.id(), old_type);
103 }
104
105 (new_expr, additional)
106 }
107
108 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
109 (
110 Block {
111 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
112 span: input.span,
113 id: self.state.node_builder.next_id(),
114 },
115 Default::default(),
116 )
117 }
118
119 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
120 (
121 IterationStatement {
122 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
123 start: self.reconstruct_expression(input.start, &()).0,
124 stop: self.reconstruct_expression(input.stop, &()).0,
125 block: self.reconstruct_block(input.block).0,
126 id: self.state.node_builder.next_id(),
127 ..input
128 }
129 .into(),
130 Default::default(),
131 )
132 }
133}
134
135impl<F> ProgramReconstructor for Replacer<'_, F> where F: Fn(&Expression) -> Expression {}