1use crate::*;
22
23pub trait AstReconstructor {
25 type AdditionalOutput: Default;
26 type AdditionalInput: Default;
27
28 fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30 match input {
31 Type::Array(array_type) => self.reconstruct_array_type(array_type),
32 Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33 Type::Future(future_type) => self.reconstruct_future_type(future_type),
34 Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35 Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36 Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37 Type::Vector(vector_type) => self.reconstruct_vector_type(vector_type),
38 Type::Address
39 | Type::Boolean
40 | Type::Field
41 | Type::Group
42 | Type::Identifier(_)
43 | Type::Integer(_)
44 | Type::Scalar
45 | Type::Signature
46 | Type::String
47 | Type::Numeric
48 | Type::Unit
49 | Type::Err => (input.clone(), Default::default()),
50 }
51 }
52
53 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
54 (
55 Type::Array(ArrayType {
56 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
57 length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
58 }),
59 Default::default(),
60 )
61 }
62
63 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
64 (
65 Type::Composite(CompositeType {
66 const_arguments: input
67 .const_arguments
68 .into_iter()
69 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
70 .collect(),
71 ..input
72 }),
73 Default::default(),
74 )
75 }
76
77 fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
78 (
79 Type::Future(FutureType {
80 inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
81 ..input
82 }),
83 Default::default(),
84 )
85 }
86
87 fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
88 (
89 Type::Mapping(MappingType {
90 key: Box::new(self.reconstruct_type(*input.key).0),
91 value: Box::new(self.reconstruct_type(*input.value).0),
92 ..input
93 }),
94 Default::default(),
95 )
96 }
97
98 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
99 (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
100 }
101
102 fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
103 (
104 Type::Tuple(TupleType {
105 elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
106 }),
107 Default::default(),
108 )
109 }
110
111 fn reconstruct_vector_type(&mut self, input: VectorType) -> (Type, Self::AdditionalOutput) {
112 (
113 Type::Vector(VectorType { element_type: Box::new(self.reconstruct_type(*input.element_type).0) }),
114 Default::default(),
115 )
116 }
117
118 fn reconstruct_expression(
120 &mut self,
121 input: Expression,
122 additional: &Self::AdditionalInput,
123 ) -> (Expression, Self::AdditionalOutput) {
124 match input {
125 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant, additional),
126 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function, additional),
127 Expression::Async(async_) => self.reconstruct_async(async_, additional),
128 Expression::Array(array) => self.reconstruct_array(array, additional),
129 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
130 Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
131 Expression::Call(call) => self.reconstruct_call(*call, additional),
132 Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
133 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
134 Expression::Err(err) => self.reconstruct_err(err, additional),
135 Expression::Path(path) => self.reconstruct_path(path, additional),
136 Expression::Literal(value) => self.reconstruct_literal(value, additional),
137 Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
138 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
139 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
140 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
141 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
142 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
143 Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
144 Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
145 }
146 }
147
148 fn reconstruct_array_access(
149 &mut self,
150 input: ArrayAccess,
151 _additional: &Self::AdditionalInput,
152 ) -> (Expression, Self::AdditionalOutput) {
153 (
154 ArrayAccess {
155 array: self.reconstruct_expression(input.array, &Default::default()).0,
156 index: self.reconstruct_expression(input.index, &Default::default()).0,
157 ..input
158 }
159 .into(),
160 Default::default(),
161 )
162 }
163
164 fn reconstruct_associated_constant(
165 &mut self,
166 input: AssociatedConstantExpression,
167 _additional: &Self::AdditionalInput,
168 ) -> (Expression, Self::AdditionalOutput) {
169 (input.into(), Default::default())
170 }
171
172 fn reconstruct_associated_function(
173 &mut self,
174 input: AssociatedFunctionExpression,
175 _additional: &Self::AdditionalInput,
176 ) -> (Expression, Self::AdditionalOutput) {
177 (
178 AssociatedFunctionExpression {
179 arguments: input
180 .arguments
181 .into_iter()
182 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
183 .collect(),
184 ..input
185 }
186 .into(),
187 Default::default(),
188 )
189 }
190
191 fn reconstruct_async(
192 &mut self,
193 input: AsyncExpression,
194 _additional: &Self::AdditionalInput,
195 ) -> (Expression, Self::AdditionalOutput) {
196 (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
197 }
198
199 fn reconstruct_member_access(
200 &mut self,
201 input: MemberAccess,
202 _additional: &Self::AdditionalInput,
203 ) -> (Expression, Self::AdditionalOutput) {
204 (
205 MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
206 Default::default(),
207 )
208 }
209
210 fn reconstruct_repeat(
211 &mut self,
212 input: RepeatExpression,
213 _additional: &Self::AdditionalInput,
214 ) -> (Expression, Self::AdditionalOutput) {
215 (
216 RepeatExpression {
217 expr: self.reconstruct_expression(input.expr, &Default::default()).0,
218 count: self.reconstruct_expression(input.count, &Default::default()).0,
219 ..input
220 }
221 .into(),
222 Default::default(),
223 )
224 }
225
226 fn reconstruct_tuple_access(
227 &mut self,
228 input: TupleAccess,
229 _additional: &Self::AdditionalInput,
230 ) -> (Expression, Self::AdditionalOutput) {
231 (
232 TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
233 Default::default(),
234 )
235 }
236
237 fn reconstruct_array(
238 &mut self,
239 input: ArrayExpression,
240 _additional: &Self::AdditionalInput,
241 ) -> (Expression, Self::AdditionalOutput) {
242 (
243 ArrayExpression {
244 elements: input
245 .elements
246 .into_iter()
247 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
248 .collect(),
249 ..input
250 }
251 .into(),
252 Default::default(),
253 )
254 }
255
256 fn reconstruct_binary(
257 &mut self,
258 input: BinaryExpression,
259 _additional: &Self::AdditionalInput,
260 ) -> (Expression, Self::AdditionalOutput) {
261 (
262 BinaryExpression {
263 left: self.reconstruct_expression(input.left, &Default::default()).0,
264 right: self.reconstruct_expression(input.right, &Default::default()).0,
265 ..input
266 }
267 .into(),
268 Default::default(),
269 )
270 }
271
272 fn reconstruct_call(
273 &mut self,
274 input: CallExpression,
275 _additional: &Self::AdditionalInput,
276 ) -> (Expression, Self::AdditionalOutput) {
277 (
278 CallExpression {
279 const_arguments: input
280 .const_arguments
281 .into_iter()
282 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
283 .collect(),
284 arguments: input
285 .arguments
286 .into_iter()
287 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
288 .collect(),
289 ..input
290 }
291 .into(),
292 Default::default(),
293 )
294 }
295
296 fn reconstruct_cast(
297 &mut self,
298 input: CastExpression,
299 _additional: &Self::AdditionalInput,
300 ) -> (Expression, Self::AdditionalOutput) {
301 (
302 CastExpression {
303 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
304 ..input
305 }
306 .into(),
307 Default::default(),
308 )
309 }
310
311 fn reconstruct_struct_init(
312 &mut self,
313 input: StructExpression,
314 _additional: &Self::AdditionalInput,
315 ) -> (Expression, Self::AdditionalOutput) {
316 (
317 StructExpression {
318 const_arguments: input
319 .const_arguments
320 .into_iter()
321 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
322 .collect(),
323 members: input
324 .members
325 .into_iter()
326 .map(|member| StructVariableInitializer {
327 identifier: member.identifier,
328 expression: member
329 .expression
330 .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
331 span: member.span,
332 id: member.id,
333 })
334 .collect(),
335 ..input
336 }
337 .into(),
338 Default::default(),
339 )
340 }
341
342 fn reconstruct_err(
343 &mut self,
344 _input: ErrExpression,
345 _additional: &Self::AdditionalInput,
346 ) -> (Expression, Self::AdditionalOutput) {
347 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
348 }
349
350 fn reconstruct_path(
351 &mut self,
352 input: Path,
353 _additional: &Self::AdditionalInput,
354 ) -> (Expression, Self::AdditionalOutput) {
355 (input.into(), Default::default())
356 }
357
358 fn reconstruct_literal(
359 &mut self,
360 input: Literal,
361 _additional: &Self::AdditionalInput,
362 ) -> (Expression, Self::AdditionalOutput) {
363 (input.into(), Default::default())
364 }
365
366 fn reconstruct_locator(
367 &mut self,
368 input: LocatorExpression,
369 _additional: &Self::AdditionalInput,
370 ) -> (Expression, Self::AdditionalOutput) {
371 (input.into(), Default::default())
372 }
373
374 fn reconstruct_ternary(
375 &mut self,
376 input: TernaryExpression,
377 _additional: &Self::AdditionalInput,
378 ) -> (Expression, Self::AdditionalOutput) {
379 (
380 TernaryExpression {
381 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
382 if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
383 if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
384 span: input.span,
385 id: input.id,
386 }
387 .into(),
388 Default::default(),
389 )
390 }
391
392 fn reconstruct_tuple(
393 &mut self,
394 input: TupleExpression,
395 _additional: &Self::AdditionalInput,
396 ) -> (Expression, Self::AdditionalOutput) {
397 (
398 TupleExpression {
399 elements: input
400 .elements
401 .into_iter()
402 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
403 .collect(),
404 ..input
405 }
406 .into(),
407 Default::default(),
408 )
409 }
410
411 fn reconstruct_unary(
412 &mut self,
413 input: UnaryExpression,
414 _additional: &Self::AdditionalInput,
415 ) -> (Expression, Self::AdditionalOutput) {
416 (
417 UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
418 .into(),
419 Default::default(),
420 )
421 }
422
423 fn reconstruct_unit(
424 &mut self,
425 input: UnitExpression,
426 _additional: &Self::AdditionalInput,
427 ) -> (Expression, Self::AdditionalOutput) {
428 (input.into(), Default::default())
429 }
430
431 fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
433 match input {
434 Statement::Assert(assert) => self.reconstruct_assert(assert),
435 Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
436 Statement::Block(stmt) => {
437 let (stmt, output) = self.reconstruct_block(stmt);
438 (stmt.into(), output)
439 }
440 Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
441 Statement::Const(stmt) => self.reconstruct_const(stmt),
442 Statement::Definition(stmt) => self.reconstruct_definition(stmt),
443 Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
444 Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
445 Statement::Return(stmt) => self.reconstruct_return(stmt),
446 }
447 }
448
449 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
450 (
451 AssertStatement {
452 variant: match input.variant {
453 AssertVariant::Assert(expr) => {
454 AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
455 }
456 AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
457 self.reconstruct_expression(left, &Default::default()).0,
458 self.reconstruct_expression(right, &Default::default()).0,
459 ),
460 AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
461 self.reconstruct_expression(left, &Default::default()).0,
462 self.reconstruct_expression(right, &Default::default()).0,
463 ),
464 },
465 ..input
466 }
467 .into(),
468 Default::default(),
469 )
470 }
471
472 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
473 (
474 AssignStatement {
475 place: self.reconstruct_expression(input.place, &Default::default()).0,
476 value: self.reconstruct_expression(input.value, &Default::default()).0,
477 ..input
478 }
479 .into(),
480 Default::default(),
481 )
482 }
483
484 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
485 (
486 Block {
487 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
488 span: input.span,
489 id: input.id,
490 },
491 Default::default(),
492 )
493 }
494
495 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
496 (
497 ConditionalStatement {
498 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
499 then: self.reconstruct_block(input.then).0,
500 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
501 ..input
502 }
503 .into(),
504 Default::default(),
505 )
506 }
507
508 fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
509 (
510 ConstDeclaration {
511 type_: self.reconstruct_type(input.type_).0,
512 value: self.reconstruct_expression(input.value, &Default::default()).0,
513 ..input
514 }
515 .into(),
516 Default::default(),
517 )
518 }
519
520 fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
521 (
522 DefinitionStatement {
523 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
524 value: self.reconstruct_expression(input.value, &Default::default()).0,
525 ..input
526 }
527 .into(),
528 Default::default(),
529 )
530 }
531
532 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
533 (
534 ExpressionStatement {
535 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
536 ..input
537 }
538 .into(),
539 Default::default(),
540 )
541 }
542
543 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
544 (
545 IterationStatement {
546 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
547 start: self.reconstruct_expression(input.start, &Default::default()).0,
548 stop: self.reconstruct_expression(input.stop, &Default::default()).0,
549 block: self.reconstruct_block(input.block).0,
550 ..input
551 }
552 .into(),
553 Default::default(),
554 )
555 }
556
557 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
558 (
559 ReturnStatement {
560 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
561 ..input
562 }
563 .into(),
564 Default::default(),
565 )
566 }
567}
568
569pub trait ProgramReconstructor: AstReconstructor {
571 fn reconstruct_program(&mut self, input: Program) -> Program {
572 let program_scopes =
573 input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
574 Program {
575 imports: input
576 .imports
577 .into_iter()
578 .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
579 .collect(),
580 stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
581 modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
582 program_scopes,
583 }
584 }
585
586 fn reconstruct_stub(&mut self, input: Stub) -> Stub {
587 Stub {
588 imports: input.imports,
589 stub_id: input.stub_id,
590 consts: input.consts,
591 structs: input.structs,
592 mappings: input.mappings,
593 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
594 span: input.span,
595 }
596 }
597
598 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
599 ProgramScope {
600 program_id: input.program_id,
601 consts: input
602 .consts
603 .into_iter()
604 .map(|(i, c)| match self.reconstruct_const(c) {
605 (Statement::Const(declaration), _) => (i, declaration),
606 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
607 })
608 .collect(),
609 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
610 mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
611 storage_variables: input
612 .storage_variables
613 .into_iter()
614 .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
615 .collect(),
616 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
617 constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
618 span: input.span,
619 }
620 }
621
622 fn reconstruct_module(&mut self, input: Module) -> Module {
623 Module {
624 program_name: input.program_name,
625 path: input.path,
626 consts: input
627 .consts
628 .into_iter()
629 .map(|(i, c)| match self.reconstruct_const(c) {
630 (Statement::Const(declaration), _) => (i, declaration),
631 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
632 })
633 .collect(),
634 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
635 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
636 }
637 }
638
639 fn reconstruct_function(&mut self, input: Function) -> Function {
640 Function {
641 annotations: input.annotations,
642 variant: input.variant,
643 identifier: input.identifier,
644 const_parameters: input
645 .const_parameters
646 .iter()
647 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
648 .collect(),
649 input: input
650 .input
651 .iter()
652 .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
653 .collect(),
654 output: input
655 .output
656 .iter()
657 .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
658 .collect(),
659 output_type: self.reconstruct_type(input.output_type).0,
660 block: self.reconstruct_block(input.block).0,
661 span: input.span,
662 id: input.id,
663 }
664 }
665
666 fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
667 Constructor {
668 annotations: input.annotations,
669 block: self.reconstruct_block(input.block).0,
670 span: input.span,
671 id: input.id,
672 }
673 }
674
675 fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
676 input
677 }
678
679 fn reconstruct_struct(&mut self, input: Composite) -> Composite {
680 Composite {
681 const_parameters: input
682 .const_parameters
683 .iter()
684 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
685 .collect(),
686 members: input
687 .members
688 .iter()
689 .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
690 .collect(),
691 ..input
692 }
693 }
694
695 fn reconstruct_import(&mut self, input: Program) -> Program {
696 self.reconstruct_program(input)
697 }
698
699 fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
700 Mapping {
701 key_type: self.reconstruct_type(input.key_type).0,
702 value_type: self.reconstruct_type(input.value_type).0,
703 ..input
704 }
705 }
706
707 fn reconstruct_storage_variable(&mut self, input: StorageVariable) -> StorageVariable {
708 StorageVariable { type_: self.reconstruct_type(input.type_).0, ..input }
709 }
710}