1use crate::*;
22
23pub trait TypeReconstructor: ExpressionReconstructor {
25 fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
26 match input {
27 Type::Array(array_type) => self.reconstruct_array_type(array_type),
28 Type::Future(future_type) => self.reconstruct_future_type(future_type),
29 Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
30 Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
31 Type::Address
32 | Type::Boolean
33 | Type::Composite(_)
34 | Type::Field
35 | Type::Group
36 | Type::Identifier(_)
37 | Type::Integer(_)
38 | Type::Scalar
39 | Type::Signature
40 | Type::String
41 | Type::Numeric
42 | Type::Unit
43 | Type::Err => (input.clone(), Default::default()),
44 }
45 }
46
47 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
48 (
49 Type::Array(ArrayType {
50 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
51 length: Box::new(self.reconstruct_expression(*input.length).0),
52 }),
53 Default::default(),
54 )
55 }
56
57 fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
58 (
59 Type::Future(FutureType {
60 inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
61 ..input
62 }),
63 Default::default(),
64 )
65 }
66
67 fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
68 (
69 Type::Mapping(MappingType {
70 key: Box::new(self.reconstruct_type(*input.key).0),
71 value: Box::new(self.reconstruct_type(*input.value).0),
72 ..input
73 }),
74 Default::default(),
75 )
76 }
77
78 fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
79 (
80 Type::Tuple(TupleType {
81 elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
82 }),
83 Default::default(),
84 )
85 }
86}
87
88pub trait ExpressionReconstructor {
90 type AdditionalOutput: Default;
91
92 fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
93 match input {
94 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
95 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
96 Expression::Array(array) => self.reconstruct_array(array),
97 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
98 Expression::Binary(binary) => self.reconstruct_binary(*binary),
99 Expression::Call(call) => self.reconstruct_call(*call),
100 Expression::Cast(cast) => self.reconstruct_cast(*cast),
101 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
102 Expression::Err(err) => self.reconstruct_err(err),
103 Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
104 Expression::Literal(value) => self.reconstruct_literal(value),
105 Expression::Locator(locator) => self.reconstruct_locator(locator),
106 Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
107 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
108 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
109 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
110 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
111 Expression::Unary(unary) => self.reconstruct_unary(*unary),
112 Expression::Unit(unit) => self.reconstruct_unit(unit),
113 }
114 }
115
116 fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
117 (
118 ArrayAccess {
119 array: self.reconstruct_expression(input.array).0,
120 index: self.reconstruct_expression(input.index).0,
121 ..input
122 }
123 .into(),
124 Default::default(),
125 )
126 }
127
128 fn reconstruct_associated_constant(
129 &mut self,
130 input: AssociatedConstantExpression,
131 ) -> (Expression, Self::AdditionalOutput) {
132 (input.into(), Default::default())
133 }
134
135 fn reconstruct_associated_function(
136 &mut self,
137 input: AssociatedFunctionExpression,
138 ) -> (Expression, Self::AdditionalOutput) {
139 (
140 AssociatedFunctionExpression {
141 arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
142 ..input
143 }
144 .into(),
145 Default::default(),
146 )
147 }
148
149 fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
150 (MemberAccess { inner: self.reconstruct_expression(input.inner).0, ..input }.into(), Default::default())
151 }
152
153 fn reconstruct_repeat(&mut self, input: RepeatExpression) -> (Expression, Self::AdditionalOutput) {
154 (
155 RepeatExpression {
156 expr: self.reconstruct_expression(input.expr).0,
157 count: self.reconstruct_expression(input.count).0,
158 ..input
159 }
160 .into(),
161 Default::default(),
162 )
163 }
164
165 fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
166 (TupleAccess { tuple: self.reconstruct_expression(input.tuple).0, ..input }.into(), Default::default())
167 }
168
169 fn reconstruct_array(&mut self, input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
170 (
171 ArrayExpression {
172 elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
173 ..input
174 }
175 .into(),
176 Default::default(),
177 )
178 }
179
180 fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
181 (
182 BinaryExpression {
183 left: self.reconstruct_expression(input.left).0,
184 right: self.reconstruct_expression(input.right).0,
185 ..input
186 }
187 .into(),
188 Default::default(),
189 )
190 }
191
192 fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
193 (
194 CallExpression {
195 const_arguments: input
196 .const_arguments
197 .into_iter()
198 .map(|arg| self.reconstruct_expression(arg).0)
199 .collect(),
200 arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
201 ..input
202 }
203 .into(),
204 Default::default(),
205 )
206 }
207
208 fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
209 (
210 CastExpression { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
211 Default::default(),
212 )
213 }
214
215 fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
216 (
217 StructExpression {
218 members: input
219 .members
220 .into_iter()
221 .map(|member| StructVariableInitializer {
222 identifier: member.identifier,
223 expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
224 span: member.span,
225 id: member.id,
226 })
227 .collect(),
228 ..input
229 }
230 .into(),
231 Default::default(),
232 )
233 }
234
235 fn reconstruct_err(&mut self, _input: ErrExpression) -> (Expression, Self::AdditionalOutput) {
236 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
237 }
238
239 fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
240 (input.into(), Default::default())
241 }
242
243 fn reconstruct_literal(&mut self, input: Literal) -> (Expression, Self::AdditionalOutput) {
244 (input.into(), Default::default())
245 }
246
247 fn reconstruct_locator(&mut self, input: LocatorExpression) -> (Expression, Self::AdditionalOutput) {
248 (input.into(), Default::default())
249 }
250
251 fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
252 (
253 TernaryExpression {
254 condition: self.reconstruct_expression(input.condition).0,
255 if_true: self.reconstruct_expression(input.if_true).0,
256 if_false: self.reconstruct_expression(input.if_false).0,
257 span: input.span,
258 id: input.id,
259 }
260 .into(),
261 Default::default(),
262 )
263 }
264
265 fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) {
266 (
267 TupleExpression {
268 elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
269 ..input
270 }
271 .into(),
272 Default::default(),
273 )
274 }
275
276 fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
277 (
278 UnaryExpression { receiver: self.reconstruct_expression(input.receiver).0, ..input }.into(),
279 Default::default(),
280 )
281 }
282
283 fn reconstruct_unit(&mut self, input: UnitExpression) -> (Expression, Self::AdditionalOutput) {
284 (input.into(), Default::default())
285 }
286}
287
288pub trait StatementReconstructor: TypeReconstructor {
290 fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
291 match input {
292 Statement::Assert(assert) => self.reconstruct_assert(assert),
293 Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
294 Statement::Block(stmt) => {
295 let (stmt, output) = self.reconstruct_block(stmt);
296 (stmt.into(), output)
297 }
298 Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
299 Statement::Const(stmt) => self.reconstruct_const(stmt),
300 Statement::Definition(stmt) => self.reconstruct_definition(stmt),
301 Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
302 Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
303 Statement::Return(stmt) => self.reconstruct_return(stmt),
304 }
305 }
306
307 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
308 (
309 AssertStatement {
310 variant: match input.variant {
311 AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
312 AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
313 self.reconstruct_expression(left).0,
314 self.reconstruct_expression(right).0,
315 ),
316 AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
317 self.reconstruct_expression(left).0,
318 self.reconstruct_expression(right).0,
319 ),
320 },
321 ..input
322 }
323 .into(),
324 Default::default(),
325 )
326 }
327
328 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
329 (AssignStatement { value: self.reconstruct_expression(input.value).0, ..input }.into(), Default::default())
330 }
331
332 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
333 (
334 Block {
335 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
336 span: input.span,
337 id: input.id,
338 },
339 Default::default(),
340 )
341 }
342
343 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
344 (
345 ConditionalStatement {
346 condition: self.reconstruct_expression(input.condition).0,
347 then: self.reconstruct_block(input.then).0,
348 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
349 ..input
350 }
351 .into(),
352 Default::default(),
353 )
354 }
355
356 fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
357 (
358 ConstDeclaration {
359 type_: self.reconstruct_type(input.type_).0,
360 value: self.reconstruct_expression(input.value).0,
361 ..input
362 }
363 .into(),
364 Default::default(),
365 )
366 }
367
368 fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
369 (
370 DefinitionStatement {
371 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
372 value: self.reconstruct_expression(input.value).0,
373 ..input
374 }
375 .into(),
376 Default::default(),
377 )
378 }
379
380 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
381 (
382 ExpressionStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
383 Default::default(),
384 )
385 }
386
387 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
388 (
389 IterationStatement {
390 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
391 start: self.reconstruct_expression(input.start).0,
392 stop: self.reconstruct_expression(input.stop).0,
393 block: self.reconstruct_block(input.block).0,
394 ..input
395 }
396 .into(),
397 Default::default(),
398 )
399 }
400
401 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
402 (
403 ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
404 Default::default(),
405 )
406 }
407}
408
409pub trait ProgramReconstructor: StatementReconstructor {
411 fn reconstruct_program(&mut self, input: Program) -> Program {
412 Program {
413 imports: input
414 .imports
415 .into_iter()
416 .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
417 .collect(),
418 stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
419 program_scopes: input
420 .program_scopes
421 .into_iter()
422 .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
423 .collect(),
424 }
425 }
426
427 fn reconstruct_stub(&mut self, input: Stub) -> Stub {
428 Stub {
429 imports: input.imports,
430 stub_id: input.stub_id,
431 consts: input.consts,
432 structs: input.structs,
433 mappings: input.mappings,
434 span: input.span,
435 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
436 }
437 }
438
439 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
440 ProgramScope {
441 program_id: input.program_id,
442 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
443 mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
444 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
445 consts: input
446 .consts
447 .into_iter()
448 .map(|(i, c)| match self.reconstruct_const(c) {
449 (Statement::Const(declaration), _) => (i, declaration),
450 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
451 })
452 .collect(),
453 span: input.span,
454 }
455 }
456
457 fn reconstruct_function(&mut self, input: Function) -> Function {
458 Function {
459 annotations: input.annotations,
460 variant: input.variant,
461 identifier: input.identifier,
462 const_parameters: input
463 .const_parameters
464 .iter()
465 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
466 .collect(),
467 input: input
468 .input
469 .iter()
470 .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
471 .collect(),
472 output: input
473 .output
474 .iter()
475 .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
476 .collect(),
477 output_type: self.reconstruct_type(input.output_type).0,
478 block: self.reconstruct_block(input.block).0,
479 span: input.span,
480 id: input.id,
481 }
482 }
483
484 fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
485 input
486 }
487
488 fn reconstruct_struct(&mut self, input: Composite) -> Composite {
489 Composite {
490 members: input
491 .members
492 .iter()
493 .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
494 .collect(),
495 ..input
496 }
497 }
498
499 fn reconstruct_import(&mut self, input: Program) -> Program {
500 self.reconstruct_program(input)
501 }
502
503 fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
504 Mapping {
505 key_type: self.reconstruct_type(input.key_type).0,
506 value_type: self.reconstruct_type(input.value_type).0,
507 ..input
508 }
509 }
510}