1use crate::*;
22
23pub trait AstReconstructor {
25 type AdditionalOutput: Default;
26
27 fn reconstruct_type(&mut self, input: Type) -> (Type, Self::AdditionalOutput) {
29 match input {
30 Type::Array(array_type) => self.reconstruct_array_type(array_type),
31 Type::Composite(composite_type) => self.reconstruct_composite_type(composite_type),
32 Type::Future(future_type) => self.reconstruct_future_type(future_type),
33 Type::Mapping(mapping_type) => self.reconstruct_mapping_type(mapping_type),
34 Type::Tuple(tuple_type) => self.reconstruct_tuple_type(tuple_type),
35 Type::Address
36 | Type::Boolean
37 | Type::Field
38 | Type::Group
39 | Type::Identifier(_)
40 | Type::Integer(_)
41 | Type::Scalar
42 | Type::Signature
43 | Type::String
44 | Type::Numeric
45 | Type::Unit
46 | Type::Err => (input.clone(), Default::default()),
47 }
48 }
49
50 fn reconstruct_array_type(&mut self, input: ArrayType) -> (Type, Self::AdditionalOutput) {
51 (
52 Type::Array(ArrayType {
53 element_type: Box::new(self.reconstruct_type(*input.element_type).0),
54 length: Box::new(self.reconstruct_expression(*input.length).0),
55 }),
56 Default::default(),
57 )
58 }
59
60 fn reconstruct_composite_type(&mut self, input: CompositeType) -> (Type, Self::AdditionalOutput) {
61 (
62 Type::Composite(CompositeType {
63 const_arguments: input
64 .const_arguments
65 .into_iter()
66 .map(|arg| self.reconstruct_expression(arg).0)
67 .collect(),
68 ..input
69 }),
70 Default::default(),
71 )
72 }
73
74 fn reconstruct_future_type(&mut self, input: FutureType) -> (Type, Self::AdditionalOutput) {
75 (
76 Type::Future(FutureType {
77 inputs: input.inputs.into_iter().map(|input| self.reconstruct_type(input).0).collect(),
78 ..input
79 }),
80 Default::default(),
81 )
82 }
83
84 fn reconstruct_mapping_type(&mut self, input: MappingType) -> (Type, Self::AdditionalOutput) {
85 (
86 Type::Mapping(MappingType {
87 key: Box::new(self.reconstruct_type(*input.key).0),
88 value: Box::new(self.reconstruct_type(*input.value).0),
89 ..input
90 }),
91 Default::default(),
92 )
93 }
94
95 fn reconstruct_tuple_type(&mut self, input: TupleType) -> (Type, Self::AdditionalOutput) {
96 (
97 Type::Tuple(TupleType {
98 elements: input.elements.into_iter().map(|element| self.reconstruct_type(element).0).collect(),
99 }),
100 Default::default(),
101 )
102 }
103
104 fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
106 match input {
107 Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
108 Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
109 Expression::Async(async_) => self.reconstruct_async(async_),
110 Expression::Array(array) => self.reconstruct_array(array),
111 Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
112 Expression::Binary(binary) => self.reconstruct_binary(*binary),
113 Expression::Call(call) => self.reconstruct_call(*call),
114 Expression::Cast(cast) => self.reconstruct_cast(*cast),
115 Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
116 Expression::Err(err) => self.reconstruct_err(err),
117 Expression::Path(path) => self.reconstruct_path(path),
118 Expression::Literal(value) => self.reconstruct_literal(value),
119 Expression::Locator(locator) => self.reconstruct_locator(locator),
120 Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
121 Expression::Repeat(repeat) => self.reconstruct_repeat(*repeat),
122 Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
123 Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
124 Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
125 Expression::Unary(unary) => self.reconstruct_unary(*unary),
126 Expression::Unit(unit) => self.reconstruct_unit(unit),
127 }
128 }
129
130 fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
131 (
132 ArrayAccess {
133 array: self.reconstruct_expression(input.array).0,
134 index: self.reconstruct_expression(input.index).0,
135 ..input
136 }
137 .into(),
138 Default::default(),
139 )
140 }
141
142 fn reconstruct_associated_constant(
143 &mut self,
144 input: AssociatedConstantExpression,
145 ) -> (Expression, Self::AdditionalOutput) {
146 (input.into(), Default::default())
147 }
148
149 fn reconstruct_associated_function(
150 &mut self,
151 input: AssociatedFunctionExpression,
152 ) -> (Expression, Self::AdditionalOutput) {
153 (
154 AssociatedFunctionExpression {
155 arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
156 ..input
157 }
158 .into(),
159 Default::default(),
160 )
161 }
162
163 fn reconstruct_async(&mut self, input: AsyncExpression) -> (Expression, Self::AdditionalOutput) {
164 (AsyncExpression { block: self.reconstruct_block(input.block).0, ..input }.into(), Default::default())
165 }
166
167 fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
168 (MemberAccess { inner: self.reconstruct_expression(input.inner).0, ..input }.into(), Default::default())
169 }
170
171 fn reconstruct_repeat(&mut self, input: RepeatExpression) -> (Expression, Self::AdditionalOutput) {
172 (
173 RepeatExpression {
174 expr: self.reconstruct_expression(input.expr).0,
175 count: self.reconstruct_expression(input.count).0,
176 ..input
177 }
178 .into(),
179 Default::default(),
180 )
181 }
182
183 fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
184 (TupleAccess { tuple: self.reconstruct_expression(input.tuple).0, ..input }.into(), Default::default())
185 }
186
187 fn reconstruct_array(&mut self, input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
188 (
189 ArrayExpression {
190 elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
191 ..input
192 }
193 .into(),
194 Default::default(),
195 )
196 }
197
198 fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
199 (
200 BinaryExpression {
201 left: self.reconstruct_expression(input.left).0,
202 right: self.reconstruct_expression(input.right).0,
203 ..input
204 }
205 .into(),
206 Default::default(),
207 )
208 }
209
210 fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
211 (
212 CallExpression {
213 const_arguments: input
214 .const_arguments
215 .into_iter()
216 .map(|arg| self.reconstruct_expression(arg).0)
217 .collect(),
218 arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
219 ..input
220 }
221 .into(),
222 Default::default(),
223 )
224 }
225
226 fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
227 (
228 CastExpression { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
229 Default::default(),
230 )
231 }
232
233 fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
234 (
235 StructExpression {
236 const_arguments: input
237 .const_arguments
238 .into_iter()
239 .map(|arg| self.reconstruct_expression(arg).0)
240 .collect(),
241 members: input
242 .members
243 .into_iter()
244 .map(|member| StructVariableInitializer {
245 identifier: member.identifier,
246 expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
247 span: member.span,
248 id: member.id,
249 })
250 .collect(),
251 ..input
252 }
253 .into(),
254 Default::default(),
255 )
256 }
257
258 fn reconstruct_err(&mut self, _input: ErrExpression) -> (Expression, Self::AdditionalOutput) {
259 panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
260 }
261
262 fn reconstruct_path(&mut self, input: Path) -> (Expression, Self::AdditionalOutput) {
263 (input.into(), Default::default())
264 }
265
266 fn reconstruct_literal(&mut self, input: Literal) -> (Expression, Self::AdditionalOutput) {
267 (input.into(), Default::default())
268 }
269
270 fn reconstruct_locator(&mut self, input: LocatorExpression) -> (Expression, Self::AdditionalOutput) {
271 (input.into(), Default::default())
272 }
273
274 fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
275 (
276 TernaryExpression {
277 condition: self.reconstruct_expression(input.condition).0,
278 if_true: self.reconstruct_expression(input.if_true).0,
279 if_false: self.reconstruct_expression(input.if_false).0,
280 span: input.span,
281 id: input.id,
282 }
283 .into(),
284 Default::default(),
285 )
286 }
287
288 fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) {
289 (
290 TupleExpression {
291 elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
292 ..input
293 }
294 .into(),
295 Default::default(),
296 )
297 }
298
299 fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
300 (
301 UnaryExpression { receiver: self.reconstruct_expression(input.receiver).0, ..input }.into(),
302 Default::default(),
303 )
304 }
305
306 fn reconstruct_unit(&mut self, input: UnitExpression) -> (Expression, Self::AdditionalOutput) {
307 (input.into(), Default::default())
308 }
309
310 fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
311 match input {
312 Statement::Assert(assert) => self.reconstruct_assert(assert),
313 Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
314 Statement::Block(stmt) => {
315 let (stmt, output) = self.reconstruct_block(stmt);
316 (stmt.into(), output)
317 }
318 Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
319 Statement::Const(stmt) => self.reconstruct_const(stmt),
320 Statement::Definition(stmt) => self.reconstruct_definition(stmt),
321 Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
322 Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
323 Statement::Return(stmt) => self.reconstruct_return(stmt),
324 }
325 }
326
327 fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
328 (
329 AssertStatement {
330 variant: match input.variant {
331 AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
332 AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
333 self.reconstruct_expression(left).0,
334 self.reconstruct_expression(right).0,
335 ),
336 AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
337 self.reconstruct_expression(left).0,
338 self.reconstruct_expression(right).0,
339 ),
340 },
341 ..input
342 }
343 .into(),
344 Default::default(),
345 )
346 }
347
348 fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
349 (
350 AssignStatement {
351 place: self.reconstruct_expression(input.place).0,
352 value: self.reconstruct_expression(input.value).0,
353 ..input
354 }
355 .into(),
356 Default::default(),
357 )
358 }
359
360 fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
361 (
362 Block {
363 statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
364 span: input.span,
365 id: input.id,
366 },
367 Default::default(),
368 )
369 }
370
371 fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
372 (
373 ConditionalStatement {
374 condition: self.reconstruct_expression(input.condition).0,
375 then: self.reconstruct_block(input.then).0,
376 otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
377 ..input
378 }
379 .into(),
380 Default::default(),
381 )
382 }
383
384 fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
385 (
386 ConstDeclaration {
387 type_: self.reconstruct_type(input.type_).0,
388 value: self.reconstruct_expression(input.value).0,
389 ..input
390 }
391 .into(),
392 Default::default(),
393 )
394 }
395
396 fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
397 (
398 DefinitionStatement {
399 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
400 value: self.reconstruct_expression(input.value).0,
401 ..input
402 }
403 .into(),
404 Default::default(),
405 )
406 }
407
408 fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
409 (
410 ExpressionStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
411 Default::default(),
412 )
413 }
414
415 fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
416 (
417 IterationStatement {
418 type_: input.type_.map(|ty| self.reconstruct_type(ty).0),
419 start: self.reconstruct_expression(input.start).0,
420 stop: self.reconstruct_expression(input.stop).0,
421 block: self.reconstruct_block(input.block).0,
422 ..input
423 }
424 .into(),
425 Default::default(),
426 )
427 }
428
429 fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
430 (
431 ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
432 Default::default(),
433 )
434 }
435}
436
437pub trait ProgramReconstructor: AstReconstructor {
439 fn reconstruct_program(&mut self, input: Program) -> Program {
440 Program {
441 imports: input
442 .imports
443 .into_iter()
444 .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
445 .collect(),
446 stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
447 modules: input.modules.into_iter().map(|(id, module)| (id, self.reconstruct_module(module))).collect(),
448 program_scopes: input
449 .program_scopes
450 .into_iter()
451 .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
452 .collect(),
453 }
454 }
455
456 fn reconstruct_stub(&mut self, input: Stub) -> Stub {
457 Stub {
458 imports: input.imports,
459 stub_id: input.stub_id,
460 consts: input.consts,
461 structs: input.structs,
462 mappings: input.mappings,
463 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
464 span: input.span,
465 }
466 }
467
468 fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
469 ProgramScope {
470 program_id: input.program_id,
471 consts: input
472 .consts
473 .into_iter()
474 .map(|(i, c)| match self.reconstruct_const(c) {
475 (Statement::Const(declaration), _) => (i, declaration),
476 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
477 })
478 .collect(),
479 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
480 mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
481 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
482 constructor: input.constructor.map(|c| self.reconstruct_constructor(c)),
483 span: input.span,
484 }
485 }
486
487 fn reconstruct_module(&mut self, input: Module) -> Module {
488 Module {
489 program_name: input.program_name,
490 path: input.path,
491 consts: input
492 .consts
493 .into_iter()
494 .map(|(i, c)| match self.reconstruct_const(c) {
495 (Statement::Const(declaration), _) => (i, declaration),
496 _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
497 })
498 .collect(),
499 structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
500 functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
501 }
502 }
503
504 fn reconstruct_function(&mut self, input: Function) -> Function {
505 Function {
506 annotations: input.annotations,
507 variant: input.variant,
508 identifier: input.identifier,
509 const_parameters: input
510 .const_parameters
511 .iter()
512 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
513 .collect(),
514 input: input
515 .input
516 .iter()
517 .map(|input| Input { type_: self.reconstruct_type(input.type_.clone()).0, ..input.clone() })
518 .collect(),
519 output: input
520 .output
521 .iter()
522 .map(|output| Output { type_: self.reconstruct_type(output.type_.clone()).0, ..output.clone() })
523 .collect(),
524 output_type: self.reconstruct_type(input.output_type).0,
525 block: self.reconstruct_block(input.block).0,
526 span: input.span,
527 id: input.id,
528 }
529 }
530
531 fn reconstruct_constructor(&mut self, input: Constructor) -> Constructor {
532 Constructor {
533 annotations: input.annotations,
534 block: self.reconstruct_block(input.block).0,
535 span: input.span,
536 id: input.id,
537 }
538 }
539
540 fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
541 input
542 }
543
544 fn reconstruct_struct(&mut self, input: Composite) -> Composite {
545 Composite {
546 const_parameters: input
547 .const_parameters
548 .iter()
549 .map(|param| ConstParameter { type_: self.reconstruct_type(param.type_.clone()).0, ..param.clone() })
550 .collect(),
551 members: input
552 .members
553 .iter()
554 .map(|member| Member { type_: self.reconstruct_type(member.type_.clone()).0, ..member.clone() })
555 .collect(),
556 ..input
557 }
558 }
559
560 fn reconstruct_import(&mut self, input: Program) -> Program {
561 self.reconstruct_program(input)
562 }
563
564 fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
565 Mapping {
566 key_type: self.reconstruct_type(input.key_type).0,
567 value_type: self.reconstruct_type(input.value_type).0,
568 ..input
569 }
570 }
571}