leo_ast/types/
type_.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    ArrayType,
19    CompositeType,
20    FutureType,
21    Identifier,
22    IntegerType,
23    MappingType,
24    OptionalType,
25    Path,
26    TupleType,
27    VectorType,
28};
29
30use itertools::Itertools;
31use leo_span::Symbol;
32use serde::{Deserialize, Serialize};
33use snarkvm::prelude::{
34    LiteralType,
35    Network,
36    PlaintextType,
37    PlaintextType::{Array, Literal, Struct},
38};
39use std::fmt;
40
41/// Explicit type used for defining a variable or expression type
42#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
43pub enum Type {
44    /// The `address` type.
45    Address,
46    /// The array type.
47    Array(ArrayType),
48    /// The `bool` type.
49    Boolean,
50    /// The `struct` type.
51    Composite(CompositeType),
52    /// The `field` type.
53    Field,
54    /// The `future` type.
55    Future(FutureType),
56    /// The `group` type.
57    Group,
58    /// A reference to a built in type.
59    Identifier(Identifier),
60    /// An integer type.
61    Integer(IntegerType),
62    /// A mapping type.
63    Mapping(MappingType),
64    /// A nullable type.
65    Optional(OptionalType),
66    /// The `scalar` type.
67    Scalar,
68    /// The `signature` type.
69    Signature,
70    /// The `string` type.
71    String,
72    /// A static tuple of at least one type.
73    Tuple(TupleType),
74    /// The vector type.
75    Vector(VectorType),
76    /// Numeric type which should be resolved to `Field`, `Group`, `Integer(_)`, or `Scalar`.
77    Numeric,
78    /// The `unit` type.
79    Unit,
80    /// Placeholder for a type that could not be resolved or was not well-formed.
81    /// Will eventually lead to a compile error.
82    #[default]
83    Err,
84}
85
86impl Type {
87    /// Are the types considered equal as far as the Leo user is concerned?
88    ///
89    /// In particular, any comparison involving an `Err` is `true`, and Futures which aren't explicit compare equal to
90    /// other Futures.
91    ///
92    /// An array with an undetermined length (e.g., one that depends on a `const`) is considered equal to other arrays
93    /// if their element types match. This allows const propagation to potentially resolve the length before type
94    /// checking is performed again.
95    ///
96    /// Composite types are considered equal if their names and resolved program names match. If either side still has
97    /// const generic arguments, they are treated as equal unconditionally since monomorphization and other passes of
98    /// type-checking will handle mismatches later.
99    pub fn eq_user(&self, other: &Type) -> bool {
100        match (self, other) {
101            (Type::Err, _)
102            | (_, Type::Err)
103            | (Type::Address, Type::Address)
104            | (Type::Boolean, Type::Boolean)
105            | (Type::Field, Type::Field)
106            | (Type::Group, Type::Group)
107            | (Type::Scalar, Type::Scalar)
108            | (Type::Signature, Type::Signature)
109            | (Type::String, Type::String)
110            | (Type::Unit, Type::Unit) => true,
111            (Type::Array(left), Type::Array(right)) => {
112                (match (left.length.as_u32(), right.length.as_u32()) {
113                    (Some(l1), Some(l2)) => l1 == l2,
114                    _ => {
115                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
116                        // equal to other arrays because their lengths _may_ eventually be proven equal.
117                        true
118                    }
119                }) && left.element_type().eq_user(right.element_type())
120            }
121            (Type::Identifier(left), Type::Identifier(right)) => left.name == right.name,
122            (Type::Integer(left), Type::Integer(right)) => left == right,
123            (Type::Mapping(left), Type::Mapping(right)) => {
124                left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
125            }
126            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
127            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
128                .elements()
129                .iter()
130                .zip_eq(right.elements().iter())
131                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
132            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
133            (Type::Composite(left), Type::Composite(right)) => {
134                // If either composite still has const generic arguments, treat them as equal.
135                // Type checking will run again after monomorphization.
136                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
137                    return true;
138                }
139
140                // Two composite types are the same if their programs and their _absolute_ paths match.
141                (left.program == right.program)
142                    && match (&left.path.try_absolute_path(), &right.path.try_absolute_path()) {
143                        (Some(l), Some(r)) => l == r,
144                        _ => false,
145                    }
146            }
147            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
148            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
149                .inputs()
150                .iter()
151                .zip_eq(right.inputs().iter())
152                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
153            _ => false,
154        }
155    }
156
157    /// Returns `true` if the self `Type` is equal to the other `Type` in all aspects besides composite program of origin.
158    ///
159    /// In the case of futures, it also makes sure that if both are not explicit, they are equal.
160    ///
161    /// Flattens array syntax: `[[u8; 1]; 2] == [u8; (2, 1)] == true`
162    ///
163    /// Composite types are considered equal if their names match. If either side still has const generic arguments,
164    /// they are treated as equal unconditionally since monomorphization and other passes of type-checking will handle
165    /// mismatches later.
166    pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
167        match (self, other) {
168            (Type::Address, Type::Address)
169            | (Type::Boolean, Type::Boolean)
170            | (Type::Field, Type::Field)
171            | (Type::Group, Type::Group)
172            | (Type::Scalar, Type::Scalar)
173            | (Type::Signature, Type::Signature)
174            | (Type::String, Type::String)
175            | (Type::Unit, Type::Unit) => true,
176            (Type::Array(left), Type::Array(right)) => {
177                // Two arrays are equal if their element types are the same and if their lengths
178                // are the same, assuming the lengths can be extracted as `u32`.
179                (match (left.length.as_u32(), right.length.as_u32()) {
180                    (Some(l1), Some(l2)) => l1 == l2,
181                    _ => {
182                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
183                        // equal to other arrays because their lengths _may_ eventually be proven equal.
184                        true
185                    }
186                }) && left.element_type().eq_flat_relaxed(right.element_type())
187            }
188            (Type::Identifier(left), Type::Identifier(right)) => left.matches(right),
189            (Type::Integer(left), Type::Integer(right)) => left.eq(right),
190            (Type::Mapping(left), Type::Mapping(right)) => {
191                left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
192            }
193            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
194            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
195                .elements()
196                .iter()
197                .zip_eq(right.elements().iter())
198                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
199            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
200            (Type::Composite(left), Type::Composite(right)) => {
201                // If either composite still has const generic arguments, treat them as equal.
202                // Type checking will run again after monomorphization.
203                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
204                    return true;
205                }
206
207                // Two composite types are the same if their _absolute_ paths match.
208                // If the absolute paths are not available, then we really can't compare the two
209                // types and we just return `false` to be conservative.
210                match (&left.path.try_absolute_path(), &right.path.try_absolute_path()) {
211                    (Some(l), Some(r)) => l == r,
212                    _ => false,
213                }
214            }
215            // Don't type check when type hasn't been explicitly defined.
216            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
217            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
218                .inputs()
219                .iter()
220                .zip_eq(right.inputs().iter())
221                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
222            _ => false,
223        }
224    }
225
226    pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program: Option<Symbol>) -> Self {
227        match t {
228            Literal(lit) => (*lit).into(),
229            Struct(s) => Type::Composite(CompositeType {
230                path: {
231                    let ident = Identifier::from(s);
232                    Path::from(ident).with_absolute_path(Some(vec![ident.name]))
233                },
234                const_arguments: Vec::new(),
235                program,
236            }),
237            Array(array) => Type::Array(ArrayType::from_snarkvm(array, program)),
238        }
239    }
240
241    // Attempts to convert `self` to a snarkVM `PlaintextType`.
242    pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
243        match self {
244            Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
245            Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
246            Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
247            Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
248            Type::Integer(int_type) => match int_type {
249                IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
250                IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
251                IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
252                IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
253                IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
254                IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
255                IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
256                IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
257                IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
258                IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
259            },
260            Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
261            Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
262            Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
263            _ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
264        }
265    }
266
267    // A helper function to get the size in bits of the input type.
268    pub fn size_in_bits<N: Network, F>(&self, is_raw: bool, get_structs: F) -> anyhow::Result<usize>
269    where
270        F: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
271    {
272        match is_raw {
273            false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs),
274            true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs),
275        }
276    }
277
278    /// Determines whether `self` can be coerced to the `expected` type.
279    ///
280    /// This method checks if the current type can be implicitly coerced to the expected type
281    /// according to specific rules:
282    /// - `Optional<T>` can be coerced to `Optional<T>`.
283    /// - `T` can be coerced to `Optional<T>`.
284    /// - Arrays `[T; N]` can be coerced to `[Optional<T>; N]` if lengths match or are unknown,
285    ///   and element types are coercible.
286    /// - Falls back to an equality check for other types.
287    ///
288    /// # Arguments
289    /// * `expected` - The type to which `self` is being coerced.
290    ///
291    /// # Returns
292    /// `true` if coercion is allowed; `false` otherwise.
293    pub fn can_coerce_to(&self, expected: &Type) -> bool {
294        use Type::*;
295
296        match (self, expected) {
297            // Allow Optional<T> → Optional<T>
298            (Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
299
300            // Allow T → Optional<T>
301            (a, Optional(opt)) => a.can_coerce_to(&opt.inner),
302
303            // Allow [T; N] → [Optional<T>; N]
304            (Array(a_arr), Array(e_arr)) => {
305                let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
306                    (Some(l1), Some(l2)) => l1 == l2,
307                    _ => true,
308                };
309
310                lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
311            }
312
313            // Fallback: check for exact match
314            _ => self.eq_user(expected),
315        }
316    }
317
318    pub fn is_optional(&self) -> bool {
319        matches!(self, Self::Optional(_))
320    }
321
322    pub fn is_vector(&self) -> bool {
323        matches!(self, Self::Vector(_))
324    }
325
326    pub fn is_mapping(&self) -> bool {
327        matches!(self, Self::Mapping(_))
328    }
329
330    pub fn to_optional(&self) -> Type {
331        Type::Optional(OptionalType { inner: Box::new(self.clone()) })
332    }
333}
334
335impl From<LiteralType> for Type {
336    fn from(value: LiteralType) -> Self {
337        match value {
338            LiteralType::Address => Type::Address,
339            LiteralType::Boolean => Type::Boolean,
340            LiteralType::Field => Type::Field,
341            LiteralType::Group => Type::Group,
342            LiteralType::U8 => Type::Integer(IntegerType::U8),
343            LiteralType::U16 => Type::Integer(IntegerType::U16),
344            LiteralType::U32 => Type::Integer(IntegerType::U32),
345            LiteralType::U64 => Type::Integer(IntegerType::U64),
346            LiteralType::U128 => Type::Integer(IntegerType::U128),
347            LiteralType::I8 => Type::Integer(IntegerType::I8),
348            LiteralType::I16 => Type::Integer(IntegerType::I16),
349            LiteralType::I32 => Type::Integer(IntegerType::I32),
350            LiteralType::I64 => Type::Integer(IntegerType::I64),
351            LiteralType::I128 => Type::Integer(IntegerType::I128),
352            LiteralType::Scalar => Type::Scalar,
353            LiteralType::Signature => Type::Signature,
354            LiteralType::String => Type::String,
355        }
356    }
357}
358
359impl fmt::Display for Type {
360    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
361        match *self {
362            Type::Address => write!(f, "address"),
363            Type::Array(ref array_type) => write!(f, "{array_type}"),
364            Type::Boolean => write!(f, "bool"),
365            Type::Field => write!(f, "field"),
366            Type::Future(ref future_type) => write!(f, "{future_type}"),
367            Type::Group => write!(f, "group"),
368            Type::Identifier(ref variable) => write!(f, "{variable}"),
369            Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
370            Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
371            Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
372            Type::Scalar => write!(f, "scalar"),
373            Type::Signature => write!(f, "signature"),
374            Type::String => write!(f, "string"),
375            Type::Composite(ref struct_type) => write!(f, "{struct_type}"),
376            Type::Tuple(ref tuple) => write!(f, "{tuple}"),
377            Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
378            Type::Numeric => write!(f, "numeric"),
379            Type::Unit => write!(f, "()"),
380            Type::Err => write!(f, "error"),
381        }
382    }
383}