1use crate::*;
22
23pub trait AstReconstructor {
25 type AdditionalOutput: Default;
26 type AdditionalInput: Default;
27
28 fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
30 match input {
31 Type::Array(array_type) => self.reconstruct_array_type(array_type),
32 Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
33 Type::Future(future_type) => self.reconstruct_future_type(future_type),
34 Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
35 Type::Optional(optional_type) => self.reconstruct_optional_type(optional_type),
36 Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
37 Type::Vector(vector_type) => self.reconstruct_vector_type(vector_type),
38 Type::Address
39 | Type::Boolean
40 | Type::Field
41 | Type::Group
42 | Type::Identifier(_)
43 | Type::Integer(_)
44 | Type::Scalar
45 | Type::Signature
46 | Type::String
47 | Type::Numeric
48 | Type::Unit
49 | Type::Err => (input.clone(), Default::default()),
50 }
51 }
52
53 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
54 (
55 Type::Array(ArrayType {
56 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
57 length: Box::new(self.reconstruct_expression(*input.length, &Default::default()).0),
58 }),
59 Default::default(),
60 )
61 }
62
63 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
64 (
65 Type::Composite(CompositeType {
66 const_arguments: input
67 .const_arguments
68 .into_iter()
69 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
70 .collect(),
71 ..input
72 }),
73 Default::default(),
74 )
75 }
76
77 fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
78 (
79 Type::Future(FutureType {
80 inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
81 ..input
82 }),
83 Default::default(),
84 )
85 }
86
87 fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
88 (
89 Type::Mapping(MappingType {
90 key: Box::new(self.reconstruct_type(*input.key).0),
91 value: Box::new(self.reconstruct_type(*input.value).0),
92 ..input
93 }),
94 Default::default(),
95 )
96 }
97
98 fn reconstruct_optional_type(&mut self, input: OptionalType) -> (Type, Self::AdditionalOutput) {
99 (Type::Optional(OptionalType { inner: Box::new(self.reconstruct_type(*input.inner).0) }), Default::default())
100 }
101
102 fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
103 (
104 Type::Tuple(TupleType {
105 elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
106 }),
107 Default::default(),
108 )
109 }
110
111 fn reconstruct_vector_type(&mut self, input: VectorType) -> (Type, Self::AdditionalOutput) {
112 (
113 Type::Vector(VectorType { element_type: Box::new(self.reconstruct_type(*input.element_type).0) }),
114 Default::default(),
115 )
116 }
117
118 fn reconstruct_expression(
120 &mut self,
121 input: Expression,
122 additional: &Self::AdditionalInput,
123 ) -> (Expression, Self::AdditionalOutput) {
124 match input {
125 Expression::Async(async_) => self.reconstruct_async(async_, additional),
126 Expression::Array(array) => self.reconstruct_array(array, additional),
127 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access, additional),
128 Expression::Binary(binary) => self.reconstruct_binary(*binary, additional),
129 Expression::Call(call) => self.reconstruct_call(*call, additional),
130 Expression::Cast(cast) => self.reconstruct_cast(*cast, additional),
131 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_, additional),
132 Expression::Err(err) => self.reconstruct_err(err, additional),
133 Expression::Path(path) => self.reconstruct_path(path, additional),
134 Expression::Literal(value) => self.reconstruct_literal(value, additional),
135 Expression::Locator(locator) => self.reconstruct_locator(locator, additional),
136 Expression::MemberAccess(access) => self.reconstruct_member_access(*access, additional),
137 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat, additional),
138 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary, additional),
139 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple, additional),
140 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access, additional),
141 Expression::Unary(unary) => self.reconstruct_unary(*unary, additional),
142 Expression::Unit(unit) => self.reconstruct_unit(unit, additional),
143 Expression::Intrinsic(intr) => self.reconstruct_intrinsic(*intr, additional),
144 }
145 }
146
147 fn reconstruct_array_access(
148 &mut self,
149 input: ArrayAccess,
150 _additional: &Self::AdditionalInput,
151 ) -> (Expression, Self::AdditionalOutput) {
152 (
153 ArrayAccess {
154 array: self.reconstruct_expression(input.array, &Default::default()).0,
155 index: self.reconstruct_expression(input.index, &Default::default()).0,
156 ..input
157 }
158 .into(),
159 Default::default(),
160 )
161 }
162
163 fn reconstruct_async(
164 &mut self,
165 input: AsyncExpression,
166 _additional: &Self::AdditionalInput,
167 ) -> (Expression, Self::AdditionalOutput) {
168 (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
169 }
170
171 fn reconstruct_member_access(
172 &mut self,
173 input: MemberAccess,
174 _additional: &Self::AdditionalInput,
175 ) -> (Expression, Self::AdditionalOutput) {
176 (
177 MemberAccess { inner: self.reconstruct_expression(input.inner, &Default::default()).0, ..input }.into(),
178 Default::default(),
179 )
180 }
181
182 fn reconstruct_repeat(
183 &mut self,
184 input: RepeatExpression,
185 _additional: &Self::AdditionalInput,
186 ) -> (Expression, Self::AdditionalOutput) {
187 (
188 RepeatExpression {
189 expr: self.reconstruct_expression(input.expr, &Default::default()).0,
190 count: self.reconstruct_expression(input.count, &Default::default()).0,
191 ..input
192 }
193 .into(),
194 Default::default(),
195 )
196 }
197
198 fn reconstruct_intrinsic(
199 &mut self,
200 input: IntrinsicExpression,
201 _additional: &Self::AdditionalInput,
202 ) -> (Expression, Self::AdditionalOutput) {
203 (
204 IntrinsicExpression {
205 arguments: input
206 .arguments
207 .into_iter()
208 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
209 .collect(),
210 ..input
211 }
212 .into(),
213 Default::default(),
214 )
215 }
216
217 fn reconstruct_tuple_access(
218 &mut self,
219 input: TupleAccess,
220 _additional: &Self::AdditionalInput,
221 ) -> (Expression, Self::AdditionalOutput) {
222 (
223 TupleAccess { tuple: self.reconstruct_expression(input.tuple, &Default::default()).0, ..input }.into(),
224 Default::default(),
225 )
226 }
227
228 fn reconstruct_array(
229 &mut self,
230 input: ArrayExpression,
231 _additional: &Self::AdditionalInput,
232 ) -> (Expression, Self::AdditionalOutput) {
233 (
234 ArrayExpression {
235 elements: input
236 .elements
237 .into_iter()
238 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
239 .collect(),
240 ..input
241 }
242 .into(),
243 Default::default(),
244 )
245 }
246
247 fn reconstruct_binary(
248 &mut self,
249 input: BinaryExpression,
250 _additional: &Self::AdditionalInput,
251 ) -> (Expression, Self::AdditionalOutput) {
252 (
253 BinaryExpression {
254 left: self.reconstruct_expression(input.left, &Default::default()).0,
255 right: self.reconstruct_expression(input.right, &Default::default()).0,
256 ..input
257 }
258 .into(),
259 Default::default(),
260 )
261 }
262
263 fn reconstruct_call(
264 &mut self,
265 input: CallExpression,
266 _additional: &Self::AdditionalInput,
267 ) -> (Expression, Self::AdditionalOutput) {
268 (
269 CallExpression {
270 const_arguments: input
271 .const_arguments
272 .into_iter()
273 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
274 .collect(),
275 arguments: input
276 .arguments
277 .into_iter()
278 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
279 .collect(),
280 ..input
281 }
282 .into(),
283 Default::default(),
284 )
285 }
286
287 fn reconstruct_cast(
288 &mut self,
289 input: CastExpression,
290 _additional: &Self::AdditionalInput,
291 ) -> (Expression, Self::AdditionalOutput) {
292 (
293 CastExpression {
294 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
295 ..input
296 }
297 .into(),
298 Default::default(),
299 )
300 }
301
302 fn reconstruct_struct_init(
303 &mut self,
304 input: StructExpression,
305 _additional: &Self::AdditionalInput,
306 ) -> (Expression, Self::AdditionalOutput) {
307 (
308 StructExpression {
309 const_arguments: input
310 .const_arguments
311 .into_iter()
312 .map(|arg| self.reconstruct_expression(arg, &Default::default()).0)
313 .collect(),
314 members: input
315 .members
316 .into_iter()
317 .map(|member| StructVariableInitializer {
318 identifier: member.identifier,
319 expression: member
320 .expression
321 .map(|expr| self.reconstruct_expression(expr, &Default::default()).0),
322 span: member.span,
323 id: member.id,
324 })
325 .collect(),
326 ..input
327 }
328 .into(),
329 Default::default(),
330 )
331 }
332
333 fn reconstruct_err(
334 &mut self,
335 _input: ErrExpression,
336 _additional: &Self::AdditionalInput,
337 ) -> (Expression, Self::AdditionalOutput) {
338 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
339 }
340
341 fn reconstruct_path(
342 &mut self,
343 input: Path,
344 _additional: &Self::AdditionalInput,
345 ) -> (Expression, Self::AdditionalOutput) {
346 (input.into(), Default::default())
347 }
348
349 fn reconstruct_literal(
350 &mut self,
351 input: Literal,
352 _additional: &Self::AdditionalInput,
353 ) -> (Expression, Self::AdditionalOutput) {
354 (input.into(), Default::default())
355 }
356
357 fn reconstruct_locator(
358 &mut self,
359 input: LocatorExpression,
360 _additional: &Self::AdditionalInput,
361 ) -> (Expression, Self::AdditionalOutput) {
362 (input.into(), Default::default())
363 }
364
365 fn reconstruct_ternary(
366 &mut self,
367 input: TernaryExpression,
368 _additional: &Self::AdditionalInput,
369 ) -> (Expression, Self::AdditionalOutput) {
370 (
371 TernaryExpression {
372 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
373 if_true: self.reconstruct_expression(input.if_true, &Default::default()).0,
374 if_false: self.reconstruct_expression(input.if_false, &Default::default()).0,
375 span: input.span,
376 id: input.id,
377 }
378 .into(),
379 Default::default(),
380 )
381 }
382
383 fn reconstruct_tuple(
384 &mut self,
385 input: TupleExpression,
386 _additional: &Self::AdditionalInput,
387 ) -> (Expression, Self::AdditionalOutput) {
388 (
389 TupleExpression {
390 elements: input
391 .elements
392 .into_iter()
393 .map(|element| self.reconstruct_expression(element, &Default::default()).0)
394 .collect(),
395 ..input
396 }
397 .into(),
398 Default::default(),
399 )
400 }
401
402 fn reconstruct_unary(
403 &mut self,
404 input: UnaryExpression,
405 _additional: &Self::AdditionalInput,
406 ) -> (Expression, Self::AdditionalOutput) {
407 (
408 UnaryExpression { receiver: self.reconstruct_expression(input.receiver, &Default::default()).0, ..input }
409 .into(),
410 Default::default(),
411 )
412 }
413
414 fn reconstruct_unit(
415 &mut self,
416 input: UnitExpression,
417 _additional: &Self::AdditionalInput,
418 ) -> (Expression, Self::AdditionalOutput) {
419 (input.into(), Default::default())
420 }
421
422 fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
424 match input {
425 Statement::Assert(assert) => self.reconstruct_assert(assert),
426 Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
427 Statement::Block(stmt) => {
428 let (stmt, output) = self.reconstruct_block(stmt);
429 (stmt.into(), output)
430 }
431 Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
432 Statement::Const(stmt) => self.reconstruct_const(stmt),
433 Statement::Definition(stmt) => self.reconstruct_definition(stmt),
434 Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
435 Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
436 Statement::Return(stmt) => self.reconstruct_return(stmt),
437 }
438 }
439
440 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
441 (
442 AssertStatement {
443 variant: match input.variant {
444 AssertVariant::Assert(expr) => {
445 AssertVariant::Assert(self.reconstruct_expression(expr, &Default::default()).0)
446 }
447 AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
448 self.reconstruct_expression(left, &Default::default()).0,
449 self.reconstruct_expression(right, &Default::default()).0,
450 ),
451 AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
452 self.reconstruct_expression(left, &Default::default()).0,
453 self.reconstruct_expression(right, &Default::default()).0,
454 ),
455 },
456 ..input
457 }
458 .into(),
459 Default::default(),
460 )
461 }
462
463 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
464 (
465 AssignStatement {
466 place: self.reconstruct_expression(input.place, &Default::default()).0,
467 value: self.reconstruct_expression(input.value, &Default::default()).0,
468 ..input
469 }
470 .into(),
471 Default::default(),
472 )
473 }
474
475 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
476 (
477 Block {
478 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
479 span: input.span,
480 id: input.id,
481 },
482 Default::default(),
483 )
484 }
485
486 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
487 (
488 ConditionalStatement {
489 condition: self.reconstruct_expression(input.condition, &Default::default()).0,
490 then: self.reconstruct_block(input.then).0,
491 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
492 ..input
493 }
494 .into(),
495 Default::default(),
496 )
497 }
498
499 fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
500 (
501 ConstDeclaration {
502 type_: self.reconstruct_type(input.type_).0,
503 value: self.reconstruct_expression(input.value, &Default::default()).0,
504 ..input
505 }
506 .into(),
507 Default::default(),
508 )
509 }
510
511 fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
512 (
513 DefinitionStatement {
514 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
515 value: self.reconstruct_expression(input.value, &Default::default()).0,
516 ..input
517 }
518 .into(),
519 Default::default(),
520 )
521 }
522
523 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
524 (
525 ExpressionStatement {
526 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
527 ..input
528 }
529 .into(),
530 Default::default(),
531 )
532 }
533
534 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
535 (
536 IterationStatement {
537 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
538 start: self.reconstruct_expression(input.start, &Default::default()).0,
539 stop: self.reconstruct_expression(input.stop, &Default::default()).0,
540 block: self.reconstruct_block(input.block).0,
541 ..input
542 }
543 .into(),
544 Default::default(),
545 )
546 }
547
548 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
549 (
550 ReturnStatement {
551 expression: self.reconstruct_expression(input.expression, &Default::default()).0,
552 ..input
553 }
554 .into(),
555 Default::default(),
556 )
557 }
558}
559
560pub trait ProgramReconstructor: AstReconstructor {
562 fn reconstruct_program(&mut self, input: Program) -> Program {
563 let program_scopes =
564 input.program_scopes.into_iter().map(|(id, scope)| (id, self.reconstruct_program_scope(scope))).collect();
565 Program {
566 imports: input
567 .imports
568 .into_iter()
569 .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
570 .collect(),
571 stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
572 modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
573 program_scopes,
574 }
575 }
576
577 fn reconstruct_stub(&mut self, input: Stub) -> Stub {
578 Stub {
579 imports: input.imports,
580 stub_id: input.stub_id,
581 consts: input.consts,
582 structs: input.structs,
583 mappings: input.mappings,
584 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
585 span: input.span,
586 }
587 }
588
589 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
590 ProgramScope {
591 program_id: input.program_id,
592 consts: input
593 .consts
594 .into_iter()
595 .map(|(i, c)| match self.reconstruct_const(c) {
596 (Statement::Const(declaration), _) => (i, declaration),
597 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
598 })
599 .collect(),
600 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
601 mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
602 storage_variables: input
603 .storage_variables
604 .into_iter()
605 .map(|(id, storage_variable)| (id, self.reconstruct_storage_variable(storage_variable)))
606 .collect(),
607 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
608 constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
609 span: input.span,
610 }
611 }
612
613 fn reconstruct_module(&mut self, input: Module) -> Module {
614 Module {
615 program_name: input.program_name,
616 path: input.path,
617 consts: input
618 .consts
619 .into_iter()
620 .map(|(i, c)| match self.reconstruct_const(c) {
621 (Statement::Const(declaration), _) => (i, declaration),
622 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
623 })
624 .collect(),
625 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
626 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
627 }
628 }
629
630 fn reconstruct_function(&mut self, input: Function) -> Function {
631 Function {
632 annotations: input.annotations,
633 variant: input.variant,
634 identifier: input.identifier,
635 const_parameters: input
636 .const_parameters
637 .iter()
638 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
639 .collect(),
640 input: input
641 .input
642 .iter()
643 .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
644 .collect(),
645 output: input
646 .output
647 .iter()
648 .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
649 .collect(),
650 output_type: self.reconstruct_type(input.output_type).0,
651 block: self.reconstruct_block(input.block).0,
652 span: input.span,
653 id: input.id,
654 }
655 }
656
657 fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
658 Constructor {
659 annotations: input.annotations,
660 block: self.reconstruct_block(input.block).0,
661 span: input.span,
662 id: input.id,
663 }
664 }
665
666 fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
667 input
668 }
669
670 fn reconstruct_struct(&mut self, input: Composite) -> Composite {
671 Composite {
672 const_parameters: input
673 .const_parameters
674 .iter()
675 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
676 .collect(),
677 members: input
678 .members
679 .iter()
680 .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
681 .collect(),
682 ..input
683 }
684 }
685
686 fn reconstruct_import(&mut self, input: Program) -> Program {
687 self.reconstruct_program(input)
688 }
689
690 fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
691 Mapping {
692 key_type: self.reconstruct_type(input.key_type).0,
693 value_type: self.reconstruct_type(input.value_type).0,
694 ..input
695 }
696 }
697
698 fn reconstruct_storage_variable(&mut self, input: StorageVariable) -> StorageVariable {
699 StorageVariable { type_: self.reconstruct_type(input.type_).0, ..input }
700 }
701}