leo_ast/stub/
function_stub.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    Annotation,
19    CompositeType,
20    Function,
21    FutureType,
22    Identifier,
23    Input,
24    Location,
25    Mode,
26    Node,
27    NodeID,
28    Output,
29    ProgramId,
30    TupleType,
31    Type,
32    Variant,
33};
34use leo_span::{Span, Symbol, sym};
35
36use itertools::Itertools;
37use serde::{Deserialize, Serialize};
38use snarkvm::{
39    console::program::{
40        FinalizeType::{Future as FutureFinalizeType, Plaintext as PlaintextFinalizeType},
41        RegisterType::{ExternalRecord, Future, Plaintext, Record},
42    },
43    prelude::{Network, ValueType},
44    synthesizer::program::{ClosureCore, CommandTrait, FunctionCore, InstructionTrait},
45};
46use std::fmt;
47
48/// A function stub definition.
49#[derive(Clone, Serialize, Deserialize)]
50pub struct FunctionStub {
51    /// Annotations on the function.
52    pub annotations: Vec<Annotation>,
53    /// Is this function a transition, inlined, or a regular function?.
54    pub variant: Variant,
55    /// The function identifier, e.g., `foo` in `function foo(...) { ... }`.
56    pub identifier: Identifier,
57    /// The function's input parameters.
58    pub input: Vec<Input>,
59    /// The function's output declarations.
60    pub output: Vec<Output>,
61    /// The function's output type.
62    pub output_type: Type,
63    /// The entire span of the function definition.
64    pub span: Span,
65    /// The ID of the node.
66    pub id: NodeID,
67}
68
69impl PartialEq for FunctionStub {
70    fn eq(&self, other: &Self) -> bool {
71        self.identifier == other.identifier
72    }
73}
74
75impl Eq for FunctionStub {}
76
77impl FunctionStub {
78    /// Initialize a new function.
79    #[allow(clippy::too_many_arguments)]
80    pub fn new(
81        annotations: Vec<Annotation>,
82        _is_async: bool,
83        variant: Variant,
84        identifier: Identifier,
85        input: Vec<Input>,
86        output: Vec<Output>,
87        span: Span,
88        id: NodeID,
89    ) -> Self {
90        let output_type = match output.len() {
91            0 => Type::Unit,
92            1 => output[0].type_.clone(),
93            _ => Type::Tuple(TupleType::new(output.iter().map(|o| o.type_.clone()).collect())),
94        };
95
96        FunctionStub { annotations, variant, identifier, input, output, output_type, span, id }
97    }
98
99    /// Returns function name.
100    pub fn name(&self) -> Symbol {
101        self.identifier.name
102    }
103
104    /// Returns `true` if the function name is `main`.
105    pub fn is_main(&self) -> bool {
106        self.name() == sym::main
107    }
108
109    /// Private formatting method used for optimizing [fmt::Debug] and [fmt::Display] implementations.
110    fn format(&self, f: &mut fmt::Formatter) -> fmt::Result {
111        match self.variant {
112            Variant::Inline => write!(f, "inline ")?,
113            Variant::Script => write!(f, "script ")?,
114            Variant::Function | Variant::AsyncFunction => write!(f, "function ")?,
115            Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?,
116        }
117        write!(f, "{}", self.identifier)?;
118
119        let parameters = self.input.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(",");
120        let returns = match self.output.len() {
121            0 => "()".to_string(),
122            1 => self.output[0].to_string(),
123            _ => self.output.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","),
124        };
125        write!(f, "({parameters}) -> {returns}")?;
126
127        Ok(())
128    }
129
130    /// Converts from snarkvm function type to leo FunctionStub, while also carrying the parent program name.
131    pub fn from_function_core<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>>(
132        function: &FunctionCore<N, Instruction, Command>,
133        program: Symbol,
134    ) -> Self {
135        let outputs = function
136            .outputs()
137            .iter()
138            .map(|output| match output.value_type() {
139                ValueType::Constant(val) => vec![Output {
140                    mode: Mode::Constant,
141                    type_: Type::from_snarkvm(val, None),
142                    span: Default::default(),
143                    id: Default::default(),
144                }],
145                ValueType::Public(val) => vec![Output {
146                    mode: Mode::Public,
147                    type_: Type::from_snarkvm(val, None),
148                    span: Default::default(),
149                    id: Default::default(),
150                }],
151                ValueType::Private(val) => vec![Output {
152                    mode: Mode::Private,
153                    type_: Type::from_snarkvm(val, None),
154                    span: Default::default(),
155                    id: Default::default(),
156                }],
157                ValueType::Record(id) => vec![Output {
158                    mode: Mode::None,
159                    type_: Type::Composite(CompositeType { id: Identifier::from(id), program: Some(program) }),
160                    span: Default::default(),
161                    id: Default::default(),
162                }],
163                ValueType::ExternalRecord(loc) => {
164                    vec![Output {
165                        mode: Mode::None,
166                        span: Default::default(),
167                        id: Default::default(),
168                        type_: Type::Composite(CompositeType {
169                            id: Identifier::from(loc.resource()),
170                            program: Some(ProgramId::from(loc.program_id()).name.name),
171                        }),
172                    }]
173                }
174                ValueType::Future(_) => vec![Output {
175                    mode: Mode::None,
176                    span: Default::default(),
177                    id: Default::default(),
178                    type_: Type::Future(FutureType::new(
179                        Vec::new(),
180                        Some(Location::new(program, Identifier::from(function.name()).name)),
181                        false,
182                    )),
183                }],
184            })
185            .collect_vec()
186            .concat();
187        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
188        let output_type = match output_vec.len() {
189            0 => Type::Unit,
190            1 => output_vec[0].clone(),
191            _ => Type::Tuple(TupleType::new(output_vec)),
192        };
193
194        Self {
195            annotations: Vec::new(),
196            variant: match function.finalize_logic().is_some() {
197                true => Variant::AsyncTransition,
198                false => Variant::Transition,
199            },
200            identifier: Identifier::from(function.name()),
201            input: function
202                .inputs()
203                .iter()
204                .enumerate()
205                .map(|(index, input)| {
206                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
207                    match input.value_type() {
208                        ValueType::Constant(val) => Input {
209                            identifier: arg_name,
210                            mode: Mode::Constant,
211                            type_: Type::from_snarkvm(val, None),
212                            span: Default::default(),
213                            id: Default::default(),
214                        },
215                        ValueType::Public(val) => Input {
216                            identifier: arg_name,
217                            mode: Mode::Public,
218                            type_: Type::from_snarkvm(val, None),
219                            span: Default::default(),
220                            id: Default::default(),
221                        },
222                        ValueType::Private(val) => Input {
223                            identifier: arg_name,
224                            mode: Mode::Private,
225                            type_: Type::from_snarkvm(val, None),
226                            span: Default::default(),
227                            id: Default::default(),
228                        },
229                        ValueType::Record(id) => Input {
230                            identifier: arg_name,
231                            mode: Mode::None,
232                            type_: Type::Composite(CompositeType { id: Identifier::from(id), program: Some(program) }),
233                            span: Default::default(),
234                            id: Default::default(),
235                        },
236                        ValueType::ExternalRecord(loc) => Input {
237                            identifier: arg_name,
238                            mode: Mode::None,
239                            span: Default::default(),
240                            id: Default::default(),
241                            type_: Type::Composite(CompositeType {
242                                id: Identifier::from(loc.resource()),
243                                program: Some(ProgramId::from(loc.program_id()).name.name),
244                            }),
245                        },
246                        ValueType::Future(_) => panic!("Functions do not contain futures as inputs"),
247                    }
248                })
249                .collect_vec(),
250            output: outputs,
251            output_type,
252            span: Default::default(),
253            id: Default::default(),
254        }
255    }
256
257    pub fn from_finalize<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>>(
258        function: &FunctionCore<N, Instruction, Command>,
259        key_name: Symbol,
260        program: Symbol,
261    ) -> Self {
262        Self {
263            annotations: Vec::new(),
264            variant: Variant::AsyncFunction,
265            identifier: Identifier::new(key_name, Default::default()),
266            input: function
267                .finalize_logic()
268                .unwrap()
269                .inputs()
270                .iter()
271                .enumerate()
272                .map(|(index, input)| Input {
273                    identifier: Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default()),
274                    mode: Mode::None,
275                    type_: match input.finalize_type() {
276                        PlaintextFinalizeType(val) => Type::from_snarkvm(val, Some(program)),
277                        FutureFinalizeType(val) => Type::Future(FutureType::new(
278                            Vec::new(),
279                            Some(Location::new(
280                                Identifier::from(val.program_id().name()).name,
281                                Symbol::intern(&format!("finalize/{}", val.resource())),
282                            )),
283                            false,
284                        )),
285                    },
286                    span: Default::default(),
287                    id: Default::default(),
288                })
289                .collect_vec(),
290            output: Vec::new(),
291            output_type: Type::Unit,
292            span: Default::default(),
293            id: 0,
294        }
295    }
296
297    pub fn from_closure<N: Network, Instruction: InstructionTrait<N>>(
298        closure: &ClosureCore<N, Instruction>,
299        program: Symbol,
300    ) -> Self {
301        let outputs = closure
302            .outputs()
303            .iter()
304            .map(|output| match output.register_type() {
305                Plaintext(val) => Output {
306                    mode: Mode::None,
307                    type_: Type::from_snarkvm(val, Some(program)),
308                    span: Default::default(),
309                    id: Default::default(),
310                },
311                Record(_) => panic!("Closures do not return records"),
312                ExternalRecord(_) => panic!("Closures do not return external records"),
313                Future(_) => panic!("Closures do not return futures"),
314            })
315            .collect_vec();
316        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
317        let output_type = match output_vec.len() {
318            0 => Type::Unit,
319            1 => output_vec[0].clone(),
320            _ => Type::Tuple(TupleType::new(output_vec)),
321        };
322        Self {
323            annotations: Vec::new(),
324            variant: Variant::Function,
325            identifier: Identifier::from(closure.name()),
326            input: closure
327                .inputs()
328                .iter()
329                .enumerate()
330                .map(|(index, input)| {
331                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
332                    match input.register_type() {
333                        Plaintext(val) => Input {
334                            identifier: arg_name,
335                            mode: Mode::None,
336                            type_: Type::from_snarkvm(val, None),
337                            span: Default::default(),
338                            id: Default::default(),
339                        },
340                        Record(_) => panic!("Closures do not contain records as inputs"),
341                        ExternalRecord(_) => panic!("Closures do not contain external records as inputs"),
342                        Future(_) => panic!("Closures do not contain futures as inputs"),
343                    }
344                })
345                .collect_vec(),
346            output: outputs,
347            output_type,
348            span: Default::default(),
349            id: Default::default(),
350        }
351    }
352}
353
354impl From<Function> for FunctionStub {
355    fn from(function: Function) -> Self {
356        Self {
357            annotations: function.annotations,
358            variant: function.variant,
359            identifier: function.identifier,
360            input: function.input,
361            output: function.output,
362            output_type: function.output_type,
363            span: function.span,
364            id: function.id,
365        }
366    }
367}
368
369impl fmt::Debug for FunctionStub {
370    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
371        self.format(f)
372    }
373}
374
375impl fmt::Display for FunctionStub {
376    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
377        self.format(f)
378    }
379}
380
381crate::simple_node_impl!(FunctionStub);