leo_passes/common/graph/mod.rs
1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_span::Symbol;
18
19use indexmap::{IndexMap, IndexSet};
20use std::{fmt::Debug, hash::Hash};
21
22/// A struct dependency graph.
23pub type StructGraph = DiGraph<Symbol>;
24
25/// A call graph.
26pub type CallGraph = DiGraph<Symbol>;
27
28/// An import dependency graph.
29pub type ImportGraph = DiGraph<Symbol>;
30
31/// A node in a graph.
32pub trait Node: Copy + 'static + Eq + PartialEq + Debug + Hash {}
33
34impl Node for Symbol {}
35
36/// Errors in directed graph operations.
37#[derive(Debug)]
38pub enum DiGraphError<N: Node> {
39 /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
40 CycleDetected(Vec<N>),
41}
42
43/// A directed graph.
44#[derive(Clone, Debug, Default, PartialEq, Eq)]
45pub struct DiGraph<N: Node> {
46 /// The set of nodes in the graph.
47 nodes: IndexSet<N>,
48 /// The directed edges in the graph.
49 /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
50 edges: IndexMap<N, IndexSet<N>>,
51}
52
53impl<N: Node> DiGraph<N> {
54 /// Initializes a new `DiGraph` from a vector of source nodes.
55 pub fn new(nodes: IndexSet<N>) -> Self {
56 Self { nodes, edges: IndexMap::new() }
57 }
58
59 pub fn add_node(&mut self, node: N) {
60 self.nodes.insert(node);
61 }
62
63 /// Adds an edge to the graph.
64 pub fn add_edge(&mut self, from: N, to: N) {
65 // Add `from` and `to` to the set of nodes if they are not already in the set.
66 self.nodes.insert(from);
67 self.nodes.insert(to);
68
69 // Add the edge to the adjacency list.
70 let entry = self.edges.entry(from).or_default();
71 entry.insert(to);
72 }
73
74 /// Removes a node and all associated edges from the graph.
75 pub fn remove_node(&mut self, node: &N) -> bool {
76 let existed = self.nodes.shift_remove(node);
77
78 if existed {
79 // Remove all outgoing edges from the node
80 self.edges.shift_remove(node);
81
82 // Remove all incoming edges to the node
83 for targets in self.edges.values_mut() {
84 targets.shift_remove(node);
85 }
86 }
87
88 existed
89 }
90
91 /// Returns the set of immediate neighbors of a given node.
92 pub fn neighbors(&self, node: &N) -> Option<IndexSet<N>> {
93 self.edges.get(node).cloned()
94 }
95
96 /// Returns `true` if the graph contains the given node.
97 pub fn contains_node(&self, node: N) -> bool {
98 self.nodes.contains(&node)
99 }
100
101 /// Returns the post-order ordering of the graph.
102 /// Detects if there is a cycle in the graph.
103 pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
104 self.post_order_from_entry_points(|_| true)
105 }
106
107 /// Returns the post-order ordering of the graph but only considering a subset of the nodes as "entry points".
108 /// Meaning, nodes that are not accessible from the entry nodes are not considered in the post order.
109 ///
110 /// Detects if there is a cycle in the graph.
111 pub fn post_order_from_entry_points<F>(&self, is_entry_point: F) -> Result<IndexSet<N>, DiGraphError<N>>
112 where
113 F: Fn(&N) -> bool,
114 {
115 // The set of nodes that do not need to be visited again.
116 let mut finished: IndexSet<N> = IndexSet::with_capacity(self.nodes.len());
117
118 // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
119 // `is_entry_point`.
120 for node in self.nodes.iter().filter(|n| is_entry_point(n)) {
121 // If the node has not been explored, explore it.
122 if !finished.contains(node) {
123 // The set of nodes that are on the path to the current node in the search.
124 let mut discovered: IndexSet<N> = IndexSet::new();
125 // Check if there is a cycle in the graph starting from `node`.
126 if let Some(node) = self.contains_cycle_from(*node, &mut discovered, &mut finished) {
127 let mut path = vec![node];
128 // Backtrack through the discovered nodes to find the cycle.
129 while let Some(next) = discovered.pop() {
130 // Add the node to the path.
131 path.push(next);
132 // If the node is the same as the first node in the path, we have found the cycle.
133 if next == node {
134 break;
135 }
136 }
137 // Reverse the path to get the cycle in the correct order.
138 path.reverse();
139 // A cycle was detected. Return the path of the cycle.
140 return Err(DiGraphError::CycleDetected(path));
141 }
142 }
143 }
144 // No cycle was found. Return the set of nodes in topological order.
145 Ok(finished)
146 }
147
148 /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
149 pub fn retain_nodes(&mut self, nodes: &IndexSet<N>) {
150 // Remove the nodes from the set of nodes.
151 self.nodes.retain(|node| nodes.contains(node));
152 self.edges.retain(|node, _| nodes.contains(node));
153 // Remove the edges that reference the nodes.
154 for (_, children) in self.edges.iter_mut() {
155 children.retain(|child| nodes.contains(child));
156 }
157 }
158
159 // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
160 // If there is no cycle, returns `None`.
161 // If there is a cycle, returns the node that was most recently discovered.
162 // Nodes are added to `finished` in post-order order.
163 fn contains_cycle_from(&self, node: N, discovered: &mut IndexSet<N>, finished: &mut IndexSet<N>) -> Option<N> {
164 // Add the node to the set of discovered nodes.
165 discovered.insert(node);
166
167 // Check each outgoing edge of the node.
168 if let Some(children) = self.edges.get(&node) {
169 for child in children.iter() {
170 // If the node already been discovered, there is a cycle.
171 if discovered.contains(child) {
172 // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
173 // Note that this case is always hit when there is a cycle.
174 return Some(*child);
175 }
176 // If the node has not been explored, explore it.
177 if !finished.contains(child) {
178 if let Some(child) = self.contains_cycle_from(*child, discovered, finished) {
179 return Some(child);
180 }
181 }
182 }
183 }
184
185 // Remove the node from the set of discovered nodes.
186 discovered.pop();
187 // Add the node to the set of finished nodes.
188 finished.insert(node);
189
190 None
191 }
192}
193
194#[cfg(test)]
195mod test {
196 use super::*;
197
198 impl Node for u32 {}
199
200 fn check_post_order<N: Node>(graph: &DiGraph<N>, expected: &[N]) {
201 let result = graph.post_order();
202 assert!(result.is_ok());
203
204 let order: Vec<N> = result.unwrap().into_iter().collect();
205 assert_eq!(order, expected);
206 }
207
208 #[test]
209 fn test_post_order() {
210 let mut graph = DiGraph::<u32>::new(IndexSet::new());
211
212 graph.add_edge(1, 2);
213 graph.add_edge(1, 3);
214 graph.add_edge(2, 4);
215 graph.add_edge(3, 4);
216 graph.add_edge(4, 5);
217
218 check_post_order(&graph, &[5, 4, 2, 3, 1]);
219
220 let mut graph = DiGraph::<u32>::new(IndexSet::new());
221
222 // F -> B
223 graph.add_edge(6, 2);
224 // B -> A
225 graph.add_edge(2, 1);
226 // B -> D
227 graph.add_edge(2, 4);
228 // D -> C
229 graph.add_edge(4, 3);
230 // D -> E
231 graph.add_edge(4, 5);
232 // F -> G
233 graph.add_edge(6, 7);
234 // G -> I
235 graph.add_edge(7, 9);
236 // I -> H
237 graph.add_edge(9, 8);
238
239 // A, C, E, D, B, H, I, G, F.
240 check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
241 }
242
243 #[test]
244 fn test_cycle() {
245 let mut graph = DiGraph::<u32>::new(IndexSet::new());
246
247 graph.add_edge(1, 2);
248 graph.add_edge(2, 3);
249 graph.add_edge(2, 4);
250 graph.add_edge(4, 1);
251
252 let result = graph.post_order();
253 assert!(result.is_err());
254
255 let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
256 let expected = Vec::from([1u32, 2, 4, 1]);
257 assert_eq!(cycle, expected);
258 }
259
260 #[test]
261 fn test_unconnected_graph() {
262 let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
263
264 check_post_order(&graph, &[1, 2, 3, 4, 5]);
265 }
266
267 #[test]
268 fn test_retain_nodes() {
269 let mut graph = DiGraph::<u32>::new(IndexSet::new());
270
271 graph.add_edge(1, 2);
272 graph.add_edge(1, 3);
273 graph.add_edge(1, 5);
274 graph.add_edge(2, 3);
275 graph.add_edge(2, 4);
276 graph.add_edge(2, 5);
277 graph.add_edge(3, 4);
278 graph.add_edge(4, 5);
279
280 let mut nodes = IndexSet::new();
281 nodes.insert(1);
282 nodes.insert(2);
283 nodes.insert(3);
284
285 graph.retain_nodes(&nodes);
286
287 let mut expected = DiGraph::<u32>::new(IndexSet::new());
288 expected.add_edge(1, 2);
289 expected.add_edge(1, 3);
290 expected.add_edge(2, 3);
291 expected.edges.insert(3, IndexSet::new());
292
293 assert_eq!(graph, expected);
294 }
295}