leo_lang/cli/commands/
test.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use super::*;
18
19use leo_compiler::run_with_ledger;
20use leo_package::{Package, ProgramData};
21use leo_span::Symbol;
22
23use snarkvm::prelude::TestnetV0;
24
25use colored::Colorize as _;
26use indexmap::IndexSet;
27use std::{fs, path::PathBuf};
28
29/// Test a leo program.
30#[derive(Parser, Debug)]
31pub struct LeoTest {
32    #[clap(
33        name = "TEST_NAME",
34        help = "If specified, run only tests whose qualified name matches against this string.",
35        default_value = ""
36    )]
37    pub(crate) test_name: String,
38
39    #[clap(flatten)]
40    pub(crate) compiler_options: BuildOptions,
41}
42
43impl Command for LeoTest {
44    type Input = <LeoBuild as Command>::Output;
45    type Output = ();
46
47    fn log_span(&self) -> Span {
48        tracing::span!(tracing::Level::INFO, "Leo")
49    }
50
51    fn prelude(&self, context: Context) -> Result<Self::Input> {
52        let mut options = self.compiler_options.clone();
53        options.build_tests = true;
54        (LeoBuild { env_override: Default::default(), options }).execute(context)
55    }
56
57    fn apply(self, context: Context, input: Self::Input) -> Result<Self::Output> {
58        handle_test(self, context, input)
59    }
60}
61
62fn handle_test(command: LeoTest, context: Context, package: Package) -> Result<()> {
63    // Get the private key.
64    let private_key = context.get_private_key::<TestnetV0>(&None)?;
65    let address = Address::try_from(&private_key)?;
66
67    // Get the paths of all local dependencies.
68    let leo_paths: Vec<PathBuf> = package
69        .programs
70        .iter()
71        .flat_map(|program| match &program.data {
72            ProgramData::SourcePath(path) => Some(path.clone()),
73            ProgramData::Bytecode(..) => None,
74        })
75        .collect();
76    let local_dependency_symbols: IndexSet<Symbol> = package
77        .programs
78        .iter()
79        .flat_map(|program| match &program.data {
80            ProgramData::SourcePath(..) => {
81                // It's a local dependency.
82                Some(program.name)
83            }
84            ProgramData::Bytecode(..) => {
85                // It's a network dependency.
86                None
87            }
88        })
89        .collect();
90    let imports_directory = package.imports_directory();
91
92    // Get the paths to .aleo files in `imports` - but filter out the ones corresponding to local dependencies.
93    let aleo_paths: Vec<PathBuf> = imports_directory
94        .read_dir()
95        .ok()
96        .into_iter()
97        .flatten()
98        .flat_map(|maybe_filename| maybe_filename.ok())
99        .filter(|entry| entry.file_type().ok().map(|filetype| filetype.is_file()).unwrap_or(false))
100        .flat_map(|entry| {
101            let path = entry.path();
102            if let Some(filename) = leo_package::filename_no_aleo_extension(&path) {
103                let symbol = Symbol::intern(filename);
104                if local_dependency_symbols.contains(&symbol) { None } else { Some(path) }
105            } else {
106                None
107            }
108        })
109        .collect();
110
111    let (native_test_functions, interpreter_result) =
112        leo_interpreter::find_and_run_tests(&leo_paths, &aleo_paths, address, 0u32, &command.test_name)?;
113
114    // Now for native tests.
115
116    let program_name = package.manifest.program.strip_suffix(".aleo").unwrap();
117    let program_name_symbol = Symbol::intern(program_name);
118    let build_directory = package.build_directory();
119
120    let credits = Symbol::intern("credits");
121
122    // Get bytecode and name for all programs, either directly or from the filesystem if they were compiled.
123    let programs: Vec<run_with_ledger::Program> = package
124        .programs
125        .iter()
126        .filter_map(|program| {
127            // Skip credits.aleo so we don't try to deploy it again.
128            if program.name == credits {
129                return None;
130            }
131            let bytecode = match &program.data {
132                ProgramData::Bytecode(c) => c.clone(),
133                ProgramData::SourcePath(..) => {
134                    // This was not a network dependency, so get its bytecode from the filesystem.
135                    let aleo_path = if program.name == program_name_symbol {
136                        build_directory.join("main.aleo")
137                    } else {
138                        imports_directory.join(format!("{}.aleo", program.name))
139                    };
140                    fs::read_to_string(&aleo_path)
141                        .unwrap_or_else(|e| panic!("Failed to read Aleo file at {}: {}", aleo_path.display(), e))
142                }
143            };
144            Some(run_with_ledger::Program { bytecode, name: program.name.to_string() })
145        })
146        .collect();
147
148    let should_fails: Vec<bool> = native_test_functions.iter().map(|test_function| test_function.should_fail).collect();
149    let cases: Vec<run_with_ledger::Case> = native_test_functions
150        .into_iter()
151        .map(|test_function| run_with_ledger::Case {
152            program_name: format!("{}.aleo", test_function.program),
153            function: test_function.function,
154            private_key: test_function.private_key,
155            input: Vec::new(),
156        })
157        .collect();
158
159    let (handler, buf) = Handler::new_with_buf();
160
161    let outcomes = run_with_ledger::run_with_ledger(
162        &run_with_ledger::Config { seed: 0, min_height: 1, programs },
163        &cases,
164        &handler,
165        &buf,
166    )?;
167
168    let native_results: Vec<Option<String>> = outcomes
169        .into_iter()
170        .zip(should_fails)
171        .map(|(outcome, should_fail)| match (&outcome.status, should_fail) {
172            (run_with_ledger::Status::Accepted, false) => None,
173            (run_with_ledger::Status::Accepted, true) => Some("Test succeeded when failure was expected.".to_string()),
174            (_, true) => None,
175            (_, false) => Some(format!("{} -- {}", outcome.status, outcome.errors)),
176        })
177        .collect();
178
179    // All tests are run. Report results.
180    let total = interpreter_result.iter().count() + native_results.len();
181    let total_passed = interpreter_result.iter().filter(|(_, test_result)| matches!(test_result, Ok(()))).count()
182        + native_results.iter().filter(|x| x.is_none()).count();
183
184    if total == 0 {
185        println!("No tests run.");
186        Ok(())
187    } else {
188        println!("{total_passed} / {total} tests passed.");
189        let failed = "FAILED".bold().red();
190        let passed = "PASSED".bold().green();
191        for (id, id_result) in interpreter_result.iter() {
192            // Wasteful to make this, but fill will work.
193            let str_id = format!("{id}");
194            if let Err(err) = id_result {
195                println!("{failed}: {str_id:<30} | {err}");
196            } else {
197                println!("{passed}: {str_id}");
198            }
199        }
200
201        for (case, case_result) in cases.iter().zip(native_results) {
202            let str_id = format!("{}/{}", case.program_name, case.function);
203            if let Some(err_str) = case_result {
204                println!("{failed}: {str_id:<30} | {err_str}");
205            } else {
206                println!("{passed}: {str_id}");
207            }
208        }
209
210        Ok(())
211    }
212}