leo_ast/common/graph/
mod.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23/// A struct dependency graph.
24/// The `Vec<Symbol>` is to the absolute path to each struct
25pub type StructGraph = DiGraph<Vec<Symbol>>;
26
27/// A call graph.
28pub type CallGraph = DiGraph<Location>;
29
30/// An import dependency graph.
31pub type ImportGraph = DiGraph<Symbol>;
32
33/// A node in a graph.
34pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
35
36impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
37
38/// Errors in directed graph operations.
39#[derive(Debug)]
40pub enum DiGraphError<N: GraphNode> {
41    /// An error that is emitted when a cycle is detected in the directed graph. Contains the path of the cycle.
42    CycleDetected(Vec<N>),
43}
44
45/// A directed graph using reference-counted nodes.
46#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct DiGraph<N: GraphNode> {
48    /// The set of nodes in the graph.
49    nodes: IndexSet<Rc<N>>,
50
51    /// The directed edges in the graph.
52    /// Each entry in the map is a node in the graph, and the set of nodes that it points to.
53    edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
54}
55
56impl<N: GraphNode> Default for DiGraph<N> {
57    fn default() -> Self {
58        Self { nodes: IndexSet::new(), edges: IndexMap::new() }
59    }
60}
61
62impl<N: GraphNode> DiGraph<N> {
63    /// Initializes a new `DiGraph` from a set of source nodes.
64    pub fn new(nodes: IndexSet<N>) -> Self {
65        let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
66        Self { nodes, edges: IndexMap::new() }
67    }
68
69    /// Adds a node to the graph.
70    pub fn add_node(&mut self, node: N) {
71        self.nodes.insert(Rc::new(node));
72    }
73
74    /// Returns an iterator over the nodes in the graph.
75    pub fn nodes(&self) -> impl Iterator<Item = &N> {
76        self.nodes.iter().map(|rc| rc.as_ref())
77    }
78
79    /// Adds an edge to the graph.
80    pub fn add_edge(&mut self, from: N, to: N) {
81        // Add `from` and `to` to the set of nodes if they are not already in the set.
82        let from_rc = self.get_or_insert(from);
83        let to_rc = self.get_or_insert(to);
84
85        // Add the edge to the adjacency list.
86        self.edges.entry(from_rc).or_default().insert(to_rc);
87    }
88
89    /// Removes a node and all associated edges from the graph.
90    pub fn remove_node(&mut self, node: &N) -> bool {
91        if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
92            // Remove all outgoing edges from the node
93            self.edges.shift_remove(&rc_node);
94
95            // Remove all incoming edges to the node
96            for targets in self.edges.values_mut() {
97                targets.shift_remove(&rc_node);
98            }
99            true
100        } else {
101            false
102        }
103    }
104
105    /// Returns an iterator to the immediate neighbors of a given node.
106    pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
107        self.edges
108            .get(node) // ← no Rc::from() needed!
109            .into_iter()
110            .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
111    }
112
113    /// Returns `true` if the graph contains the given node.
114    pub fn contains_node(&self, node: N) -> bool {
115        self.nodes.contains(&Rc::new(node))
116    }
117
118    /// Returns the post-order ordering of the graph.
119    /// Detects if there is a cycle in the graph.
120    pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
121        self.post_order_with_filter(|_| true)
122    }
123
124    /// Returns the post-order ordering of the graph but only considering a subset of the nodes that
125    /// satisfy the given filter.
126    ///
127    /// Detects if there is a cycle in the graph.
128    pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
129    where
130        F: Fn(&N) -> bool,
131    {
132        // The set of nodes that do not need to be visited again.
133        let mut finished = IndexSet::with_capacity(self.nodes.len());
134
135        // Perform a depth-first search of the graph, starting from `node`, for each node in the graph that satisfies
136        // `is_entry_point`.
137        for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
138            // If the node has not been explored, explore it.
139            if !finished.contains(node_rc) {
140                // The set of nodes that are on the path to the current node in the searc
141                let mut discovered = IndexSet::new();
142                // Check if there is a cycle in the graph starting from `node`.
143                if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
144                    let mut path = vec![cycle_node.as_ref().clone()];
145                    // Backtrack through the discovered nodes to find the cycle.
146                    while let Some(next) = discovered.pop() {
147                        // Add the node to the path.
148                        path.push(next.as_ref().clone());
149                        // If the node is the same as the first node in the path, we have found the cycle.
150                        if Rc::ptr_eq(&next, &cycle_node) {
151                            break;
152                        }
153                    }
154                    // Reverse the path to get the cycle in the correct order.
155                    path.reverse();
156                    // A cycle was detected. Return the path of the cycle.
157                    return Err(DiGraphError::CycleDetected(path));
158                }
159            }
160        }
161
162        // No cycle was found. Return the set of nodes in topological order.
163        Ok(finished.iter().map(|rc| (**rc).clone()).collect())
164    }
165
166    /// Retains a subset of the nodes, and removes all edges in which the source or destination is not in the subset.
167    pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
168        let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
169        // Remove the nodes from the set of nodes.
170        self.nodes.retain(|n| keep.contains(n));
171        self.edges.retain(|n, _| keep.contains(n));
172        // Remove the edges that reference the nodes.
173        for targets in self.edges.values_mut() {
174            targets.retain(|t| keep.contains(t));
175        }
176    }
177
178    // Detects if there is a cycle in the graph starting from the given node, via a recursive depth-first search.
179    // If there is no cycle, returns `None`.
180    // If there is a cycle, returns the node that was most recently discovered.
181    // Nodes are added to `finished` in post-order order.
182    fn contains_cycle_from(
183        &self,
184        node: &Rc<N>,
185        discovered: &mut IndexSet<Rc<N>>,
186        finished: &mut IndexSet<Rc<N>>,
187    ) -> Option<Rc<N>> {
188        // Add the node to the set of discovered nodes.
189        discovered.insert(node.clone());
190
191        // Check each outgoing edge of the node.
192        if let Some(children) = self.edges.get(node) {
193            for child in children {
194                // If the node already been discovered, there is a cycle.
195                if discovered.contains(child) {
196                    // Insert the child node into the set of discovered nodes; this is used to reconstruct the cycle.
197                    // Note that this case is always hit when there is a cycle.
198                    return Some(child.clone());
199                }
200                // If the node has not been explored, explore it.
201                if !finished.contains(child) {
202                    if let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished) {
203                        return Some(cycle_node);
204                    }
205                }
206            }
207        }
208
209        // Remove the node from the set of discovered nodes.
210        discovered.pop();
211        // Add the node to the set of finished nodes.
212        finished.insert(node.clone());
213        None
214    }
215
216    /// Helper: get or insert Rc<N> into the graph.
217    fn get_or_insert(&mut self, node: N) -> Rc<N> {
218        if let Some(existing) = self.nodes.get(&node) {
219            return existing.clone();
220        }
221        let rc = Rc::new(node);
222        self.nodes.insert(rc.clone());
223        rc
224    }
225}
226
227#[cfg(test)]
228mod test {
229    use super::*;
230
231    fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
232        let result = graph.post_order();
233        assert!(result.is_ok());
234
235        let order: Vec<N> = result.unwrap().into_iter().collect();
236        assert_eq!(order, expected);
237    }
238
239    #[test]
240    fn test_post_order() {
241        let mut graph = DiGraph::<u32>::new(IndexSet::new());
242
243        graph.add_edge(1, 2);
244        graph.add_edge(1, 3);
245        graph.add_edge(2, 4);
246        graph.add_edge(3, 4);
247        graph.add_edge(4, 5);
248
249        check_post_order(&graph, &[5, 4, 2, 3, 1]);
250
251        let mut graph = DiGraph::<u32>::new(IndexSet::new());
252
253        // F -> B
254        graph.add_edge(6, 2);
255        // B -> A
256        graph.add_edge(2, 1);
257        // B -> D
258        graph.add_edge(2, 4);
259        // D -> C
260        graph.add_edge(4, 3);
261        // D -> E
262        graph.add_edge(4, 5);
263        // F -> G
264        graph.add_edge(6, 7);
265        // G -> I
266        graph.add_edge(7, 9);
267        // I -> H
268        graph.add_edge(9, 8);
269
270        // A, C, E, D, B, H, I, G, F.
271        check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
272    }
273
274    #[test]
275    fn test_cycle() {
276        let mut graph = DiGraph::<u32>::new(IndexSet::new());
277
278        graph.add_edge(1, 2);
279        graph.add_edge(2, 3);
280        graph.add_edge(2, 4);
281        graph.add_edge(4, 1);
282
283        let result = graph.post_order();
284        assert!(result.is_err());
285
286        let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
287        let expected = Vec::from([1u32, 2, 4, 1]);
288        assert_eq!(cycle, expected);
289    }
290
291    #[test]
292    fn test_unconnected_graph() {
293        let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
294
295        check_post_order(&graph, &[1, 2, 3, 4, 5]);
296    }
297
298    #[test]
299    fn test_retain_nodes() {
300        let mut graph = DiGraph::<u32>::new(IndexSet::new());
301
302        graph.add_edge(1, 2);
303        graph.add_edge(1, 3);
304        graph.add_edge(1, 5);
305        graph.add_edge(2, 3);
306        graph.add_edge(2, 4);
307        graph.add_edge(2, 5);
308        graph.add_edge(3, 4);
309        graph.add_edge(4, 5);
310
311        let mut nodes = IndexSet::new();
312        nodes.insert(1);
313        nodes.insert(2);
314        nodes.insert(3);
315
316        graph.retain_nodes(&nodes);
317
318        let mut expected = DiGraph::<u32>::new(IndexSet::new());
319        expected.add_edge(1, 2);
320        expected.add_edge(1, 3);
321        expected.add_edge(2, 3);
322        expected.edges.insert(3.into(), IndexSet::new());
323
324        assert_eq!(graph, expected);
325    }
326}