leo_passes/common/replacer/
mod.rs1use leo_ast::{AstReconstructor, Block, Expression, IterationStatement, Node as _, ProgramReconstructor, Statement};
18
19use crate::CompilerState;
20
21pub struct Replacer<'a, F>
40where
41 F: Fn(&Expression) -> Expression,
42{
43 state: &'a mut CompilerState,
44 refresh_expr_ids: bool,
45 replace: F,
46}
47
48impl<'a, F> Replacer<'a, F>
49where
50 F: Fn(&Expression) -> Expression,
51{
52 pub fn new(replace: F, refresh_expr_ids: bool, state: &'a mut CompilerState) -> Self {
53 Self { replace, refresh_expr_ids, state }
54 }
55}
56
57impl<F> AstReconstructor for Replacer<'_, F>
58where
59 F: Fn(&Expression) -> Expression,
60{
61 type AdditionalOutput = ();
62
63 fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
64 let opt_old_type = self.state.type_table.get(&input.id());
65 let replaced_expr = (self.replace)(&input);
66 let (mut new_expr, additional) = if replaced_expr.id() == input.id() {
67 match input {
69 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
70 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
71 Expression::Async(async_) => self.reconstruct_async(async_),
72 Expression::Array(array) => self.reconstruct_array(array),
73 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
74 Expression::Binary(binary) => self.reconstruct_binary(*binary),
75 Expression::Call(call) => self.reconstruct_call(*call),
76 Expression::Cast(cast) => self.reconstruct_cast(*cast),
77 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
78 Expression::Err(err) => self.reconstruct_err(err),
79 Expression::Path(path) => self.reconstruct_path(path),
80 Expression::Literal(value) => self.reconstruct_literal(value),
81 Expression::Locator(locator) => self.reconstruct_locator(locator),
82 Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
83 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
84 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
85 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
86 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
87 Expression::Unary(unary) => self.reconstruct_unary(*unary),
88 Expression::Unit(unit) => self.reconstruct_unit(unit),
89 }
90 } else {
91 (replaced_expr, Default::default())
92 };
93
94 if self.refresh_expr_ids {
96 new_expr.set_id(self.state.node_builder.next_id());
97 }
98
99 if let Some(old_type) = opt_old_type {
101 self.state.type_table.insert(new_expr.id(), old_type);
102 }
103
104 (new_expr, additional)
105 }
106
107 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
108 (
109 Block {
110 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
111 span: input.span,
112 id: self.state.node_builder.next_id(),
113 },
114 Default::default(),
115 )
116 }
117
118 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
119 (
120 IterationStatement {
121 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
122 start: self.reconstruct_expression(input.start).0,
123 stop: self.reconstruct_expression(input.stop).0,
124 block: self.reconstruct_block(input.block).0,
125 id: self.state.node_builder.next_id(),
126 ..input
127 }
128 .into(),
129 Default::default(),
130 )
131 }
132}
133
134impl<F> ProgramReconstructor for Replacer<'_, F> where F: Fn(&Expression) -> Expression {}