leo_passes/common/graph/
mod.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use leo_span::Symbol;
18
19use indexmap::{IndexMap, IndexSet};
20use std::{fmt::Debug, hash::Hash};
21
22/// A struct dependency graph.
23pub type StructGraph = DiGraph<Symbol>;
24
25/// A call graph.
26pub type CallGraph = DiGraph<Symbol>;
27
28/// An import dependency graph.
29pub type ImportGraph = DiGraph<Symbol>;
30
31/// A node in a graph.
32pub trait Node: Copy + 'static + Eq + PartialEq + Debug + Hash {}
33
34impl Node for Symbol {}
35
36/// Errors in directed graph operations.
37#[derive(Debug)]
38pub enum DiGraphError<N: Node> {
39    /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
40    CycleDetected(Vec<N>),
41}
42
43/// A directed graph.
44#[derive(Clone, Debug, Default, PartialEq, Eq)]
45pub struct DiGraph<N: Node> {
46    /// The set of nodes in the graph.
47    nodes: IndexSet<N>,
48    /// The directed edges in the graph.
49    /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
50    edges: IndexMap<N, IndexSet<N>>,
51}
52
53impl<N: Node> DiGraph<N> {
54    /// Initializes a new `DiGraph` from a vector of source nodes.
55    pub fn new(nodes: IndexSet<N>) -> Self {
56        Self { nodes, edges: IndexMap::new() }
57    }
58
59    pub fn add_node(&mut self, node: N) {
60        self.nodes.insert(node);
61    }
62
63    /// Adds an edge to the graph.
64    pub fn add_edge(&mut self, from: N, to: N) {
65        // Add `from` and `to` to the set of nodes if they are not already in the set.
66        self.nodes.insert(from);
67        self.nodes.insert(to);
68
69        // Add the edge to the adjacency list.
70        let entry = self.edges.entry(from).or_default();
71        entry.insert(to);
72    }
73
74    /// Removes a node and all associated edges from the graph.
75    pub fn remove_node(&mut self, node: &N) -> bool {
76        let existed = self.nodes.shift_remove(node);
77
78        if existed {
79            // Remove all outgoing edges from the node
80            self.edges.shift_remove(node);
81
82            // Remove all incoming edges to the node
83            for targets in self.edges.values_mut() {
84                targets.shift_remove(node);
85            }
86        }
87
88        existed
89    }
90
91    /// Returns the set of immediate neighbors of a given node.
92    pub fn neighbors(&self, node: &N) -> Option<IndexSet<N>> {
93        self.edges.get(node).cloned()
94    }
95
96    /// Returns `true` if the graph contains the given node.
97    pub fn contains_node(&self, node: N) -> bool {
98        self.nodes.contains(&node)
99    }
100
101    /// Returns the post-order ordering of the graph.
102    /// Detects if there is a cycle in the graph.
103    pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
104        self.post_order_from_entry_points(|_| true)
105    }
106
107    /// Returns the post-order ordering of the graph but only considering a subset of the nodes as "entry points".
108    /// Meaning, nodes that are not accessible from the entry nodes are not considered in the post order.
109    ///
110    /// Detects if there is a cycle in the graph.
111    pub fn post_order_from_entry_points<F>(&self, is_entry_point: F) -> Result<IndexSet<N>, DiGraphError<N>>
112    where
113        F: Fn(&N) -> bool,
114    {
115        // The set of nodes that do not need to be visited again.
116        let mut finished: IndexSet<N> = IndexSet::with_capacity(self.nodes.len());
117
118        // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
119        // `is_entry_point`.
120        for node in self.nodes.iter().filter(|n| is_entry_point(n)) {
121            // If the node has not been explored, explore it.
122            if !finished.contains(node) {
123                // The set of nodes that are on the path to the current node in the search.
124                let mut discovered: IndexSet<N> = IndexSet::new();
125                // Check if there is a cycle in the graph starting from `node`.
126                if let Some(node) = self.contains_cycle_from(*node, &mut discovered, &mut finished) {
127                    let mut path = vec![node];
128                    // Backtrack through the discovered nodes to find the cycle.
129                    while let Some(next) = discovered.pop() {
130                        // Add the node to the path.
131                        path.push(next);
132                        // If the node is the same as the first node in the path, we have found the cycle.
133                        if next == node {
134                            break;
135                        }
136                    }
137                    // Reverse the path to get the cycle in the correct order.
138                    path.reverse();
139                    // A cycle was detected. Return the path of the cycle.
140                    return Err(DiGraphError::CycleDetected(path));
141                }
142            }
143        }
144        // No cycle was found. Return the set of nodes in topological order.
145        Ok(finished)
146    }
147
148    /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
149    pub fn retain_nodes(&mut self, nodes: &IndexSet<N>) {
150        // Remove the nodes from the set of nodes.
151        self.nodes.retain(|node| nodes.contains(node));
152        self.edges.retain(|node, _| nodes.contains(node));
153        // Remove the edges that reference the nodes.
154        for (_, children) in self.edges.iter_mut() {
155            children.retain(|child| nodes.contains(child));
156        }
157    }
158
159    // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
160    // If there is no cycle, returns `None`.
161    // If there is a cycle, returns the node that was most recently discovered.
162    // Nodes are added to `finished` in post-order order.
163    fn contains_cycle_from(&self, node: N, discovered: &mut IndexSet<N>, finished: &mut IndexSet<N>) -> Option<N> {
164        // Add the node to the set of discovered nodes.
165        discovered.insert(node);
166
167        // Check each outgoing edge of the node.
168        if let Some(children) = self.edges.get(&node) {
169            for child in children.iter() {
170                // If the node already been discovered, there is a cycle.
171                if discovered.contains(child) {
172                    // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
173                    // Note that this case is always hit when there is a cycle.
174                    return Some(*child);
175                }
176                // If the node has not been explored, explore it.
177                if !finished.contains(child) {
178                    if let Some(child) = self.contains_cycle_from(*child, discovered, finished) {
179                        return Some(child);
180                    }
181                }
182            }
183        }
184
185        // Remove the node from the set of discovered nodes.
186        discovered.pop();
187        // Add the node to the set of finished nodes.
188        finished.insert(node);
189
190        None
191    }
192}
193
194#[cfg(test)]
195mod test {
196    use super::*;
197
198    impl Node for u32 {}
199
200    fn check_post_order<N: Node>(graph: &DiGraph<N>, expected: &[N]) {
201        let result = graph.post_order();
202        assert!(result.is_ok());
203
204        let order: Vec<N> = result.unwrap().into_iter().collect();
205        assert_eq!(order, expected);
206    }
207
208    #[test]
209    fn test_post_order() {
210        let mut graph = DiGraph::<u32>::new(IndexSet::new());
211
212        graph.add_edge(1, 2);
213        graph.add_edge(1, 3);
214        graph.add_edge(2, 4);
215        graph.add_edge(3, 4);
216        graph.add_edge(4, 5);
217
218        check_post_order(&graph, &[5, 4, 2, 3, 1]);
219
220        let mut graph = DiGraph::<u32>::new(IndexSet::new());
221
222        // F -> B
223        graph.add_edge(6, 2);
224        // B -> A
225        graph.add_edge(2, 1);
226        // B -> D
227        graph.add_edge(2, 4);
228        // D -> C
229        graph.add_edge(4, 3);
230        // D -> E
231        graph.add_edge(4, 5);
232        // F -> G
233        graph.add_edge(6, 7);
234        // G -> I
235        graph.add_edge(7, 9);
236        // I -> H
237        graph.add_edge(9, 8);
238
239        // A, C, E, D, B, H, I, G, F.
240        check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
241    }
242
243    #[test]
244    fn test_cycle() {
245        let mut graph = DiGraph::<u32>::new(IndexSet::new());
246
247        graph.add_edge(1, 2);
248        graph.add_edge(2, 3);
249        graph.add_edge(2, 4);
250        graph.add_edge(4, 1);
251
252        let result = graph.post_order();
253        assert!(result.is_err());
254
255        let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
256        let expected = Vec::from([1u32, 2, 4, 1]);
257        assert_eq!(cycle, expected);
258    }
259
260    #[test]
261    fn test_unconnected_graph() {
262        let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
263
264        check_post_order(&graph, &[1, 2, 3, 4, 5]);
265    }
266
267    #[test]
268    fn test_retain_nodes() {
269        let mut graph = DiGraph::<u32>::new(IndexSet::new());
270
271        graph.add_edge(1, 2);
272        graph.add_edge(1, 3);
273        graph.add_edge(1, 5);
274        graph.add_edge(2, 3);
275        graph.add_edge(2, 4);
276        graph.add_edge(2, 5);
277        graph.add_edge(3, 4);
278        graph.add_edge(4, 5);
279
280        let mut nodes = IndexSet::new();
281        nodes.insert(1);
282        nodes.insert(2);
283        nodes.insert(3);
284
285        graph.retain_nodes(&nodes);
286
287        let mut expected = DiGraph::<u32>::new(IndexSet::new());
288        expected.add_edge(1, 2);
289        expected.add_edge(1, 3);
290        expected.add_edge(2, 3);
291        expected.edges.insert(3, IndexSet::new());
292
293        assert_eq!(graph, expected);
294    }
295}