leo_passes/type_checking/
ast.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::*;
18use crate::{VariableSymbol, VariableType};
19
20use leo_ast::{
21    Type::{Future, Tuple},
22    *,
23};
24use leo_errors::{TypeCheckerError, TypeCheckerWarning};
25use leo_span::{Span, Symbol, sym};
26
27use itertools::Itertools as _;
28
29impl TypeCheckingVisitor<'_> {
30    pub fn visit_expression_assign(&mut self, input: &Expression) -> Type {
31        let ty = match input {
32            Expression::ArrayAccess(array_access) => self.visit_array_access_general(array_access, true, &None),
33            Expression::Path(path) if path.qualifier().is_empty() => self.visit_path_assign(path),
34            Expression::MemberAccess(member_access) => self.visit_member_access_general(member_access, true, &None),
35            Expression::TupleAccess(tuple_access) => self.visit_tuple_access_general(tuple_access, true, &None),
36            _ => {
37                self.emit_err(TypeCheckerError::invalid_assignment_target(input, input.span()));
38                Type::Err
39            }
40        };
41
42        // Prohibit assignment to an external record in a narrower conditional scope.
43        let external_record = self.is_external_record(&ty);
44        let external_record_tuple =
45            matches!(&ty, Type::Tuple(tuple) if tuple.elements().iter().any(|ty| self.is_external_record(ty)));
46
47        if external_record || external_record_tuple {
48            let Expression::Path(path) = input else {
49                // This is not valid Leo and will have triggered an error elsewhere.
50                return Type::Err;
51            };
52
53            if !self.symbol_in_conditional_scope(path.identifier().name) {
54                if external_record {
55                    self.emit_err(TypeCheckerError::assignment_to_external_record_cond(&ty, input.span()));
56                } else {
57                    // Note that this will cover both assigning to a tuple variable and assigning to a member of a tuple.
58                    self.emit_err(TypeCheckerError::assignment_to_external_record_tuple_cond(&ty, input.span()));
59                }
60            }
61        }
62
63        // Prohibit reassignment of futures.
64        if let Type::Future(..) = ty {
65            self.emit_err(TypeCheckerError::cannot_reassign_future_variable(input, input.span()));
66        }
67
68        // Prohibit reassignment of mappings.
69        if let Type::Mapping(_) = ty {
70            self.emit_err(TypeCheckerError::cannot_reassign_mapping(input, input.span()));
71        }
72
73        // Add the expression and its associated type to the type table.
74        self.state.type_table.insert(input.id(), ty.clone());
75        ty
76    }
77
78    pub fn visit_array_access_general(&mut self, input: &ArrayAccess, assign: bool, expected: &Option<Type>) -> Type {
79        // Check that the expression is an array.
80        let this_type = if assign {
81            self.visit_expression_assign(&input.array)
82        } else {
83            self.visit_expression(&input.array, &None)
84        };
85        self.assert_array_type(&this_type, input.array.span());
86
87        // Check that the index is an integer type.
88        let mut index_type = self.visit_expression(&input.index, &None);
89
90        if index_type == Type::Numeric {
91            // If the index has type `Numeric`, then it's an unsuffixed literal. Just infer its type to be `u32` and
92            // then check it's validity as a `u32`.
93            index_type = Type::Integer(IntegerType::U32);
94            if let Expression::Literal(literal) = &input.index {
95                self.check_numeric_literal(literal, &index_type);
96            }
97        }
98
99        self.assert_int_type(&index_type, input.index.span());
100
101        // Keep track of the type of the index in the type table.
102        // This is important for when the index is an unsuffixed literal.
103        self.state.type_table.insert(input.index.id(), index_type.clone());
104
105        // Get the element type of the array.
106        let Type::Array(array_type) = this_type else {
107            // We must have already reported an error above, in our type assertion.
108            return Type::Err;
109        };
110
111        let element_type = array_type.element_type();
112
113        // If the expected type is known, then check that the element type is the same as the expected type.
114        self.maybe_assert_type(element_type, expected, input.span());
115
116        // Return the element type of the array.
117        element_type.clone()
118    }
119
120    pub fn visit_member_access_general(&mut self, input: &MemberAccess, assign: bool, expected: &Option<Type>) -> Type {
121        // Handler member access expressions that correspond to valid operands in AVM code.
122        if !assign {
123            match &input.inner {
124                // If the access expression is of the form `self.<name>`, then check the <name> is valid.
125                Expression::Path(path) if path.identifier().name == sym::SelfLower => {
126                    match input.name.name {
127                        sym::address => {
128                            return Type::Address;
129                        }
130                        sym::caller => {
131                            // Check that the operation is not invoked in a `finalize` block.
132                            self.check_access_allowed("self.caller", false, input.name.span());
133                            let ty = Type::Address;
134                            self.maybe_assert_type(&ty, expected, input.span());
135                            return ty;
136                        }
137                        sym::checksum => {
138                            return Type::Array(ArrayType::new(
139                                Type::Integer(IntegerType::U8),
140                                Expression::Literal(Literal::integer(
141                                    IntegerType::U8,
142                                    "32".to_string(),
143                                    Default::default(),
144                                    Default::default(),
145                                )),
146                            ));
147                        }
148                        sym::edition => {
149                            return Type::Integer(IntegerType::U16);
150                        }
151                        sym::id => {
152                            return Type::Address;
153                        }
154                        sym::program_owner => {
155                            // Check that the operation is only invoked in a `finalize` block.
156                            self.check_access_allowed("self.program_owner", true, input.name.span());
157                            return Type::Address;
158                        }
159                        sym::signer => {
160                            // Check that operation is not invoked in a `finalize` block.
161                            self.check_access_allowed("self.signer", false, input.name.span());
162                            let ty = Type::Address;
163                            self.maybe_assert_type(&ty, expected, input.span());
164                            return ty;
165                        }
166                        _ => {
167                            self.emit_err(TypeCheckerError::invalid_self_access(input.name.span()));
168                            return Type::Err;
169                        }
170                    }
171                }
172                // If the access expression is of the form `block.<name>`, then check the <name> is valid.
173                Expression::Path(path) if path.identifier().name == sym::block => match input.name.name {
174                    sym::height => {
175                        // Check that the operation is invoked in a `finalize` block.
176                        self.check_access_allowed("block.height", true, input.name.span());
177                        let ty = Type::Integer(IntegerType::U32);
178                        self.maybe_assert_type(&ty, expected, input.span());
179                        return ty;
180                    }
181                    _ => {
182                        self.emit_err(TypeCheckerError::invalid_block_access(input.name.span()));
183                        return Type::Err;
184                    }
185                },
186                // If the access expression is of the form `network.<name>`, then check that the <name> is valid.
187                Expression::Path(path) if path.identifier().name == sym::network => match input.name.name {
188                    sym::id => {
189                        // Check that the operation is not invoked outside a `finalize` block.
190                        self.check_access_allowed("network.id", true, input.name.span());
191                        let ty = Type::Integer(IntegerType::U16);
192                        self.maybe_assert_type(&ty, expected, input.span());
193                        return ty;
194                    }
195                    _ => {
196                        self.emit_err(TypeCheckerError::invalid_block_access(input.name.span()));
197                        return Type::Err;
198                    }
199                },
200                _ => {}
201            }
202        }
203
204        let ty = if assign {
205            self.visit_expression_assign(&input.inner)
206        } else {
207            self.visit_expression(&input.inner, &None)
208        };
209
210        // Make sure we're not assigning to a member of an external record.
211        if assign && self.is_external_record(&ty) {
212            self.emit_err(TypeCheckerError::assignment_to_external_record_member(&ty, input.span));
213        }
214
215        // Check that the type of `inner` in `inner.name` is a struct.
216        match ty {
217            Type::Err => Type::Err,
218            Type::Composite(ref struct_) => {
219                // Retrieve the struct definition associated with `identifier`.
220                let Some(struct_) = self
221                    .lookup_struct(struct_.program.or(self.scope_state.program_name), &struct_.path.absolute_path())
222                else {
223                    self.emit_err(TypeCheckerError::undefined_type(ty, input.inner.span()));
224                    return Type::Err;
225                };
226                // Check that `input.name` is a member of the struct.
227                match struct_.members.iter().find(|member| member.name() == input.name.name) {
228                    // Case where `input.name` is a member of the struct.
229                    Some(Member { type_, .. }) => {
230                        // Check that the type of `input.name` is the same as `expected`.
231                        self.maybe_assert_type(type_, expected, input.span());
232                        type_.clone()
233                    }
234                    // Case where `input.name` is not a member of the struct.
235                    None => {
236                        self.emit_err(TypeCheckerError::invalid_struct_variable(
237                            input.name,
238                            &struct_,
239                            input.name.span(),
240                        ));
241                        Type::Err
242                    }
243                }
244            }
245            type_ => {
246                self.emit_err(TypeCheckerError::type_should_be2(type_, "a struct or record", input.inner.span()));
247                Type::Err
248            }
249        }
250    }
251
252    pub fn visit_tuple_access_general(&mut self, input: &TupleAccess, assign: bool, expected: &Option<Type>) -> Type {
253        let this_type = if assign {
254            self.visit_expression_assign(&input.tuple)
255        } else {
256            self.visit_expression(&input.tuple, &None)
257        };
258        match this_type {
259            Type::Err => Type::Err,
260            Type::Tuple(tuple) => {
261                // Check out of range input.
262                let index = input.index.value();
263                let Some(actual) = tuple.elements().get(index) else {
264                    self.emit_err(TypeCheckerError::tuple_out_of_range(index, tuple.length(), input.span()));
265                    return Type::Err;
266                };
267
268                self.maybe_assert_type(actual, expected, input.span());
269
270                actual.clone()
271            }
272            Type::Future(_) => {
273                // Get the fully inferred type.
274                let Some(Type::Future(inferred_f)) = self.state.type_table.get(&input.tuple.id()) else {
275                    // If a future type was not inferred, we will have already reported an error.
276                    return Type::Err;
277                };
278
279                if inferred_f.location.is_none() {
280                    // This generally means that the `Future` is produced by an `async` block expression and not an
281                    // `async function` function call.
282                    self.emit_err(TypeCheckerError::invalid_async_block_future_access(input.span()));
283                    return Type::Err;
284                }
285
286                let Some(actual) = inferred_f.inputs().get(input.index.value()) else {
287                    self.emit_err(TypeCheckerError::invalid_future_access(
288                        input.index.value(),
289                        inferred_f.inputs().len(),
290                        input.span(),
291                    ));
292                    return Type::Err;
293                };
294
295                // If all inferred types weren't the same, the member will be of type `Type::Err`.
296                if let Type::Err = actual {
297                    self.emit_err(TypeCheckerError::future_error_member(input.index.value(), input.span()));
298                    return Type::Err;
299                }
300
301                self.maybe_assert_type(actual, expected, input.span());
302
303                actual.clone()
304            }
305            type_ => {
306                self.emit_err(TypeCheckerError::type_should_be2(type_, "a tuple or future", input.span()));
307                Type::Err
308            }
309        }
310    }
311
312    pub fn visit_path_assign(&mut self, input: &Path) -> Type {
313        // Lookup the variable in the symbol table and retrieve its type.
314        let Some(var) =
315            self.state.symbol_table.lookup_path(self.scope_state.program_name.unwrap(), &input.absolute_path())
316        else {
317            self.emit_err(TypeCheckerError::unknown_sym("variable", input, input.span));
318            return Type::Err;
319        };
320
321        // If the variable exists, then check that it is not a constant.
322        match &var.declaration {
323            VariableType::Const => self.emit_err(TypeCheckerError::cannot_assign_to_const_var(input, var.span)),
324            VariableType::ConstParameter => {
325                self.emit_err(TypeCheckerError::cannot_assign_to_generic_const_function_parameter(input, input.span))
326            }
327            VariableType::Input(Mode::Constant) => {
328                self.emit_err(TypeCheckerError::cannot_assign_to_const_input(input, var.span))
329            }
330            VariableType::Mut | VariableType::Input(_) => {}
331        }
332
333        // If the variable exists and it's in an async function, then check that it is in the current conditional scope.
334        if self.scope_state.variant.unwrap().is_async_function()
335            && !self.symbol_in_conditional_scope(input.identifier().name)
336        {
337            self.emit_err(TypeCheckerError::async_cannot_assign_outside_conditional(input, "function", var.span));
338        }
339
340        // Similarly, if the variable exists and it's in an async block, then check that it is in the current conditional scope.
341        if self.async_block_id.is_some() && !self.symbol_in_conditional_scope(input.identifier().name) {
342            self.emit_err(TypeCheckerError::async_cannot_assign_outside_conditional(input, "block", var.span));
343        }
344
345        if let Some(async_block_id) = self.async_block_id {
346            if !self.state.symbol_table.is_defined_in_scope_or_ancestor_until(async_block_id, input.identifier().name) {
347                // If we're inside an async block (i.e. in the scope of its block or one if its child scopes) and if
348                // we're trying to assign to a variable that is not local to the block (or its child scopes), then we
349                // should error out.
350                self.emit_err(TypeCheckerError::cannot_assign_to_vars_outside_async_block(
351                    input.identifier().name,
352                    input.span,
353                ));
354            }
355        }
356
357        var.type_.clone()
358    }
359
360    /// Infers the type of an expression, but returns Type::Err and emits an error if the result is Type::Numeric.
361    /// Used to disallow numeric types in specific contexts where they are not valid or expected.
362    pub(crate) fn visit_expression_reject_numeric(&mut self, expr: &Expression, expected: &Option<Type>) -> Type {
363        let mut inferred = self.visit_expression(expr, expected);
364        match inferred {
365            Type::Numeric => {
366                self.emit_inference_failure_error(&mut inferred, expr);
367                Type::Err
368            }
369            _ => inferred,
370        }
371    }
372
373    /// Infers the type of an expression, and if it is `Type::Numeric`, coerces it to `U32`, validates it, and
374    /// records it in the type table.
375    pub(crate) fn visit_expression_infer_default_u32(&mut self, expr: &Expression) -> Type {
376        let mut inferred = self.visit_expression(expr, &None);
377
378        if inferred == Type::Numeric {
379            inferred = Type::Integer(IntegerType::U32);
380
381            if let Expression::Literal(literal) = expr {
382                if !self.check_numeric_literal(literal, &inferred) {
383                    inferred = Type::Err;
384                }
385            }
386
387            self.state.type_table.insert(expr.id(), inferred.clone());
388        }
389
390        inferred
391    }
392}
393
394impl AstVisitor for TypeCheckingVisitor<'_> {
395    type AdditionalInput = Option<Type>;
396    type Output = Type;
397
398    /* Types */
399    fn visit_array_type(&mut self, input: &ArrayType) {
400        self.visit_type(&input.element_type);
401        self.visit_expression_infer_default_u32(&input.length);
402    }
403
404    fn visit_composite_type(&mut self, input: &CompositeType) {
405        let struct_ = self.lookup_struct(self.scope_state.program_name, &input.path.absolute_path()).clone();
406
407        if let Some(struct_) = struct_ {
408            // Check the number of const arguments against the number of the struct's const parameters
409            if struct_.const_parameters.len() != input.const_arguments.len() {
410                self.emit_err(TypeCheckerError::incorrect_num_const_args(
411                    "Struct type",
412                    struct_.const_parameters.len(),
413                    input.const_arguments.len(),
414                    input.path.span,
415                ));
416            }
417
418            // Check the types of const arguments against the types of the struct's const parameters
419            for (expected, argument) in struct_.const_parameters.iter().zip(input.const_arguments.iter()) {
420                self.visit_expression(argument, &Some(expected.type_().clone()));
421            }
422        } else if !input.const_arguments.is_empty() {
423            self.emit_err(TypeCheckerError::unexpected_const_args(input, input.path.span));
424        }
425    }
426
427    /* Expressions */
428    fn visit_expression(&mut self, input: &Expression, additional: &Self::AdditionalInput) -> Self::Output {
429        let output = match input {
430            Expression::Array(array) => self.visit_array(array, additional),
431            Expression::ArrayAccess(access) => self.visit_array_access_general(access, false, additional),
432            Expression::AssociatedConstant(constant) => self.visit_associated_constant(constant, additional),
433            Expression::AssociatedFunction(function) => self.visit_associated_function(function, additional),
434            Expression::Async(async_) => self.visit_async(async_, additional),
435            Expression::Binary(binary) => self.visit_binary(binary, additional),
436            Expression::Call(call) => self.visit_call(call, additional),
437            Expression::Cast(cast) => self.visit_cast(cast, additional),
438            Expression::Struct(struct_) => self.visit_struct_init(struct_, additional),
439            Expression::Err(err) => self.visit_err(err, additional),
440            Expression::Path(path) => self.visit_path(path, additional),
441            Expression::Literal(literal) => self.visit_literal(literal, additional),
442            Expression::Locator(locator) => self.visit_locator(locator, additional),
443            Expression::MemberAccess(access) => self.visit_member_access_general(access, false, additional),
444            Expression::Repeat(repeat) => self.visit_repeat(repeat, additional),
445            Expression::Ternary(ternary) => self.visit_ternary(ternary, additional),
446            Expression::Tuple(tuple) => self.visit_tuple(tuple, additional),
447            Expression::TupleAccess(access) => self.visit_tuple_access_general(access, false, additional),
448            Expression::Unary(unary) => self.visit_unary(unary, additional),
449            Expression::Unit(unit) => self.visit_unit(unit, additional),
450        };
451
452        // Add the expression and its associated type to the symbol table.
453        self.state.type_table.insert(input.id(), output.clone());
454        output
455    }
456
457    fn visit_array_access(&mut self, _input: &ArrayAccess, _additional: &Self::AdditionalInput) -> Self::Output {
458        panic!("Should not be called.");
459    }
460
461    fn visit_member_access(&mut self, _input: &MemberAccess, _additional: &Self::AdditionalInput) -> Self::Output {
462        panic!("Should not be called.");
463    }
464
465    fn visit_tuple_access(&mut self, _input: &TupleAccess, _additional: &Self::AdditionalInput) -> Self::Output {
466        panic!("Should not be called.");
467    }
468
469    fn visit_array(&mut self, input: &ArrayExpression, additional: &Self::AdditionalInput) -> Self::Output {
470        if input.elements.is_empty() {
471            self.emit_err(TypeCheckerError::array_empty(input.span()));
472            return Type::Err;
473        }
474
475        // Grab the element type from the expected type if the expected type is an array or if it's
476        // an optional array
477        let element_type = match additional {
478            Some(Type::Array(array_ty)) => Some(array_ty.element_type().clone()),
479            Some(Type::Optional(opt)) => match &*opt.inner {
480                Type::Array(array_ty) => Some(array_ty.element_type().clone()),
481                _ => None,
482            },
483            _ => None,
484        };
485
486        let inferred_type = self.visit_expression_reject_numeric(&input.elements[0], &element_type);
487
488        if input.elements.len() > self.limits.max_array_elements {
489            self.emit_err(TypeCheckerError::array_too_large(
490                input.elements.len(),
491                self.limits.max_array_elements,
492                input.span(),
493            ));
494        }
495
496        for expression in input.elements[1..].iter() {
497            let next_type = self.visit_expression_reject_numeric(expression, &element_type);
498
499            if next_type == Type::Err {
500                return Type::Err;
501            }
502
503            if let Some(ref element_type) = element_type {
504                self.assert_type(&next_type, element_type, expression.span());
505            } else {
506                self.assert_type(&next_type, &inferred_type, expression.span());
507            }
508        }
509
510        if inferred_type == Type::Err {
511            return Type::Err;
512        }
513
514        let type_ = Type::Array(ArrayType::new(
515            inferred_type,
516            Expression::Literal(Literal {
517                // The default type for array length is `U32`.
518                variant: LiteralVariant::Integer(IntegerType::U32, input.elements.len().to_string()),
519                id: self.state.node_builder.next_id(),
520                span: Span::default(),
521            }),
522        ));
523
524        self.maybe_assert_type(&type_, additional, input.span());
525
526        type_
527    }
528
529    fn visit_repeat(&mut self, input: &RepeatExpression, additional: &Self::AdditionalInput) -> Self::Output {
530        // Grab the element type from the expected type if the expected type is an array or if it's
531        // an optional array
532        let expected_element_type = match additional {
533            Some(Type::Array(array_ty)) => Some(array_ty.element_type().clone()),
534            Some(Type::Optional(opt)) => match &*opt.inner {
535                Type::Array(array_ty) => Some(array_ty.element_type().clone()),
536                _ => None,
537            },
538            _ => None,
539        };
540
541        let inferred_element_type = self.visit_expression_reject_numeric(&input.expr, &expected_element_type);
542
543        // Now infer the type of `count`. If it's an unsuffixed literal (i.e. has `Type::Numeric`), then infer it to be
544        // a `U32` as the default type.
545        self.visit_expression_infer_default_u32(&input.count);
546
547        // If we can already evaluate the repeat count as a `u32`, then make sure it's not 0 or  greater than the array
548        // size limit.
549        if let Some(count) = input.count.as_u32() {
550            if count == 0 {
551                self.emit_err(TypeCheckerError::array_empty(input.span()));
552                return Type::Err;
553            }
554
555            if count > self.limits.max_array_elements as u32 {
556                self.emit_err(TypeCheckerError::array_too_large(count, self.limits.max_array_elements, input.span()));
557            }
558        }
559
560        let type_ = Type::Array(ArrayType::new(inferred_element_type, input.count.clone()));
561
562        self.maybe_assert_type(&type_, additional, input.span());
563        type_
564    }
565
566    fn visit_associated_constant(
567        &mut self,
568        input: &AssociatedConstantExpression,
569        expected: &Self::AdditionalInput,
570    ) -> Self::Output {
571        // Check associated constant type and constant name
572        let Some(core_constant) = self.get_core_constant(&input.ty, &input.name) else {
573            self.emit_err(TypeCheckerError::invalid_associated_constant(input, input.span));
574            return Type::Err;
575        };
576        let type_ = core_constant.to_type();
577        self.maybe_assert_type(&type_, expected, input.span());
578        type_
579    }
580
581    fn visit_associated_function(
582        &mut self,
583        input: &AssociatedFunctionExpression,
584        expected: &Self::AdditionalInput,
585    ) -> Self::Output {
586        // Check core struct name and function.
587        let Some(core_instruction) = self.get_core_function_call(&input.variant, &input.name) else {
588            self.emit_err(TypeCheckerError::invalid_core_function_call(input, input.span()));
589            return Type::Err;
590        };
591        // Check that operation is not restricted to finalize blocks.
592        if !matches!(self.scope_state.variant, Some(Variant::AsyncFunction) | Some(Variant::Script))
593            && self.async_block_id.is_none()
594            && core_instruction.is_finalize_command()
595        {
596            self.emit_err(TypeCheckerError::operation_must_be_in_async_block_or_function(input.span()));
597        }
598
599        let return_type = if let CoreFunction::OptionalUnwrapOr = core_instruction {
600            // Ensure we have exactly two arguments
601            let (optional_expr, fallback_expr) = match &input.arguments[..] {
602                [opt, fallback] => (opt, fallback),
603                _ => {
604                    self.emit_err(TypeCheckerError::incorrect_num_args_to_call(
605                        core_instruction.num_args(),
606                        input.arguments.len(),
607                        input.span(),
608                    ));
609                    return Type::Err;
610                }
611            };
612
613            // Type check the first argument normally
614            let optional_ty = self.visit_expression(optional_expr, &None);
615
616            // This emits an error if the type is not optional
617            self.assert_optional_type(&optional_ty, optional_expr.span());
618
619            // If the type is not Optional, we return Err
620            let Type::Optional(OptionalType { inner }) = optional_ty else {
621                return Type::Err;
622            };
623
624            // Use the inner type of the optional as the expected type for the fallback
625            let fallback_ty = self.visit_expression_reject_numeric(fallback_expr, &Some(*inner.clone()));
626            self.assert_type(&fallback_ty, &inner, fallback_expr.span());
627
628            // Return the final type: the inner type (what unwrap_or returns)
629            *inner
630        } else {
631            // Get the types of the arguments. Error out on arguments that have `Type::Numeric`. We could potentially do
632            // better for some of the core functions, but that can get pretty tedious because it would have to be function
633            // specific.
634            let arguments_with_types = input
635                .arguments
636                .iter()
637                .map(|arg| (self.visit_expression_reject_numeric(arg, &None), arg))
638                .collect::<Vec<_>>();
639
640            // Check return type if the expected type is known.
641            self.check_core_function_call(core_instruction.clone(), &arguments_with_types, input.span())
642        };
643
644        // Check return type if the expected type is known.
645        self.maybe_assert_type(&return_type, expected, input.span());
646
647        // Await futures here so that can use the argument variable names to lookup.
648        if core_instruction == CoreFunction::FutureAwait && input.arguments.len() != 1 {
649            self.emit_err(TypeCheckerError::can_only_await_one_future_at_a_time(input.span));
650        }
651
652        return_type
653    }
654
655    fn visit_async(&mut self, input: &AsyncExpression, _additional: &Self::AdditionalInput) -> Self::Output {
656        // Step into an async block
657        self.async_block_id = Some(input.block.id);
658
659        // A few restrictions
660        if self.scope_state.is_conditional {
661            self.emit_err(TypeCheckerError::async_block_in_conditional(input.span));
662        }
663
664        if !matches!(self.scope_state.variant, Some(Variant::AsyncTransition) | Some(Variant::Script)) {
665            self.emit_err(TypeCheckerError::illegal_async_block_location(input.span));
666        }
667
668        if self.scope_state.already_contains_an_async_block {
669            self.emit_err(TypeCheckerError::multiple_async_blocks_not_allowed(input.span));
670        }
671
672        if self.scope_state.has_called_finalize {
673            self.emit_err(TypeCheckerError::conflicting_async_call_and_block(input.span));
674        }
675
676        self.visit_block(&input.block);
677
678        // This scope now already has an async block
679        self.scope_state.already_contains_an_async_block = true;
680
681        // Step out of the async block
682        self.async_block_id = None;
683
684        // The type of the async block is just a `Future` with no `Location` (i.e. not produced by an explicit `async
685        // function`) and no inputs since we're not allowed to access inputs of a `Future` produced by an `async block.
686        Type::Future(FutureType::new(Vec::new(), None, false))
687    }
688
689    fn visit_binary(&mut self, input: &BinaryExpression, destination: &Self::AdditionalInput) -> Self::Output {
690        let assert_same_type = |slf: &Self, t1: &Type, t2: &Type| -> Type {
691            if t1 == &Type::Err || t2 == &Type::Err {
692                Type::Err
693            } else if !t1.eq_user(t2) {
694                slf.emit_err(TypeCheckerError::operation_types_mismatch(input.op, t1, t2, input.span()));
695                Type::Err
696            } else {
697                t1.clone()
698            }
699        };
700
701        // This closure attempts to resolve numeric type inference between two operands.
702        // It handles the following cases:
703        // - If both types are unknown numeric placeholders (`Numeric`), emit errors for both.
704        // - If one type is `Numeric` and the other is an error (`Err`), propagate the error.
705        // - If one type is a known numeric type and the other is `Numeric`, infer the unknown type.
706        // - If one type is `Numeric` but the other is not a valid numeric type, emit an error.
707        // - Otherwise, do nothing (types are already resolved or not subject to inference).
708        let infer_numeric_types = |slf: &Self, left_type: &mut Type, right_type: &mut Type| {
709            use Type::*;
710
711            match (&*left_type, &*right_type) {
712                // Case: Both types are unknown numeric types – cannot infer either side
713                (Numeric, Numeric) => {
714                    slf.emit_inference_failure_error(left_type, &input.left);
715                    slf.emit_inference_failure_error(right_type, &input.right);
716                }
717
718                // Case: Left is unknown numeric, right is erroneous – propagate error to left
719                (Numeric, Err) => slf.emit_inference_failure_error(left_type, &input.left),
720
721                // Case: Right is unknown numeric, left is erroneous – propagate error to right
722                (Err, Numeric) => slf.emit_inference_failure_error(right_type, &input.right),
723
724                // Case: Right type is unknown numeric, infer it from known left type
725                (Integer(_) | Field | Group | Scalar, Numeric) => {
726                    *right_type = left_type.clone();
727                    slf.state.type_table.insert(input.right.id(), right_type.clone());
728                    if let Expression::Literal(literal) = &input.right {
729                        slf.check_numeric_literal(literal, right_type);
730                    }
731                }
732
733                // Case: Left type is unknown numeric, infer it from known right type
734                (Numeric, Integer(_) | Field | Group | Scalar) => {
735                    *left_type = right_type.clone();
736                    slf.state.type_table.insert(input.left.id(), left_type.clone());
737                    if let Expression::Literal(literal) = &input.left {
738                        slf.check_numeric_literal(literal, left_type);
739                    }
740                }
741
742                // Case: Left type is numeric but right is invalid for numeric inference – error on left
743                (Numeric, _) => slf.emit_inference_failure_error(left_type, &input.left),
744
745                // Case: Right type is numeric but left is invalid for numeric inference – error on right
746                (_, Numeric) => slf.emit_inference_failure_error(right_type, &input.right),
747
748                // No inference or error needed. Rely on further operator-specific checks.
749                _ => {}
750            }
751        };
752
753        match input.op {
754            BinaryOperation::And | BinaryOperation::Or | BinaryOperation::Nand | BinaryOperation::Nor => {
755                self.maybe_assert_type(&Type::Boolean, destination, input.span());
756                self.visit_expression(&input.left, &Some(Type::Boolean));
757                self.visit_expression(&input.right, &Some(Type::Boolean));
758                Type::Boolean
759            }
760            BinaryOperation::BitwiseAnd | BinaryOperation::BitwiseOr | BinaryOperation::Xor => {
761                let operand_expected = self.unwrap_optional_type(destination);
762
763                // The expected type for both `left` and `right` is the unwrapped type
764                let mut t1 = self.visit_expression(&input.left, &operand_expected);
765                let mut t2 = self.visit_expression(&input.right, &operand_expected);
766
767                // Infer `Numeric` types if possible
768                infer_numeric_types(self, &mut t1, &mut t2);
769
770                // Now sanity check everything
771                self.assert_bool_int_type(&t1, input.left.span());
772                self.assert_bool_int_type(&t2, input.right.span());
773
774                let result_t = assert_same_type(self, &t1, &t2);
775                self.maybe_assert_type(&result_t, destination, input.span());
776
777                self.wrap_if_optional(result_t, destination)
778            }
779            BinaryOperation::Add => {
780                let operand_expected = self.unwrap_optional_type(destination);
781
782                // The expected type for both `left` and `right` is the unwrapped type
783                let mut t1 = self.visit_expression(&input.left, &operand_expected);
784                let mut t2 = self.visit_expression(&input.right, &operand_expected);
785
786                // Infer `Numeric` types if possible
787                infer_numeric_types(self, &mut t1, &mut t2);
788
789                // Now sanity check everything
790                let assert_add_type = |type_: &Type, span: Span| {
791                    if !matches!(type_, Type::Err | Type::Field | Type::Group | Type::Scalar | Type::Integer(_)) {
792                        self.emit_err(TypeCheckerError::type_should_be2(
793                            type_,
794                            "a field, group, scalar, or integer",
795                            span,
796                        ));
797                    }
798                };
799
800                assert_add_type(&t1, input.left.span());
801                assert_add_type(&t2, input.right.span());
802
803                let result_t = assert_same_type(self, &t1, &t2);
804
805                self.maybe_assert_type(&result_t, destination, input.span());
806
807                self.wrap_if_optional(result_t, destination)
808            }
809            BinaryOperation::Sub => {
810                let operand_expected = self.unwrap_optional_type(destination);
811
812                // The expected type for both `left` and `right` is the unwrapped type
813                let mut t1 = self.visit_expression(&input.left, &operand_expected);
814                let mut t2 = self.visit_expression(&input.right, &operand_expected);
815
816                // Infer `Numeric` types if possible
817                infer_numeric_types(self, &mut t1, &mut t2);
818
819                // Now sanity check everything
820                self.assert_field_group_int_type(&t1, input.left.span());
821                self.assert_field_group_int_type(&t2, input.right.span());
822
823                let result_t = assert_same_type(self, &t1, &t2);
824
825                self.maybe_assert_type(&result_t, destination, input.span());
826
827                self.wrap_if_optional(result_t, destination)
828            }
829            BinaryOperation::Mul => {
830                let unwrapped_dest = self.unwrap_optional_type(destination);
831
832                // The expected type for both `left` and `right` is the same as unwrapped destination except when it is
833                // a `Type::Group`. In that case, the two operands should be a `Type::Group` and `Type::Scalar` but we can't
834                // known which one is which.
835                let expected = if matches!(unwrapped_dest, Some(Type::Group)) { &None } else { &unwrapped_dest };
836                let mut t1 = self.visit_expression(&input.left, expected);
837                let mut t2 = self.visit_expression(&input.right, expected);
838
839                // - If one side is `Group` and the other is an unresolved `Numeric`, infer the `Numeric` as a `Scalar`,
840                //   since `Group * Scalar = Group`.
841                // - Similarly, if one side is `Scalar` and the other is `Numeric`, infer the `Numeric` as `Group`.
842                //
843                // If no special case applies, default to inferring types between `t1` and `t2` as-is.
844                match (&t1, &t2) {
845                    (Type::Group, Type::Numeric) => infer_numeric_types(self, &mut Type::Scalar, &mut t2),
846                    (Type::Numeric, Type::Group) => infer_numeric_types(self, &mut t1, &mut Type::Scalar),
847                    (Type::Scalar, Type::Numeric) => infer_numeric_types(self, &mut Type::Group, &mut t2),
848                    (Type::Numeric, Type::Scalar) => infer_numeric_types(self, &mut t1, &mut Type::Group),
849                    (_, _) => infer_numeric_types(self, &mut t1, &mut t2),
850                }
851
852                // Final sanity checks
853                let result_t = match (&t1, &t2) {
854                    (Type::Err, _) | (_, Type::Err) => Type::Err,
855                    (Type::Group, Type::Scalar) | (Type::Scalar, Type::Group) => Type::Group,
856                    (Type::Field, Type::Field) => Type::Field,
857                    (Type::Integer(integer_type1), Type::Integer(integer_type2)) if integer_type1 == integer_type2 => {
858                        t1.clone()
859                    }
860                    _ => {
861                        self.emit_err(TypeCheckerError::mul_types_mismatch(t1, t2, input.span()));
862                        Type::Err
863                    }
864                };
865
866                self.maybe_assert_type(&result_t, destination, input.span());
867
868                self.wrap_if_optional(result_t, destination)
869            }
870            BinaryOperation::Div => {
871                let operand_expected = self.unwrap_optional_type(destination);
872
873                // The expected type for both `left` and `right` is the unwrapped type
874                let mut t1 = self.visit_expression(&input.left, &operand_expected);
875                let mut t2 = self.visit_expression(&input.right, &operand_expected);
876
877                // Infer `Numeric` types if possible
878                infer_numeric_types(self, &mut t1, &mut t2);
879
880                // Now sanity check everything
881                self.assert_field_int_type(&t1, input.left.span());
882                self.assert_field_int_type(&t2, input.right.span());
883
884                let result_t = assert_same_type(self, &t1, &t2);
885
886                self.maybe_assert_type(&result_t, destination, input.span());
887
888                self.wrap_if_optional(result_t, destination)
889            }
890            BinaryOperation::Rem | BinaryOperation::RemWrapped => {
891                let operand_expected = self.unwrap_optional_type(destination);
892
893                // The expected type for both `left` and `right` is the unwrapped type
894                let mut t1 = self.visit_expression(&input.left, &operand_expected);
895                let mut t2 = self.visit_expression(&input.right, &operand_expected);
896
897                // Infer `Numeric` types if possible
898                infer_numeric_types(self, &mut t1, &mut t2);
899
900                // Now sanity check everything
901                self.assert_int_type(&t1, input.left.span());
902                self.assert_int_type(&t2, input.right.span());
903
904                let result_t = assert_same_type(self, &t1, &t2);
905
906                self.maybe_assert_type(&result_t, destination, input.span());
907
908                self.wrap_if_optional(result_t, destination)
909            }
910            BinaryOperation::Mod => {
911                let operand_expected = self.unwrap_optional_type(destination);
912
913                // The expected type for both `left` and `right` is the unwrapped type
914                let mut t1 = self.visit_expression(&input.left, &operand_expected);
915                let mut t2 = self.visit_expression(&input.right, &operand_expected);
916
917                // Infer `Numeric` types if possible
918                infer_numeric_types(self, &mut t1, &mut t2);
919
920                // Now sanity check everything
921                self.assert_unsigned_type(&t1, input.left.span());
922                self.assert_unsigned_type(&t2, input.right.span());
923
924                let result_t = assert_same_type(self, &t1, &t2);
925
926                self.maybe_assert_type(&result_t, destination, input.span());
927
928                self.wrap_if_optional(result_t, destination)
929            }
930            BinaryOperation::Pow => {
931                let operand_expected = self.unwrap_optional_type(destination);
932
933                // The expected type of `left` is the unwrapped destination
934                let mut t1 = self.visit_expression(&input.left, &operand_expected);
935
936                // The expected type of `right` is `field`, `u8`, `u16`, or `u32` so leave it as `None` for now.
937                let mut t2 = self.visit_expression(&input.right, &None);
938
939                // If one side is a `Field` and the other is a `Numeric`, infer the `Numeric` as a `Field.
940                // Otherwise, error out for each `Numeric`.
941                if matches!((&t1, &t2), (Type::Field, Type::Numeric) | (Type::Numeric, Type::Field)) {
942                    infer_numeric_types(self, &mut t1, &mut t2);
943                } else {
944                    if matches!(t1, Type::Numeric) {
945                        self.emit_inference_failure_error(&mut t1, &input.left);
946                    }
947                    if matches!(t2, Type::Numeric) {
948                        self.emit_inference_failure_error(&mut t2, &input.right);
949                    }
950                }
951
952                // Now sanity check everything
953                let ty = match (&t1, &t2) {
954                    (Type::Err, _) | (_, Type::Err) => Type::Err,
955                    (Type::Field, Type::Field) => Type::Field,
956                    (base @ Type::Integer(_), t2) => {
957                        if !matches!(
958                            t2,
959                            Type::Integer(IntegerType::U8)
960                                | Type::Integer(IntegerType::U16)
961                                | Type::Integer(IntegerType::U32)
962                        ) {
963                            self.emit_err(TypeCheckerError::pow_types_mismatch(base, t2, input.span()));
964                        }
965                        base.clone()
966                    }
967                    _ => {
968                        self.emit_err(TypeCheckerError::pow_types_mismatch(t1, t2, input.span()));
969                        Type::Err
970                    }
971                };
972
973                self.maybe_assert_type(&ty, destination, input.span());
974
975                self.wrap_if_optional(ty, destination)
976            }
977            BinaryOperation::Eq | BinaryOperation::Neq => {
978                // Handle type inference for `None` as a special case.
979                //
980                // If either side of the binary expression is the literal `None`, we first type check the other side
981                // without any expected type to infer its type. Then we type check the `None` side using that inferred type
982                // as context, in hopes of resolving it to a more specific optional type.
983                //
984                // This helps with cases like `x == None`, allowing us to infer the type of `x` and apply it to `None`.
985                // However, this is **not sufficient for the general case**. For instance, in something like `[None] == [x]`,
986                // we won't be able to infer the type of `None`.
987                let (mut t1, mut t2) =
988                    if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input.right {
989                        let t1 = self.visit_expression(&input.left, &None);
990                        (t1.clone(), self.visit_expression(&input.right, &Some(t1.clone())))
991                    } else if let Expression::Literal(Literal { variant: LiteralVariant::None, .. }) = input.left {
992                        let t2 = self.visit_expression(&input.right, &None);
993                        (self.visit_expression(&input.left, &Some(t2.clone())), t2)
994                    } else {
995                        (self.visit_expression(&input.left, &None), self.visit_expression(&input.right, &None))
996                    };
997
998                // Infer `Numeric` types if possible
999                infer_numeric_types(self, &mut t1, &mut t2);
1000
1001                // Now sanity check everything
1002                let _ = assert_same_type(self, &t1, &t2);
1003
1004                self.maybe_assert_type(&Type::Boolean, destination, input.span());
1005
1006                Type::Boolean
1007            }
1008            BinaryOperation::Lt | BinaryOperation::Gt | BinaryOperation::Lte | BinaryOperation::Gte => {
1009                // Assert left and right are equal field, scalar, or integer types.
1010                let mut t1 = self.visit_expression(&input.left, &None);
1011                let mut t2 = self.visit_expression(&input.right, &None);
1012
1013                // Infer `Numeric` types if possible
1014                infer_numeric_types(self, &mut t1, &mut t2);
1015
1016                // Now sanity check everything
1017                let assert_compare_type = |type_: &Type, span: Span| {
1018                    if !matches!(type_, Type::Err | Type::Field | Type::Scalar | Type::Integer(_)) {
1019                        self.emit_err(TypeCheckerError::type_should_be2(type_, "a field, scalar, or integer", span));
1020                    }
1021                };
1022
1023                assert_compare_type(&t1, input.left.span());
1024                assert_compare_type(&t2, input.right.span());
1025
1026                let _ = assert_same_type(self, &t1, &t2);
1027
1028                self.maybe_assert_type(&Type::Boolean, destination, input.span());
1029
1030                Type::Boolean
1031            }
1032            BinaryOperation::AddWrapped
1033            | BinaryOperation::SubWrapped
1034            | BinaryOperation::DivWrapped
1035            | BinaryOperation::MulWrapped => {
1036                let operand_expected = self.unwrap_optional_type(destination);
1037
1038                // The expected type for both `left` and `right` is the unwrapped type
1039                let mut t1 = self.visit_expression(&input.left, &operand_expected);
1040                let mut t2 = self.visit_expression(&input.right, &operand_expected);
1041
1042                // Infer `Numeric` types if possible
1043                infer_numeric_types(self, &mut t1, &mut t2);
1044
1045                // Now sanity check everything
1046                self.assert_int_type(&t1, input.left.span());
1047                self.assert_int_type(&t2, input.right.span());
1048
1049                let result_t = assert_same_type(self, &t1, &t2);
1050
1051                self.maybe_assert_type(&result_t, destination, input.span());
1052
1053                self.wrap_if_optional(result_t, destination)
1054            }
1055            BinaryOperation::Shl
1056            | BinaryOperation::ShlWrapped
1057            | BinaryOperation::Shr
1058            | BinaryOperation::ShrWrapped
1059            | BinaryOperation::PowWrapped => {
1060                let operand_expected = self.unwrap_optional_type(destination);
1061
1062                // The expected type of `left` is the unwrapped `destination`
1063                let t1 = self.visit_expression_reject_numeric(&input.left, &operand_expected);
1064
1065                // The expected type of `right` is `field`, `u8`, `u16`, or `u32` so leave it as `None` for now.
1066                let t2 = self.visit_expression_reject_numeric(&input.right, &None);
1067
1068                self.assert_int_type(&t1, input.left.span());
1069
1070                if !matches!(
1071                    &t2,
1072                    Type::Err
1073                        | Type::Integer(IntegerType::U8)
1074                        | Type::Integer(IntegerType::U16)
1075                        | Type::Integer(IntegerType::U32)
1076                ) {
1077                    self.emit_err(TypeCheckerError::shift_type_magnitude(input.op, t2, input.right.span()));
1078                }
1079
1080                self.wrap_if_optional(t1, destination)
1081            }
1082        }
1083    }
1084
1085    fn visit_call(&mut self, input: &CallExpression, expected: &Self::AdditionalInput) -> Self::Output {
1086        let callee_program = input.program.or(self.scope_state.program_name).unwrap();
1087
1088        let callee_path = input.function.absolute_path();
1089
1090        let Some(func_symbol) =
1091            self.state.symbol_table.lookup_function(&Location::new(callee_program, callee_path.clone()))
1092        else {
1093            self.emit_err(TypeCheckerError::unknown_sym("function", input.function.clone(), input.function.span()));
1094            return Type::Err;
1095        };
1096
1097        let func = func_symbol.function.clone();
1098
1099        // Check that the call is valid.
1100        // We always set the variant before entering the body of a function, so this unwrap works.
1101        match self.scope_state.variant.unwrap() {
1102            Variant::AsyncFunction | Variant::Function if !matches!(func.variant, Variant::Inline) => self.emit_err(
1103                TypeCheckerError::can_only_call_inline_function("a `function`, `inline`, or `constructor`", input.span),
1104            ),
1105            Variant::Transition | Variant::AsyncTransition
1106                if matches!(func.variant, Variant::Transition)
1107                    && input.program.is_none_or(|program| program == self.scope_state.program_name.unwrap()) =>
1108            {
1109                self.emit_err(TypeCheckerError::cannot_invoke_call_to_local_transition_function(input.span))
1110            }
1111            _ => {}
1112        }
1113
1114        // Check that the call is not to an external `inline` function.
1115        if func.variant == Variant::Inline
1116            && input.program.is_some_and(|program| program != self.scope_state.program_name.unwrap())
1117        {
1118            self.emit_err(TypeCheckerError::cannot_call_external_inline_function(input.span));
1119        }
1120
1121        // Make sure we're not calling a non-inline from an async block
1122        if self.async_block_id.is_some() && !matches!(func.variant, Variant::Inline) {
1123            self.emit_err(TypeCheckerError::can_only_call_inline_function("an async block", input.span));
1124        }
1125
1126        // Async functions return a single future.
1127        let mut ret = if func.variant == Variant::AsyncFunction {
1128            // Async functions always return futures.
1129            Type::Future(FutureType::new(
1130                Vec::new(),
1131                Some(Location::new(callee_program, input.function.absolute_path())),
1132                false,
1133            ))
1134        } else if func.variant == Variant::AsyncTransition {
1135            // Fully infer future type.
1136            let Some(inputs) =
1137                self.async_function_input_types.get(&Location::new(callee_program, vec![Symbol::intern(&format!(
1138                    "finalize/{}",
1139                    input.function.identifier().name
1140                ))]))
1141            else {
1142                self.emit_err(TypeCheckerError::async_function_not_found(input.function.clone(), input.span));
1143                return Type::Future(FutureType::new(
1144                    Vec::new(),
1145                    Some(Location::new(callee_program, callee_path.clone())),
1146                    false,
1147                ));
1148            };
1149
1150            let future_type = Type::Future(FutureType::new(
1151                inputs.clone(),
1152                Some(Location::new(callee_program, callee_path.clone())),
1153                true,
1154            ));
1155            let fully_inferred_type = match &func.output_type {
1156                Type::Tuple(tup) => Type::Tuple(TupleType::new(
1157                    tup.elements()
1158                        .iter()
1159                        .map(|t| if matches!(t, Type::Future(_)) { future_type.clone() } else { t.clone() })
1160                        .collect::<Vec<Type>>(),
1161                )),
1162                Type::Future(_) => future_type,
1163                _ => panic!("Invalid output type for async transition."),
1164            };
1165            self.assert_and_return_type(fully_inferred_type, expected, input.span())
1166        } else {
1167            self.assert_and_return_type(func.output_type, expected, input.span())
1168        };
1169
1170        // Check number of function arguments.
1171        if func.input.len() != input.arguments.len() {
1172            self.emit_err(TypeCheckerError::incorrect_num_args_to_call(
1173                func.input.len(),
1174                input.arguments.len(),
1175                input.span(),
1176            ));
1177        }
1178
1179        // Check the number of const arguments against the number of the function's const parameters
1180        if func.const_parameters.len() != input.const_arguments.len() {
1181            self.emit_err(TypeCheckerError::incorrect_num_const_args(
1182                "Call",
1183                func.const_parameters.len(),
1184                input.const_arguments.len(),
1185                input.span(),
1186            ));
1187        }
1188
1189        // Check the types of const arguments against the types of the function's const parameters
1190        for (expected, argument) in func.const_parameters.iter().zip(input.const_arguments.iter()) {
1191            self.visit_expression(argument, &Some(expected.type_().clone()));
1192        }
1193
1194        let (mut input_futures, mut inferred_finalize_inputs) = (Vec::new(), Vec::new());
1195        for (expected, argument) in func.input.iter().zip(input.arguments.iter()) {
1196            // Get the type of the expression. If the type is not known, do not attempt to attempt any further inference.
1197            let ty = self.visit_expression(argument, &Some(expected.type_().clone()));
1198
1199            if ty == Type::Err {
1200                return Type::Err;
1201            }
1202            // Extract information about futures that are being consumed.
1203            if func.variant == Variant::AsyncFunction && matches!(expected.type_(), Type::Future(_)) {
1204                // Consume the future.
1205                let option_name = match argument {
1206                    Expression::Path(path) => Some(path.identifier().name),
1207                    Expression::TupleAccess(tuple_access) => {
1208                        if let Expression::Path(path) = &tuple_access.tuple {
1209                            Some(path.identifier().name)
1210                        } else {
1211                            None
1212                        }
1213                    }
1214                    _ => None,
1215                };
1216
1217                if let Some(name) = option_name {
1218                    match self.scope_state.futures.shift_remove(&name) {
1219                        Some(future) => {
1220                            self.scope_state.call_location = Some(future);
1221                        }
1222                        None => {
1223                            self.emit_err(TypeCheckerError::unknown_future_consumed(name, argument.span()));
1224                        }
1225                    }
1226                }
1227
1228                match argument {
1229                    Expression::Path(_) | Expression::Call(_) | Expression::TupleAccess(_) => {
1230                        match &self.scope_state.call_location {
1231                            Some(location) => {
1232                                // Get the external program and function name.
1233                                input_futures.push(location.clone());
1234                                // Get the full inferred type.
1235                                inferred_finalize_inputs.push(ty);
1236                            }
1237                            None => {
1238                                self.emit_err(TypeCheckerError::unknown_future_consumed(argument, argument.span()));
1239                            }
1240                        }
1241                    }
1242                    _ => {
1243                        self.emit_err(TypeCheckerError::unknown_future_consumed("unknown", argument.span()));
1244                    }
1245                }
1246            } else {
1247                inferred_finalize_inputs.push(ty);
1248            }
1249        }
1250
1251        let caller_program =
1252            self.scope_state.program_name.expect("`program_name` is always set before traversing a program scope");
1253        // Note: Constructors are added to the call graph under the `constructor` symbol.
1254        // This is safe since `constructor` is a reserved token and cannot be used as a function name.
1255        let caller_function = if self.scope_state.is_constructor {
1256            sym::constructor
1257        } else {
1258            self.scope_state.function.expect("`function` is always set before traversing a function scope")
1259        };
1260
1261        // This is the path to the function that we're in
1262        let caller_path = self
1263            .scope_state
1264            .module_name
1265            .iter()
1266            .cloned()
1267            .chain(std::iter::once(caller_function))
1268            .collect::<Vec<Symbol>>();
1269
1270        let caller = Location::new(caller_program, caller_path.clone());
1271        let callee = Location::new(callee_program, callee_path.clone());
1272        self.state.call_graph.add_edge(caller, callee);
1273
1274        if func.variant.is_transition() && self.scope_state.variant == Some(Variant::AsyncTransition) {
1275            if self.scope_state.has_called_finalize {
1276                self.emit_err(TypeCheckerError::external_call_after_async("function call", input.span));
1277            }
1278
1279            if self.scope_state.already_contains_an_async_block {
1280                self.emit_err(TypeCheckerError::external_call_after_async("block", input.span));
1281            }
1282        }
1283
1284        // Propagate futures from async functions and transitions.
1285        if func.variant.is_async_function() {
1286            // Cannot have async calls in a conditional block.
1287            if self.scope_state.is_conditional {
1288                self.emit_err(TypeCheckerError::async_call_in_conditional(input.span));
1289            }
1290
1291            // Can only call async functions and external async transitions from an async transition body.
1292            if !matches!(self.scope_state.variant, Some(Variant::AsyncTransition) | Some(Variant::Script)) {
1293                self.emit_err(TypeCheckerError::async_call_can_only_be_done_from_async_transition(input.span));
1294            }
1295
1296            // Can only call an async function once in a transition function body.
1297            if self.scope_state.has_called_finalize {
1298                self.emit_err(TypeCheckerError::must_call_async_function_once(input.span));
1299            }
1300
1301            if self.scope_state.already_contains_an_async_block {
1302                self.emit_err(TypeCheckerError::conflicting_async_call_and_block(input.span));
1303            }
1304
1305            // Check that all futures consumed.
1306            if !self.scope_state.futures.is_empty() {
1307                self.emit_err(TypeCheckerError::not_all_futures_consumed(
1308                    self.scope_state.futures.iter().map(|(f, _)| f).join(", "),
1309                    input.span,
1310                ));
1311            }
1312            self.state
1313                .symbol_table
1314                .attach_finalizer(
1315                    Location::new(callee_program, caller_path),
1316                    Location::new(callee_program, callee_path.clone()),
1317                    input_futures,
1318                    inferred_finalize_inputs.clone(),
1319                )
1320                .expect("Failed to attach finalizer");
1321            // Create expectation for finalize inputs that will be checked when checking corresponding finalize function signature.
1322            self.async_function_callers
1323                .entry(Location::new(self.scope_state.program_name.unwrap(), callee_path.clone()))
1324                .or_default()
1325                .insert(self.scope_state.location());
1326
1327            // Set scope state flag.
1328            self.scope_state.has_called_finalize = true;
1329
1330            // Update ret to reflect fully inferred future type.
1331            ret = Type::Future(FutureType::new(
1332                inferred_finalize_inputs,
1333                Some(Location::new(callee_program, callee_path.clone())),
1334                true,
1335            ));
1336
1337            // Type check in case the expected type is known.
1338            self.assert_and_return_type(ret.clone(), expected, input.span());
1339        }
1340
1341        // Set call location so that definition statement knows where future comes from.
1342        self.scope_state.call_location = Some(Location::new(callee_program, callee_path.clone()));
1343
1344        ret
1345    }
1346
1347    fn visit_cast(&mut self, input: &CastExpression, expected: &Self::AdditionalInput) -> Self::Output {
1348        let expression_type = self.visit_expression_reject_numeric(&input.expression, &None);
1349
1350        let assert_castable_type = |actual: &Type, span: Span| {
1351            if !matches!(
1352                actual,
1353                Type::Integer(_) | Type::Boolean | Type::Field | Type::Group | Type::Scalar | Type::Address | Type::Err,
1354            ) {
1355                self.emit_err(TypeCheckerError::type_should_be2(
1356                    actual,
1357                    "an integer, bool, field, group, scalar, or address",
1358                    span,
1359                ));
1360            }
1361        };
1362
1363        assert_castable_type(&input.type_, input.span());
1364
1365        assert_castable_type(&expression_type, input.expression.span());
1366
1367        self.maybe_assert_type(&input.type_, expected, input.span());
1368
1369        input.type_.clone()
1370    }
1371
1372    fn visit_struct_init(&mut self, input: &StructExpression, additional: &Self::AdditionalInput) -> Self::Output {
1373        let struct_ = self.lookup_struct(self.scope_state.program_name, &input.path.absolute_path()).clone();
1374        let Some(struct_) = struct_ else {
1375            self.emit_err(TypeCheckerError::unknown_sym("struct or record", input.path.clone(), input.path.span()));
1376            return Type::Err;
1377        };
1378
1379        // Check the number of const arguments against the number of the struct's const parameters
1380        if struct_.const_parameters.len() != input.const_arguments.len() {
1381            self.emit_err(TypeCheckerError::incorrect_num_const_args(
1382                "Struct expression",
1383                struct_.const_parameters.len(),
1384                input.const_arguments.len(),
1385                input.span(),
1386            ));
1387        }
1388
1389        // Check the types of const arguments against the types of the struct's const parameters
1390        for (expected, argument) in struct_.const_parameters.iter().zip(input.const_arguments.iter()) {
1391            self.visit_expression(argument, &Some(expected.type_().clone()));
1392        }
1393
1394        // Note that it is sufficient for the `program` to be `None` as composite types can only be initialized
1395        // in the program in which they are defined.
1396        let type_ = Type::Composite(CompositeType {
1397            path: input.path.clone(),
1398            const_arguments: input.const_arguments.clone(),
1399            program: None,
1400        });
1401        self.maybe_assert_type(&type_, additional, input.path.span());
1402
1403        // Check number of struct members.
1404        if struct_.members.len() != input.members.len() {
1405            self.emit_err(TypeCheckerError::incorrect_num_struct_members(
1406                struct_.members.len(),
1407                input.members.len(),
1408                input.span(),
1409            ));
1410        }
1411
1412        for Member { identifier, type_, .. } in struct_.members.iter() {
1413            if let Some(actual) = input.members.iter().find(|member| member.identifier.name == identifier.name) {
1414                match &actual.expression {
1415                    None => {
1416                        // If `expression` is None, then the member uses the identifier shorthand, e.g. `Foo { a }`
1417                        // We visit it as an expression rather than just calling `visit_path` so it will get
1418                        // put into the type table.
1419                        self.visit_expression(
1420                            &Path::from(actual.identifier)
1421                                .with_absolute_path(Some(
1422                                    self.scope_state
1423                                        .module_name
1424                                        .iter()
1425                                        .cloned()
1426                                        .chain(std::iter::once(actual.identifier.name))
1427                                        .collect::<Vec<Symbol>>(),
1428                                ))
1429                                .into(),
1430                            &Some(type_.clone()),
1431                        );
1432                    }
1433                    Some(expr) => {
1434                        // Otherwise, visit the associated expression.
1435                        self.visit_expression(expr, &Some(type_.clone()));
1436                    }
1437                };
1438            } else {
1439                self.emit_err(TypeCheckerError::missing_struct_member(struct_.identifier, identifier, input.span()));
1440            };
1441        }
1442
1443        if struct_.is_record {
1444            // First, ensure that the current scope is not an async function. Records should not be instantiated in
1445            // async functions
1446            if self.scope_state.variant == Some(Variant::AsyncFunction) {
1447                self.state
1448                    .handler
1449                    .emit_err(TypeCheckerError::records_not_allowed_inside_async("function", input.span()));
1450            }
1451
1452            // Similarly, ensure that the current scope is not an async block. Records should not be instantiated in
1453            // async blocks
1454            if self.async_block_id.is_some() {
1455                self.state.handler.emit_err(TypeCheckerError::records_not_allowed_inside_async("block", input.span()));
1456            }
1457
1458            // Records where the `owner` is `self.caller` can be problematic because `self.caller` can be a program
1459            // address and programs can't spend records. Emit a warning in this case.
1460            //
1461            // Multiple occurrences of `owner` here is an error but that should be flagged somewhere else.
1462            input.members.iter().filter(|init| init.identifier.name == sym::owner).for_each(|init| {
1463                if let Some(Expression::MemberAccess(access)) = &init.expression {
1464                    if let MemberAccess {
1465                        inner: Expression::Path(path),
1466                        name: Identifier { name: sym::caller, .. },
1467                        ..
1468                    } = &**access
1469                    {
1470                        if path.identifier().name == sym::SelfLower {
1471                            self.emit_warning(TypeCheckerWarning::caller_as_record_owner(
1472                                input.path.clone(),
1473                                access.span(),
1474                            ));
1475                        }
1476                    }
1477                }
1478            });
1479        }
1480
1481        type_
1482    }
1483
1484    // We do not want to panic on `ErrExpression`s in order to propagate as many errors as possible.
1485    fn visit_err(&mut self, _input: &ErrExpression, _additional: &Self::AdditionalInput) -> Self::Output {
1486        Type::Err
1487    }
1488
1489    fn visit_path(&mut self, input: &Path, expected: &Self::AdditionalInput) -> Self::Output {
1490        let var = self.state.symbol_table.lookup_path(self.scope_state.program_name.unwrap(), &input.absolute_path());
1491
1492        if let Some(var) = var {
1493            self.maybe_assert_type(&var.type_, expected, input.span());
1494            var.type_.clone()
1495        } else {
1496            self.emit_err(TypeCheckerError::unknown_sym("variable", input, input.span()));
1497            Type::Err
1498        }
1499    }
1500
1501    fn visit_literal(&mut self, input: &Literal, expected: &Self::AdditionalInput) -> Self::Output {
1502        let span = input.span();
1503
1504        macro_rules! parse_and_return {
1505            ($ty:ty, $variant:expr, $str:expr, $label:expr) => {{
1506                self.parse_integer_literal::<$ty>($str, span, $label);
1507                Type::Integer($variant)
1508            }};
1509        }
1510
1511        let type_ = match &input.variant {
1512            LiteralVariant::Address(..) => Type::Address,
1513            LiteralVariant::Boolean(..) => Type::Boolean,
1514            LiteralVariant::Field(..) => Type::Field,
1515            LiteralVariant::Scalar(..) => Type::Scalar,
1516            LiteralVariant::String(..) => {
1517                self.emit_err(TypeCheckerError::strings_are_not_supported(span));
1518                Type::String
1519            }
1520            LiteralVariant::Integer(kind, string) => match kind {
1521                IntegerType::U8 => parse_and_return!(u8, IntegerType::U8, string, "u8"),
1522                IntegerType::U16 => parse_and_return!(u16, IntegerType::U16, string, "u16"),
1523                IntegerType::U32 => parse_and_return!(u32, IntegerType::U32, string, "u32"),
1524                IntegerType::U64 => parse_and_return!(u64, IntegerType::U64, string, "u64"),
1525                IntegerType::U128 => parse_and_return!(u128, IntegerType::U128, string, "u128"),
1526                IntegerType::I8 => parse_and_return!(i8, IntegerType::I8, string, "i8"),
1527                IntegerType::I16 => parse_and_return!(i16, IntegerType::I16, string, "i16"),
1528                IntegerType::I32 => parse_and_return!(i32, IntegerType::I32, string, "i32"),
1529                IntegerType::I64 => parse_and_return!(i64, IntegerType::I64, string, "i64"),
1530                IntegerType::I128 => parse_and_return!(i128, IntegerType::I128, string, "i128"),
1531            },
1532            LiteralVariant::Group(s) => {
1533                let trimmed = s.trim_start_matches('-').trim_start_matches('0');
1534                if !trimmed.is_empty()
1535                    && format!("{trimmed}group")
1536                        .parse::<snarkvm::prelude::Group<snarkvm::prelude::TestnetV0>>()
1537                        .is_err()
1538                {
1539                    self.emit_err(TypeCheckerError::invalid_int_value(trimmed, "group", span));
1540                }
1541                Type::Group
1542            }
1543            LiteralVariant::Unsuffixed(_) => match expected {
1544                Some(ty @ Type::Integer(_) | ty @ Type::Field | ty @ Type::Group | ty @ Type::Scalar) => {
1545                    self.check_numeric_literal(input, ty);
1546                    ty.clone()
1547                }
1548                Some(ty @ Type::Optional(opt)) => {
1549                    // Handle optional expected type, e.g., u32?
1550                    let inner = &opt.inner;
1551                    match &**inner {
1552                        Type::Integer(_) | Type::Field | Type::Group | Type::Scalar => {
1553                            self.check_numeric_literal(input, inner);
1554                            Type::Optional(OptionalType { inner: Box::new(*inner.clone()) })
1555                        }
1556                        _ => {
1557                            self.emit_err(TypeCheckerError::unexpected_unsuffixed_numeral(
1558                                format!("type `{ty}`"),
1559                                span,
1560                            ));
1561                            Type::Err
1562                        }
1563                    }
1564                }
1565                Some(ty) => {
1566                    self.emit_err(TypeCheckerError::unexpected_unsuffixed_numeral(format!("type `{ty}`"), span));
1567                    Type::Err
1568                }
1569                None => Type::Numeric,
1570            },
1571            LiteralVariant::None => {
1572                if let Some(ty @ Type::Optional(_)) = expected {
1573                    ty.clone()
1574                } else if let Some(ty) = expected {
1575                    self.emit_err(TypeCheckerError::none_found_non_optional(format!("{ty}"), span));
1576                    Type::Err
1577                } else {
1578                    self.emit_err(TypeCheckerError::could_not_determine_type(format!("{input}"), span));
1579                    Type::Err
1580                }
1581            }
1582        };
1583
1584        self.maybe_assert_type(&type_, expected, span);
1585
1586        type_
1587    }
1588
1589    fn visit_locator(&mut self, input: &LocatorExpression, expected: &Self::AdditionalInput) -> Self::Output {
1590        let maybe_var =
1591            self.state.symbol_table.lookup_global(&Location::new(input.program.name.name, vec![input.name])).cloned();
1592        if let Some(var) = maybe_var {
1593            self.maybe_assert_type(&var.type_, expected, input.span());
1594            var.type_
1595        } else {
1596            self.emit_err(TypeCheckerError::unknown_sym("variable", input.name, input.span()));
1597            Type::Err
1598        }
1599    }
1600
1601    fn visit_ternary(&mut self, input: &TernaryExpression, expected: &Self::AdditionalInput) -> Self::Output {
1602        self.visit_expression(&input.condition, &Some(Type::Boolean));
1603
1604        let t1 = self.visit_expression_reject_numeric(&input.if_true, expected);
1605        let t2 = self.visit_expression_reject_numeric(&input.if_false, expected);
1606
1607        let typ = if t1 == Type::Err || t2 == Type::Err {
1608            Type::Err
1609        } else if !t1.can_coerce_to(&t2) && !t2.can_coerce_to(&t1) {
1610            self.emit_err(TypeCheckerError::ternary_branch_mismatch(t1, t2, input.span()));
1611            Type::Err
1612        } else if let Some(expected) = expected {
1613            expected.clone()
1614        } else {
1615            t1
1616        };
1617
1618        // Make sure this isn't an external record type - won't work as we can't construct it.
1619        if self.is_external_record(&typ) {
1620            self.emit_err(TypeCheckerError::ternary_over_external_records(&typ, input.span));
1621        }
1622
1623        // None of its members may be external record types either.
1624        if let Type::Tuple(tuple) = &typ {
1625            if tuple.elements().iter().any(|ty| self.is_external_record(ty)) {
1626                self.emit_err(TypeCheckerError::ternary_over_external_records(&typ, input.span));
1627            }
1628        }
1629
1630        typ
1631    }
1632
1633    fn visit_tuple(&mut self, input: &TupleExpression, expected: &Self::AdditionalInput) -> Self::Output {
1634        if let Some(expected) = expected {
1635            if let Type::Tuple(expected_types) = expected {
1636                // If the expected type is a tuple, then ensure it's compatible with `input`
1637
1638                // First, make sure that the number of tuple elements is correct
1639                if expected_types.length() != input.elements.len() {
1640                    self.emit_err(TypeCheckerError::incorrect_tuple_length(
1641                        expected_types.length(),
1642                        input.elements.len(),
1643                        input.span(),
1644                    ));
1645                }
1646
1647                // Now make sure that none of the tuple elements is a tuple
1648                input.elements.iter().zip(expected_types.elements()).for_each(|(expr, expected_el_ty)| {
1649                    if matches!(expr, Expression::Tuple(_)) {
1650                        self.emit_err(TypeCheckerError::nested_tuple_expression(expr.span()));
1651                    }
1652                    self.visit_expression(expr, &Some(expected_el_ty.clone()));
1653                });
1654
1655                // Just return the expected type since we proved it's correct
1656                expected.clone()
1657            } else {
1658                // If the expected type is not a tuple, then we just error out
1659
1660                // This is the expected type of the tuple based on its individual fields
1661                let field_types = input
1662                    .elements
1663                    .iter()
1664                    .map(|field| {
1665                        let ty = self.visit_expression(field, &None);
1666                        if ty == Type::Numeric {
1667                            self.emit_err(TypeCheckerError::could_not_determine_type(field.clone(), field.span()));
1668                            Type::Err
1669                        } else {
1670                            ty
1671                        }
1672                    })
1673                    .collect::<Vec<_>>();
1674                if field_types.iter().all(|f| *f != Type::Err) {
1675                    let tuple_type = Type::Tuple(TupleType::new(field_types));
1676                    self.emit_err(TypeCheckerError::type_should_be2(tuple_type, expected, input.span()));
1677                }
1678
1679                // Recover with the expected type anyways
1680                expected.clone()
1681            }
1682        } else {
1683            // If no `expected` type is provided, then we analyze the tuple itself and infer its type
1684
1685            // We still need to check that none of the tuple elements is a tuple
1686            input.elements.iter().for_each(|expr| {
1687                if matches!(expr, Expression::Tuple(_)) {
1688                    self.emit_err(TypeCheckerError::nested_tuple_expression(expr.span()));
1689                }
1690            });
1691
1692            Type::Tuple(TupleType::new(
1693                input
1694                    .elements
1695                    .iter()
1696                    .map(|field| {
1697                        let ty = self.visit_expression(field, &None);
1698                        if ty == Type::Numeric {
1699                            self.emit_err(TypeCheckerError::could_not_determine_type(field.clone(), field.span()));
1700                            Type::Err
1701                        } else {
1702                            ty
1703                        }
1704                    })
1705                    .collect::<Vec<_>>(),
1706            ))
1707        }
1708    }
1709
1710    fn visit_unary(&mut self, input: &UnaryExpression, destination: &Self::AdditionalInput) -> Self::Output {
1711        let operand_expected = self.unwrap_optional_type(destination);
1712
1713        let assert_signed_int = |slf: &mut Self, type_: &Type| {
1714            if !matches!(
1715                type_,
1716                Type::Err
1717                    | Type::Integer(IntegerType::I8)
1718                    | Type::Integer(IntegerType::I16)
1719                    | Type::Integer(IntegerType::I32)
1720                    | Type::Integer(IntegerType::I64)
1721                    | Type::Integer(IntegerType::I128)
1722            ) {
1723                slf.emit_err(TypeCheckerError::type_should_be2(type_, "a signed integer", input.span()));
1724            }
1725        };
1726
1727        let ty = match input.op {
1728            UnaryOperation::Abs => {
1729                let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
1730                assert_signed_int(self, &type_);
1731                type_
1732            }
1733            UnaryOperation::AbsWrapped => {
1734                let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
1735                assert_signed_int(self, &type_);
1736                type_
1737            }
1738            UnaryOperation::Double => {
1739                let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
1740                if !matches!(&type_, Type::Err | Type::Field | Type::Group) {
1741                    self.emit_err(TypeCheckerError::type_should_be2(&type_, "a field or group", input.span()));
1742                }
1743                type_
1744            }
1745            UnaryOperation::Inverse => {
1746                let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
1747                if type_ == Type::Numeric {
1748                    // We can actually infer to `field` here because only fields can be inverted
1749                    type_ = Type::Field;
1750                    self.state.type_table.insert(input.receiver.id(), Type::Field);
1751                } else {
1752                    self.assert_type(&type_, &Type::Field, input.span());
1753                }
1754                type_
1755            }
1756            UnaryOperation::Negate => {
1757                let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
1758                if !matches!(
1759                    &type_,
1760                    Type::Err
1761                        | Type::Integer(IntegerType::I8)
1762                        | Type::Integer(IntegerType::I16)
1763                        | Type::Integer(IntegerType::I32)
1764                        | Type::Integer(IntegerType::I64)
1765                        | Type::Integer(IntegerType::I128)
1766                        | Type::Group
1767                        | Type::Field
1768                ) {
1769                    self.emit_err(TypeCheckerError::type_should_be2(
1770                        &type_,
1771                        "a signed integer, group, or field",
1772                        input.receiver.span(),
1773                    ));
1774                }
1775                type_
1776            }
1777            UnaryOperation::Not => {
1778                let type_ = self.visit_expression_reject_numeric(&input.receiver, &operand_expected);
1779                if !matches!(&type_, Type::Err | Type::Boolean | Type::Integer(_)) {
1780                    self.emit_err(TypeCheckerError::type_should_be2(&type_, "a bool or integer", input.span()));
1781                }
1782                type_
1783            }
1784            UnaryOperation::Square => {
1785                let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
1786                if type_ == Type::Numeric {
1787                    // We can actually infer to `field` here because only fields can be squared
1788                    type_ = Type::Field;
1789                    self.state.type_table.insert(input.receiver.id(), Type::Field);
1790                } else {
1791                    self.assert_type(&type_, &Type::Field, input.span());
1792                }
1793                type_
1794            }
1795            UnaryOperation::SquareRoot => {
1796                let mut type_ = self.visit_expression(&input.receiver, &operand_expected);
1797                if type_ == Type::Numeric {
1798                    // We can actually infer to `field` here because only fields can be square-rooted
1799                    type_ = Type::Field;
1800                    self.state.type_table.insert(input.receiver.id(), Type::Field);
1801                } else {
1802                    self.assert_type(&type_, &Type::Field, input.span());
1803                }
1804                type_
1805            }
1806            UnaryOperation::ToXCoordinate | UnaryOperation::ToYCoordinate => {
1807                let _operand_type = self.visit_expression(&input.receiver, &Some(Type::Group));
1808                self.maybe_assert_type(&Type::Field, destination, input.span());
1809                Type::Field
1810            }
1811        };
1812
1813        self.maybe_assert_type(&ty, destination, input.span());
1814
1815        self.wrap_if_optional(ty, destination)
1816    }
1817
1818    fn visit_unit(&mut self, _input: &UnitExpression, _additional: &Self::AdditionalInput) -> Self::Output {
1819        Type::Unit
1820    }
1821
1822    /* Statements */
1823    fn visit_statement(&mut self, input: &Statement) {
1824        // No statements can follow a return statement.
1825        if self.scope_state.has_return {
1826            self.emit_err(TypeCheckerError::unreachable_code_after_return(input.span()));
1827            return;
1828        }
1829
1830        match input {
1831            Statement::Assert(stmt) => self.visit_assert(stmt),
1832            Statement::Assign(stmt) => self.visit_assign(stmt),
1833            Statement::Block(stmt) => self.visit_block(stmt),
1834            Statement::Conditional(stmt) => self.visit_conditional(stmt),
1835            Statement::Const(stmt) => self.visit_const(stmt),
1836            Statement::Definition(stmt) => self.visit_definition(stmt),
1837            Statement::Expression(stmt) => self.visit_expression_statement(stmt),
1838            Statement::Iteration(stmt) => self.visit_iteration(stmt),
1839            Statement::Return(stmt) => self.visit_return(stmt),
1840        }
1841    }
1842
1843    fn visit_assert(&mut self, input: &AssertStatement) {
1844        match &input.variant {
1845            AssertVariant::Assert(expr) => {
1846                let _type = self.visit_expression(expr, &Some(Type::Boolean));
1847            }
1848            AssertVariant::AssertEq(left, right) | AssertVariant::AssertNeq(left, right) => {
1849                let t1 = self.visit_expression_reject_numeric(left, &None);
1850                let t2 = self.visit_expression_reject_numeric(right, &None);
1851
1852                if t1 != Type::Err && t2 != Type::Err && !t1.eq_user(&t2) {
1853                    let op =
1854                        if matches!(input.variant, AssertVariant::AssertEq(..)) { "assert_eq" } else { "assert_neq" };
1855                    self.emit_err(TypeCheckerError::operation_types_mismatch(op, t1, t2, input.span()));
1856                }
1857            }
1858        }
1859    }
1860
1861    fn visit_assign(&mut self, input: &AssignStatement) {
1862        let lhs_type = self.visit_expression_assign(&input.place);
1863
1864        self.visit_expression(&input.value, &Some(lhs_type.clone()));
1865    }
1866
1867    fn visit_block(&mut self, input: &Block) {
1868        self.in_scope(input.id, |slf| {
1869            input.statements.iter().for_each(|stmt| slf.visit_statement(stmt));
1870        });
1871    }
1872
1873    fn visit_conditional(&mut self, input: &ConditionalStatement) {
1874        self.visit_expression(&input.condition, &Some(Type::Boolean));
1875
1876        let mut then_block_has_return = false;
1877        let mut otherwise_block_has_return = false;
1878
1879        // Set the `has_return` flag for the then-block.
1880        let previous_has_return = core::mem::replace(&mut self.scope_state.has_return, then_block_has_return);
1881        // Set the `is_conditional` flag.
1882        let previous_is_conditional = core::mem::replace(&mut self.scope_state.is_conditional, true);
1883
1884        // Visit block.
1885        self.in_conditional_scope(|slf| slf.visit_block(&input.then));
1886
1887        // Store the `has_return` flag for the then-block.
1888        then_block_has_return = self.scope_state.has_return;
1889
1890        if let Some(otherwise) = &input.otherwise {
1891            // Set the `has_return` flag for the otherwise-block.
1892            self.scope_state.has_return = otherwise_block_has_return;
1893
1894            match &**otherwise {
1895                Statement::Block(stmt) => {
1896                    // Visit the otherwise-block.
1897                    self.in_conditional_scope(|slf| slf.visit_block(stmt));
1898                }
1899                Statement::Conditional(stmt) => self.visit_conditional(stmt),
1900                _ => unreachable!("Else-case can only be a block or conditional statement."),
1901            }
1902
1903            // Store the `has_return` flag for the otherwise-block.
1904            otherwise_block_has_return = self.scope_state.has_return;
1905        }
1906
1907        // Restore the previous `has_return` flag.
1908        self.scope_state.has_return = previous_has_return || (then_block_has_return && otherwise_block_has_return);
1909        // Restore the previous `is_conditional` flag.
1910        self.scope_state.is_conditional = previous_is_conditional;
1911    }
1912
1913    fn visit_const(&mut self, input: &ConstDeclaration) {
1914        self.visit_type(&input.type_);
1915
1916        // For now, consts that contain optional types are not supported.
1917        // TODO: remove this restriction by supporting const evaluation of optionals including `None`.
1918        if self.contains_optional_type(&input.type_) {
1919            self.emit_err(TypeCheckerError::const_cannot_be_optional(input.span));
1920        }
1921
1922        // Check that the type of the definition is not a unit type, singleton tuple type, or nested tuple type.
1923        match &input.type_ {
1924            // If the type is an empty tuple, return an error.
1925            Type::Unit => self.emit_err(TypeCheckerError::lhs_must_be_identifier_or_tuple(input.span)),
1926            // If the type is a singleton tuple, return an error.
1927            Type::Tuple(tuple) => match tuple.length() {
1928                0 | 1 => unreachable!("Parsing guarantees that tuple types have at least two elements."),
1929                _ => {
1930                    if tuple.elements().iter().any(|type_| matches!(type_, Type::Tuple(_))) {
1931                        self.emit_err(TypeCheckerError::nested_tuple_type(input.span))
1932                    }
1933                }
1934            },
1935            Type::Mapping(_) | Type::Err => unreachable!(
1936                "Parsing guarantees that `mapping` and `err` types are not present at this location in the AST."
1937            ),
1938            // Otherwise, the type is valid.
1939            _ => (), // Do nothing
1940        }
1941
1942        // Check the expression on the right-hand side.
1943        self.visit_expression(&input.value, &Some(input.type_.clone()));
1944
1945        if self.scope_state.function.is_some() {
1946            // Global consts have already been added to the symbol table, so only
1947            // add this one if it's local.
1948            if let Err(err) = self.state.symbol_table.insert_variable(
1949                self.scope_state.program_name.unwrap(),
1950                &[input.place.name],
1951                VariableSymbol { type_: input.type_.clone(), span: input.place.span, declaration: VariableType::Const },
1952            ) {
1953                self.state.handler.emit_err(err);
1954            }
1955        }
1956    }
1957
1958    fn visit_definition(&mut self, input: &DefinitionStatement) {
1959        // Check that the type annotation of the definition is valid, if provided.
1960        if let Some(ty) = &input.type_ {
1961            self.visit_type(ty);
1962            self.assert_type_is_valid(ty, input.span);
1963        }
1964
1965        // Check that the type of the definition is not a unit type, singleton tuple type, or nested tuple type.
1966        match &input.type_ {
1967            // If the type is a singleton tuple, return an error.
1968            Some(Type::Tuple(tuple)) => match tuple.length() {
1969                0 | 1 => unreachable!("Parsing guarantees that tuple types have at least two elements."),
1970                _ => {
1971                    for type_ in tuple.elements() {
1972                        if matches!(type_, Type::Tuple(_)) {
1973                            self.emit_err(TypeCheckerError::nested_tuple_type(input.span))
1974                        }
1975                    }
1976                }
1977            },
1978            Some(Type::Mapping(_)) | Some(Type::Err) => unreachable!(
1979                "Parsing guarantees that `mapping` and `err` types are not present at this location in the AST."
1980            ),
1981            // Otherwise, the type is valid.
1982            _ => (), // Do nothing
1983        }
1984
1985        // Check the expression on the right-hand side. If we could not resolve `Type::Numeric`, then just give up.
1986        // We could do better in the future by potentially looking at consumers of this variable and inferring type
1987        // information from them.
1988        let inferred_type = self.visit_expression_reject_numeric(&input.value, &input.type_);
1989
1990        // Insert the variables into the symbol table.
1991        match &input.place {
1992            DefinitionPlace::Single(identifier) => {
1993                self.insert_variable(
1994                    Some(inferred_type.clone()),
1995                    identifier,
1996                    // If no type annotation is provided, then just use `inferred_type`.
1997                    input.type_.clone().unwrap_or(inferred_type),
1998                    identifier.span,
1999                );
2000            }
2001            DefinitionPlace::Multiple(identifiers) => {
2002                // Get the tuple type either from `input.type_` or from `inferred_type`.
2003                let tuple_type = match (&input.type_, inferred_type.clone()) {
2004                    (Some(Type::Tuple(tuple_type)), _) => tuple_type.clone(),
2005                    (None, Type::Tuple(tuple_type)) => tuple_type.clone(),
2006                    _ => {
2007                        // This is an error but should have been emitted earlier. Just exit here.
2008                        return;
2009                    }
2010                };
2011
2012                // Ensure the number of identifiers we're defining is the same as the number of tuple elements, as
2013                // indicated by `tuple_type`
2014                if identifiers.len() != tuple_type.length() {
2015                    return self.emit_err(TypeCheckerError::incorrect_num_tuple_elements(
2016                        identifiers.len(),
2017                        tuple_type.length(),
2018                        input.span(),
2019                    ));
2020                }
2021
2022                // Now just insert each tuple element as a separate variable
2023                for (i, identifier) in identifiers.iter().enumerate() {
2024                    let inferred = if let Type::Tuple(inferred_tuple) = &inferred_type {
2025                        inferred_tuple.elements().get(i).cloned().unwrap_or_default()
2026                    } else {
2027                        Type::Err
2028                    };
2029                    self.insert_variable(Some(inferred), identifier, tuple_type.elements()[i].clone(), identifier.span);
2030                }
2031            }
2032        }
2033    }
2034
2035    fn visit_expression_statement(&mut self, input: &ExpressionStatement) {
2036        // Expression statements can only be function calls.
2037        if !matches!(input.expression, Expression::Call(_) | Expression::AssociatedFunction(_)) {
2038            self.emit_err(TypeCheckerError::expression_statement_must_be_function_call(input.span()));
2039        } else {
2040            // Check the expression.
2041            self.visit_expression(&input.expression, &None);
2042        }
2043    }
2044
2045    fn visit_iteration(&mut self, input: &IterationStatement) {
2046        // Ensure the type annotation is an integer type
2047        if let Some(ty) = &input.type_ {
2048            self.visit_type(ty);
2049            self.assert_int_type(ty, input.variable.span);
2050        }
2051
2052        // These are the types of the start and end expressions of the iterator range. `visit_expression` will make
2053        // sure they match `input.type_` (i.e. the iterator type annotation) if available.
2054        let start_ty = self.visit_expression(&input.start, &input.type_.clone());
2055        let stop_ty = self.visit_expression(&input.stop, &input.type_.clone());
2056
2057        // Ensure both types are integer types
2058        self.assert_int_type(&start_ty, input.start.span());
2059        self.assert_int_type(&stop_ty, input.stop.span());
2060
2061        if start_ty != stop_ty {
2062            // Emit an error if the types of the range bounds do not match
2063            self.emit_err(TypeCheckerError::range_bounds_type_mismatch(input.start.span() + input.stop.span()));
2064        }
2065
2066        // Now, just set the type of the iterator variable to `start_ty` if `input.type_` is not available. If `stop_ty`
2067        // does not match `start_ty` and `input.type_` is not available, the we just recover with `start_ty` anyways
2068        // and continue.
2069        let iterator_ty = input.type_.clone().unwrap_or(start_ty);
2070        self.state.type_table.insert(input.variable.id(), iterator_ty.clone());
2071
2072        self.in_scope(input.id(), |slf| {
2073            // Add the loop variable to the scope of the loop body.
2074            if let Err(err) = slf.state.symbol_table.insert_variable(
2075                slf.scope_state.program_name.unwrap(),
2076                &[input.variable.name],
2077                VariableSymbol { type_: iterator_ty.clone(), span: input.span(), declaration: VariableType::Const },
2078            ) {
2079                slf.state.handler.emit_err(err);
2080            }
2081
2082            let prior_has_return = core::mem::take(&mut slf.scope_state.has_return);
2083            let prior_has_finalize = core::mem::take(&mut slf.scope_state.has_called_finalize);
2084
2085            slf.visit_block(&input.block);
2086
2087            if slf.scope_state.has_return {
2088                slf.emit_err(TypeCheckerError::loop_body_contains_return(input.span()));
2089            }
2090
2091            if slf.scope_state.has_called_finalize {
2092                slf.emit_err(TypeCheckerError::loop_body_contains_async("function call", input.span()));
2093            }
2094
2095            if slf.scope_state.already_contains_an_async_block {
2096                slf.emit_err(TypeCheckerError::loop_body_contains_async("block expression", input.span()));
2097            }
2098
2099            slf.scope_state.has_return = prior_has_return;
2100            slf.scope_state.has_called_finalize = prior_has_finalize;
2101        });
2102    }
2103
2104    fn visit_return(&mut self, input: &ReturnStatement) {
2105        if self.async_block_id.is_some() {
2106            return self.emit_err(TypeCheckerError::async_block_cannot_return(input.span()));
2107        }
2108
2109        if self.scope_state.is_constructor {
2110            // It must return a unit value; nothing else to check.
2111            if !matches!(input.expression, Expression::Unit(..)) {
2112                self.emit_err(TypeCheckerError::constructor_can_only_return_unit(&input.expression, input.span));
2113            }
2114            return;
2115        }
2116
2117        let caller_name = self.scope_state.function.expect("`self.function` is set every time a function is visited.");
2118        let caller_path =
2119            self.scope_state.module_name.iter().cloned().chain(std::iter::once(caller_name)).collect::<Vec<Symbol>>();
2120
2121        let func_symbol = self
2122            .state
2123            .symbol_table
2124            .lookup_function(&Location::new(self.scope_state.program_name.unwrap(), caller_path.clone()))
2125            .expect("The symbol table creator should already have visited all functions.");
2126
2127        let mut return_type = func_symbol.function.output_type.clone();
2128
2129        if self.scope_state.variant == Some(Variant::AsyncTransition) && self.scope_state.has_called_finalize {
2130            let inferred_future_type = Future(FutureType::new(
2131                if let Some(finalizer) = &func_symbol.finalizer { finalizer.inferred_inputs.clone() } else { vec![] },
2132                Some(Location::new(self.scope_state.program_name.unwrap(), caller_path)),
2133                true,
2134            ));
2135
2136            // Need to modify return type since the function signature is just default future, but the actual return
2137            // type is the fully inferred future of the finalize input type.
2138            let inferred = match return_type.clone() {
2139                Future(_) => inferred_future_type,
2140                Tuple(tuple) => Tuple(TupleType::new(
2141                    tuple
2142                        .elements()
2143                        .iter()
2144                        .map(|t| if matches!(t, Future(_)) { inferred_future_type.clone() } else { t.clone() })
2145                        .collect::<Vec<Type>>(),
2146                )),
2147                _ => {
2148                    return self.emit_err(TypeCheckerError::async_transition_missing_future_to_return(input.span()));
2149                }
2150            };
2151
2152            // Check that the explicit type declared in the function output signature matches the inferred type.
2153            return_type = self.assert_and_return_type(inferred, &Some(return_type), input.span());
2154        }
2155
2156        if matches!(input.expression, Expression::Unit(..)) {
2157            // Manually type check rather than using one of the assert functions for a better error message.
2158            if return_type != Type::Unit {
2159                // TODO - This is a bit hackish. We're reusing an existing error, because
2160                // we have too many errors in TypeCheckerError without hitting the recursion
2161                // limit for macros. But the error message to the user should still be pretty clear.
2162                return self.emit_err(TypeCheckerError::missing_return(input.span()));
2163            }
2164        }
2165
2166        self.visit_expression(&input.expression, &Some(return_type));
2167
2168        // Set the `has_return` flag after processing `input.expression` so that we don't error out
2169        // on something like `return async { .. }`.
2170        self.scope_state.has_return = true;
2171    }
2172}