1use crate::*;
22
23pub trait AstReconstructor {
25 type AdditionalOutput: Default;
26 type AdditionalInput: Default;
27
28 fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30 match input {
31 Type::Array(array_type) => self.reconstruct_array_type(array_type),
32 Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33 Type::Future(future_type) => self.reconstruct_future_type(future_type),
34 Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35 Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36 Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37 Type::Address
38 | Type::Boolean
39 | Type::Field
40 | Type::Group
41 | Type::Identifier(_)
42 | Type::Integer(_)
43 | Type::Scalar
44 | Type::Signature
45 | Type::String
46 | Type::Numeric
47 | Type::Unit
48 | Type::Err => (input.clone(), Default::default()),
49 }
50 }
51
52 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
53 (
54 Type::Array(ArrayType {
55 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
56 length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
57 }),
58 Default::default(),
59 )
60 }
61
62 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
63 (
64 Type::Composite(CompositeType {
65 const_arguments: input
66 .const_arguments
67 .into_iter()
68 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
69 .collect(),
70 ..input
71 }),
72 Default::default(),
73 )
74 }
75
76 fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
77 (
78 Type::Future(FutureType {
79 inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
80 ..input
81 }),
82 Default::default(),
83 )
84 }
85
86 fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
87 (
88 Type::Mapping(MappingType {
89 key: Box::new(self.reconstruct_type(*input.key).0),
90 value: Box::new(self.reconstruct_type(*input.value).0),
91 ..input
92 }),
93 Default::default(),
94 )
95 }
96
97 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
98 (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
99 }
100
101 fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
102 (
103 Type::Tuple(TupleType {
104 elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
105 }),
106 Default::default(),
107 )
108 }
109
110 fn reconstruct_expression(
112 &mut self,
113 input: Expression,
114 additional: &Self::AdditionalInput,
115 ) -> (Expression, Self::AdditionalOutput) {
116 match input {
117 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, additional),
118 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, additional),
119 Expression::Async(async_) => self.reconstruct_async(async_, additional),
120 Expression::Array(array) => self.reconstruct_array(array, additional),
121 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
122 Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
123 Expression::Call(call) => self.reconstruct_call(*call, additional),
124 Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
125 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
126 Expression::Err(err) => self.reconstruct_err(err, additional),
127 Expression::Path(path) => self.reconstruct_path(path, additional),
128 Expression::Literal(value) => self.reconstruct_literal(value, additional),
129 Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
130 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
131 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
132 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
133 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
134 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
135 Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
136 Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
137 }
138 }
139
140 fn reconstruct_array_access(
141 &mut self,
142 input: ArrayAccess,
143 _additional: &Self::AdditionalInput,
144 ) -> (Expression, Self::AdditionalOutput) {
145 (
146 ArrayAccess {
147 array: self.reconstruct_expression(input.array, &Default::default()).0,
148 index: self.reconstruct_expression(input.index, &Default::default()).0,
149 ..input
150 }
151 .into(),
152 Default::default(),
153 )
154 }
155
156 fn reconstruct_associated_constant(
157 &mut self,
158 input: AssociatedConstantExpression,
159 _additional: &Self::AdditionalInput,
160 ) -> (Expression, Self::AdditionalOutput) {
161 (input.into(), Default::default())
162 }
163
164 fn reconstruct_associated_function(
165 &mut self,
166 input: AssociatedFunctionExpression,
167 _additional: &Self::AdditionalInput,
168 ) -> (Expression, Self::AdditionalOutput) {
169 (
170 AssociatedFunctionExpression {
171 arguments: input
172 .arguments
173 .into_iter()
174 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
175 .collect(),
176 ..input
177 }
178 .into(),
179 Default::default(),
180 )
181 }
182
183 fn reconstruct_async(
184 &mut self,
185 input: AsyncExpression,
186 _additional: &Self::AdditionalInput,
187 ) -> (Expression, Self::AdditionalOutput) {
188 (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
189 }
190
191 fn reconstruct_member_access(
192 &mut self,
193 input: MemberAccess,
194 _additional: &Self::AdditionalInput,
195 ) -> (Expression, Self::AdditionalOutput) {
196 (
197 MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
198 Default::default(),
199 )
200 }
201
202 fn reconstruct_repeat(
203 &mut self,
204 input: RepeatExpression,
205 _additional: &Self::AdditionalInput,
206 ) -> (Expression, Self::AdditionalOutput) {
207 (
208 RepeatExpression {
209 expr: self.reconstruct_expression(input.expr, &Default::default()).0,
210 count: self.reconstruct_expression(input.count, &Default::default()).0,
211 ..input
212 }
213 .into(),
214 Default::default(),
215 )
216 }
217
218 fn reconstruct_tuple_access(
219 &mut self,
220 input: TupleAccess,
221 _additional: &Self::AdditionalInput,
222 ) -> (Expression, Self::AdditionalOutput) {
223 (
224 TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
225 Default::default(),
226 )
227 }
228
229 fn reconstruct_array(
230 &mut self,
231 input: ArrayExpression,
232 _additional: &Self::AdditionalInput,
233 ) -> (Expression, Self::AdditionalOutput) {
234 (
235 ArrayExpression {
236 elements: input
237 .elements
238 .into_iter()
239 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
240 .collect(),
241 ..input
242 }
243 .into(),
244 Default::default(),
245 )
246 }
247
248 fn reconstruct_binary(
249 &mut self,
250 input: BinaryExpression,
251 _additional: &Self::AdditionalInput,
252 ) -> (Expression, Self::AdditionalOutput) {
253 (
254 BinaryExpression {
255 left: self.reconstruct_expression(input.left, &Default::default()).0,
256 right: self.reconstruct_expression(input.right, &Default::default()).0,
257 ..input
258 }
259 .into(),
260 Default::default(),
261 )
262 }
263
264 fn reconstruct_call(
265 &mut self,
266 input: CallExpression,
267 _additional: &Self::AdditionalInput,
268 ) -> (Expression, Self::AdditionalOutput) {
269 (
270 CallExpression {
271 const_arguments: input
272 .const_arguments
273 .into_iter()
274 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
275 .collect(),
276 arguments: input
277 .arguments
278 .into_iter()
279 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
280 .collect(),
281 ..input
282 }
283 .into(),
284 Default::default(),
285 )
286 }
287
288 fn reconstruct_cast(
289 &mut self,
290 input: CastExpression,
291 _additional: &Self::AdditionalInput,
292 ) -> (Expression, Self::AdditionalOutput) {
293 (
294 CastExpression {
295 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
296 ..input
297 }
298 .into(),
299 Default::default(),
300 )
301 }
302
303 fn reconstruct_struct_init(
304 &mut self,
305 input: StructExpression,
306 _additional: &Self::AdditionalInput,
307 ) -> (Expression, Self::AdditionalOutput) {
308 (
309 StructExpression {
310 const_arguments: input
311 .const_arguments
312 .into_iter()
313 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
314 .collect(),
315 members: input
316 .members
317 .into_iter()
318 .map(|member| StructVariableInitializer {
319 identifier: member.identifier,
320 expression: member
321 .expression
322 .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
323 span: member.span,
324 id: member.id,
325 })
326 .collect(),
327 ..input
328 }
329 .into(),
330 Default::default(),
331 )
332 }
333
334 fn reconstruct_err(
335 &mut self,
336 _input: ErrExpression,
337 _additional: &Self::AdditionalInput,
338 ) -> (Expression, Self::AdditionalOutput) {
339 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
340 }
341
342 fn reconstruct_path(
343 &mut self,
344 input: Path,
345 _additional: &Self::AdditionalInput,
346 ) -> (Expression, Self::AdditionalOutput) {
347 (input.into(), Default::default())
348 }
349
350 fn reconstruct_literal(
351 &mut self,
352 input: Literal,
353 _additional: &Self::AdditionalInput,
354 ) -> (Expression, Self::AdditionalOutput) {
355 (input.into(), Default::default())
356 }
357
358 fn reconstruct_locator(
359 &mut self,
360 input: LocatorExpression,
361 _additional: &Self::AdditionalInput,
362 ) -> (Expression, Self::AdditionalOutput) {
363 (input.into(), Default::default())
364 }
365
366 fn reconstruct_ternary(
367 &mut self,
368 input: TernaryExpression,
369 _additional: &Self::AdditionalInput,
370 ) -> (Expression, Self::AdditionalOutput) {
371 (
372 TernaryExpression {
373 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
374 if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
375 if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
376 span: input.span,
377 id: input.id,
378 }
379 .into(),
380 Default::default(),
381 )
382 }
383
384 fn reconstruct_tuple(
385 &mut self,
386 input: TupleExpression,
387 _additional: &Self::AdditionalInput,
388 ) -> (Expression, Self::AdditionalOutput) {
389 (
390 TupleExpression {
391 elements: input
392 .elements
393 .into_iter()
394 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
395 .collect(),
396 ..input
397 }
398 .into(),
399 Default::default(),
400 )
401 }
402
403 fn reconstruct_unary(
404 &mut self,
405 input: UnaryExpression,
406 _additional: &Self::AdditionalInput,
407 ) -> (Expression, Self::AdditionalOutput) {
408 (
409 UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
410 .into(),
411 Default::default(),
412 )
413 }
414
415 fn reconstruct_unit(
416 &mut self,
417 input: UnitExpression,
418 _additional: &Self::AdditionalInput,
419 ) -> (Expression, Self::AdditionalOutput) {
420 (input.into(), Default::default())
421 }
422
423 fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
425 match input {
426 Statement::Assert(assert) => self.reconstruct_assert(assert),
427 Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
428 Statement::Block(stmt) => {
429 let (stmt, output) = self.reconstruct_block(stmt);
430 (stmt.into(), output)
431 }
432 Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
433 Statement::Const(stmt) => self.reconstruct_const(stmt),
434 Statement::Definition(stmt) => self.reconstruct_definition(stmt),
435 Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
436 Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
437 Statement::Return(stmt) => self.reconstruct_return(stmt),
438 }
439 }
440
441 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
442 (
443 AssertStatement {
444 variant: match input.variant {
445 AssertVariant::Assert(expr) => {
446 AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
447 }
448 AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
449 self.reconstruct_expression(left, &Default::default()).0,
450 self.reconstruct_expression(right, &Default::default()).0,
451 ),
452 AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
453 self.reconstruct_expression(left, &Default::default()).0,
454 self.reconstruct_expression(right, &Default::default()).0,
455 ),
456 },
457 ..input
458 }
459 .into(),
460 Default::default(),
461 )
462 }
463
464 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
465 (
466 AssignStatement {
467 place: self.reconstruct_expression(input.place, &Default::default()).0,
468 value: self.reconstruct_expression(input.value, &Default::default()).0,
469 ..input
470 }
471 .into(),
472 Default::default(),
473 )
474 }
475
476 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
477 (
478 Block {
479 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
480 span: input.span,
481 id: input.id,
482 },
483 Default::default(),
484 )
485 }
486
487 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
488 (
489 ConditionalStatement {
490 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
491 then: self.reconstruct_block(input.then).0,
492 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
493 ..input
494 }
495 .into(),
496 Default::default(),
497 )
498 }
499
500 fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
501 (
502 ConstDeclaration {
503 type_: self.reconstruct_type(input.type_).0,
504 value: self.reconstruct_expression(input.value, &Default::default()).0,
505 ..input
506 }
507 .into(),
508 Default::default(),
509 )
510 }
511
512 fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
513 (
514 DefinitionStatement {
515 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
516 value: self.reconstruct_expression(input.value, &Default::default()).0,
517 ..input
518 }
519 .into(),
520 Default::default(),
521 )
522 }
523
524 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
525 (
526 ExpressionStatement {
527 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
528 ..input
529 }
530 .into(),
531 Default::default(),
532 )
533 }
534
535 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
536 (
537 IterationStatement {
538 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
539 start: self.reconstruct_expression(input.start, &Default::default()).0,
540 stop: self.reconstruct_expression(input.stop, &Default::default()).0,
541 block: self.reconstruct_block(input.block).0,
542 ..input
543 }
544 .into(),
545 Default::default(),
546 )
547 }
548
549 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
550 (
551 ReturnStatement {
552 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
553 ..input
554 }
555 .into(),
556 Default::default(),
557 )
558 }
559}
560
561pub trait ProgramReconstructor: AstReconstructor {
563 fn reconstruct_program(&mut self, input: Program) -> Program {
564 let program_scopes =
565 input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
566 Program {
567 imports: input
568 .imports
569 .into_iter()
570 .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
571 .collect(),
572 stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
573 modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
574 program_scopes,
575 }
576 }
577
578 fn reconstruct_stub(&mut self, input: Stub) -> Stub {
579 Stub {
580 imports: input.imports,
581 stub_id: input.stub_id,
582 consts: input.consts,
583 structs: input.structs,
584 mappings: input.mappings,
585 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
586 span: input.span,
587 }
588 }
589
590 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
591 ProgramScope {
592 program_id: input.program_id,
593 consts: input
594 .consts
595 .into_iter()
596 .map(|(i, c)| match self.reconstruct_const(c) {
597 (Statement::Const(declaration), _) => (i, declaration),
598 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
599 })
600 .collect(),
601 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
602 mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
603 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
604 constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
605 span: input.span,
606 }
607 }
608
609 fn reconstruct_module(&mut self, input: Module) -> Module {
610 Module {
611 program_name: input.program_name,
612 path: input.path,
613 consts: input
614 .consts
615 .into_iter()
616 .map(|(i, c)| match self.reconstruct_const(c) {
617 (Statement::Const(declaration), _) => (i, declaration),
618 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
619 })
620 .collect(),
621 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
622 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
623 }
624 }
625
626 fn reconstruct_function(&mut self, input: Function) -> Function {
627 Function {
628 annotations: input.annotations,
629 variant: input.variant,
630 identifier: input.identifier,
631 const_parameters: input
632 .const_parameters
633 .iter()
634 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
635 .collect(),
636 input: input
637 .input
638 .iter()
639 .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
640 .collect(),
641 output: input
642 .output
643 .iter()
644 .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
645 .collect(),
646 output_type: self.reconstruct_type(input.output_type).0,
647 block: self.reconstruct_block(input.block).0,
648 span: input.span,
649 id: input.id,
650 }
651 }
652
653 fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
654 Constructor {
655 annotations: input.annotations,
656 block: self.reconstruct_block(input.block).0,
657 span: input.span,
658 id: input.id,
659 }
660 }
661
662 fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
663 input
664 }
665
666 fn reconstruct_struct(&mut self, input: Composite) -> Composite {
667 Composite {
668 const_parameters: input
669 .const_parameters
670 .iter()
671 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
672 .collect(),
673 members: input
674 .members
675 .iter()
676 .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
677 .collect(),
678 ..input
679 }
680 }
681
682 fn reconstruct_import(&mut self, input: Program) -> Program {
683 self.reconstruct_program(input)
684 }
685
686 fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
687 Mapping {
688 key_type: self.reconstruct_type(input.key_type).0,
689 value_type: self.reconstruct_type(input.value_type).0,
690 ..input
691 }
692 }
693}