1use crate::Location;
18use leo_span::Symbol;
19
20use indexmap::{IndexMap, IndexSet};
21use std::{fmt::Debug, hash::Hash, rc::Rc};
22
23pub type StructGraph = DiGraph<Vec<Symbol>>;
26
27pub type CallGraph = DiGraph<Location>;
29
30pub type ImportGraph = DiGraph<Symbol>;
32
33pub trait GraphNode: Clone + 'static + Eq + PartialEq + Debug + Hash {}
35
36impl<T> GraphNode for T where T: 'static + Clone + Eq + PartialEq + Debug + Hash {}
37
38#[derive(Debug)]
40pub enum DiGraphError<N: GraphNode> {
41 CycleDetected(Vec<N>),
43}
44
45#[derive(Clone, Debug, PartialEq, Eq)]
47pub struct DiGraph<N: GraphNode> {
48 nodes: IndexSet<Rc<N>>,
50
51 edges: IndexMap<Rc<N>, IndexSet<Rc<N>>>,
54}
55
56impl<N: GraphNode> Default for DiGraph<N> {
57 fn default() -> Self {
58 Self { nodes: IndexSet::new(), edges: IndexMap::new() }
59 }
60}
61
62impl<N: GraphNode> DiGraph<N> {
63 pub fn new(nodes: IndexSet<N>) -> Self {
65 let nodes: IndexSet<_> = nodes.into_iter().map(Rc::new).collect();
66 Self { nodes, edges: IndexMap::new() }
67 }
68
69 pub fn add_node(&mut self, node: N) {
71 self.nodes.insert(Rc::new(node));
72 }
73
74 pub fn nodes(&self) -> impl Iterator<Item = &N> {
76 self.nodes.iter().map(|rc| rc.as_ref())
77 }
78
79 pub fn add_edge(&mut self, from: N, to: N) {
81 let from_rc = self.get_or_insert(from);
83 let to_rc = self.get_or_insert(to);
84
85 self.edges.entry(from_rc).or_default().insert(to_rc);
87 }
88
89 pub fn remove_node(&mut self, node: &N) -> bool {
91 if let Some(rc_node) = self.nodes.shift_take(&Rc::new(node.clone())) {
92 self.edges.shift_remove(&rc_node);
94
95 for targets in self.edges.values_mut() {
97 targets.shift_remove(&rc_node);
98 }
99 true
100 } else {
101 false
102 }
103 }
104
105 pub fn neighbors(&self, node: &N) -> impl Iterator<Item = &N> {
107 self.edges
108 .get(node) .into_iter()
110 .flat_map(|neighbors| neighbors.iter().map(|rc| rc.as_ref()))
111 }
112
113 pub fn contains_node(&self, node: N) -> bool {
115 self.nodes.contains(&Rc::new(node))
116 }
117
118 pub fn post_order(&self) -> Result<IndexSet<N>, DiGraphError<N>> {
121 self.post_order_with_filter(|_| true)
122 }
123
124 pub fn post_order_with_filter<F>(&self, filter: F) -> Result<IndexSet<N>, DiGraphError<N>>
129 where
130 F: Fn(&N) -> bool,
131 {
132 let mut finished = IndexSet::with_capacity(self.nodes.len());
134
135 for node_rc in self.nodes.iter().filter(|n| filter(n.as_ref())) {
138 if !finished.contains(node_rc) {
140 let mut discovered = IndexSet::new();
142 if let Some(cycle_node) = self.contains_cycle_from(node_rc, &mut discovered, &mut finished) {
144 let mut path = vec![cycle_node.as_ref().clone()];
145 while let Some(next) = discovered.pop() {
147 path.push(next.as_ref().clone());
149 if Rc::ptr_eq(&next, &cycle_node) {
151 break;
152 }
153 }
154 path.reverse();
156 return Err(DiGraphError::CycleDetected(path));
158 }
159 }
160 }
161
162 Ok(finished.iter().map(|rc| (**rc).clone()).collect())
164 }
165
166 pub fn retain_nodes(&mut self, keep: &IndexSet<N>) {
168 let keep: IndexSet<_> = keep.iter().map(|n| Rc::new(n.clone())).collect();
169 self.nodes.retain(|n| keep.contains(n));
171 self.edges.retain(|n, _| keep.contains(n));
172 for targets in self.edges.values_mut() {
174 targets.retain(|t| keep.contains(t));
175 }
176 }
177
178 fn contains_cycle_from(
183 &self,
184 node: &Rc<N>,
185 discovered: &mut IndexSet<Rc<N>>,
186 finished: &mut IndexSet<Rc<N>>,
187 ) -> Option<Rc<N>> {
188 discovered.insert(node.clone());
190
191 if let Some(children) = self.edges.get(node) {
193 for child in children {
194 if discovered.contains(child) {
196 return Some(child.clone());
199 }
200 if !finished.contains(child) {
202 if let Some(cycle_node) = self.contains_cycle_from(child, discovered, finished) {
203 return Some(cycle_node);
204 }
205 }
206 }
207 }
208
209 discovered.pop();
211 finished.insert(node.clone());
213 None
214 }
215
216 fn get_or_insert(&mut self, node: N) -> Rc<N> {
218 if let Some(existing) = self.nodes.get(&node) {
219 return existing.clone();
220 }
221 let rc = Rc::new(node);
222 self.nodes.insert(rc.clone());
223 rc
224 }
225}
226
227#[cfg(test)]
228mod test {
229 use super::*;
230
231 fn check_post_order<N: GraphNode>(graph: &DiGraph<N>, expected: &[N]) {
232 let result = graph.post_order();
233 assert!(result.is_ok());
234
235 let order: Vec<N> = result.unwrap().into_iter().collect();
236 assert_eq!(order, expected);
237 }
238
239 #[test]
240 fn test_post_order() {
241 let mut graph = DiGraph::<u32>::new(IndexSet::new());
242
243 graph.add_edge(1, 2);
244 graph.add_edge(1, 3);
245 graph.add_edge(2, 4);
246 graph.add_edge(3, 4);
247 graph.add_edge(4, 5);
248
249 check_post_order(&graph, &[5, 4, 2, 3, 1]);
250
251 let mut graph = DiGraph::<u32>::new(IndexSet::new());
252
253 graph.add_edge(6, 2);
255 graph.add_edge(2, 1);
257 graph.add_edge(2, 4);
259 graph.add_edge(4, 3);
261 graph.add_edge(4, 5);
263 graph.add_edge(6, 7);
265 graph.add_edge(7, 9);
267 graph.add_edge(9, 8);
269
270 check_post_order(&graph, &[1, 3, 5, 4, 2, 8, 9, 7, 6]);
272 }
273
274 #[test]
275 fn test_cycle() {
276 let mut graph = DiGraph::<u32>::new(IndexSet::new());
277
278 graph.add_edge(1, 2);
279 graph.add_edge(2, 3);
280 graph.add_edge(2, 4);
281 graph.add_edge(4, 1);
282
283 let result = graph.post_order();
284 assert!(result.is_err());
285
286 let DiGraphError::CycleDetected(cycle) = result.unwrap_err();
287 let expected = Vec::from([1u32, 2, 4, 1]);
288 assert_eq!(cycle, expected);
289 }
290
291 #[test]
292 fn test_unconnected_graph() {
293 let graph = DiGraph::<u32>::new(IndexSet::from([1, 2, 3, 4, 5]));
294
295 check_post_order(&graph, &[1, 2, 3, 4, 5]);
296 }
297
298 #[test]
299 fn test_retain_nodes() {
300 let mut graph = DiGraph::<u32>::new(IndexSet::new());
301
302 graph.add_edge(1, 2);
303 graph.add_edge(1, 3);
304 graph.add_edge(1, 5);
305 graph.add_edge(2, 3);
306 graph.add_edge(2, 4);
307 graph.add_edge(2, 5);
308 graph.add_edge(3, 4);
309 graph.add_edge(4, 5);
310
311 let mut nodes = IndexSet::new();
312 nodes.insert(1);
313 nodes.insert(2);
314 nodes.insert(3);
315
316 graph.retain_nodes(&nodes);
317
318 let mut expected = DiGraph::<u32>::new(IndexSet::new());
319 expected.add_edge(1, 2);
320 expected.add_edge(1, 3);
321 expected.add_edge(2, 3);
322 expected.edges.insert(3.into(), IndexSet::new());
323
324 assert_eq!(graph, expected);
325 }
326}