Skip to content

Implement condensation graph generation #1337

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
features:
- |
Added a new condensation() function that works for both directed and undirected graphs. For directed graphs, it returns the condensation (quotient graph) where each node is a strongly connected component (SCC). For undirected graphs, each node is a connected component. The returned graph has a 'node_map' attribute mapping each original node index to the index of the condensed node it belongs to.
1 change: 1 addition & 0 deletions rustworkx/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ from .rustworkx import number_strongly_connected_components as number_strongly_c
from .rustworkx import number_weakly_connected_components as number_weakly_connected_components
from .rustworkx import node_connected_component as node_connected_component
from .rustworkx import strongly_connected_components as strongly_connected_components
from .rustworkx import condensation as condensation
from .rustworkx import weakly_connected_components as weakly_connected_components
from .rustworkx import digraph_adjacency_matrix as digraph_adjacency_matrix
from .rustworkx import graph_adjacency_matrix as graph_adjacency_matrix
Expand Down
1 change: 1 addition & 0 deletions rustworkx/rustworkx.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ def number_strongly_connected_components(graph: PyDiGraph, /) -> int: ...
def number_weakly_connected_components(graph: PyDiGraph, /) -> int: ...
def node_connected_component(graph: PyGraph, node: int, /) -> set[int]: ...
def strongly_connected_components(graph: PyDiGraph, /) -> list[list[int]]: ...
def condensation(graph: PyDiGraph, /, sccs: list[int] | None = ...) -> PyDiGraph: ...
def weakly_connected_components(graph: PyDiGraph, /) -> list[set[int]]: ...
def digraph_adjacency_matrix(
graph: PyDiGraph[_S, _T],
Expand Down
165 changes: 161 additions & 4 deletions src/connectivity/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,22 @@ use super::{

use hashbrown::{HashMap, HashSet};
use indexmap::IndexSet;
use petgraph::algo;
use petgraph::algo::condensation;
use petgraph::graph::DiGraph;
use petgraph::graph::{DiGraph, IndexType};
use petgraph::stable_graph::NodeIndex;
use petgraph::unionfind::UnionFind;
use petgraph::visit::{EdgeRef, IntoEdgeReferences, NodeCount, NodeIndexable, Visitable};
use petgraph::{algo, Graph};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::PyDict;
use pyo3::BoundObject;
use pyo3::IntoPyObject;
use pyo3::Python;
use rayon::prelude::*;

use ndarray::prelude::*;
use numpy::{IntoPyArray, PyArray2};
use petgraph::prelude::StableGraph;

use crate::iterators::{
AllPairsMultiplePathMapping, BiconnectedComponents, Chains, EdgeList, NodeIndices,
Expand Down Expand Up @@ -192,6 +194,161 @@ pub fn is_strongly_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
Ok(algo::kosaraju_scc(&graph.graph).len() == 1)
}

/// Compute the condensation of a graph (directed or undirected).
///
/// For directed graphs, this returns the condensation (quotient graph) where each node
/// represents a strongly connected component (SCC) of the input graph. For undirected graphs,
/// each node represents a connected component.
///
/// The returned graph has a node attribute 'node_map' which is a list mapping each original
/// node index to the index of the condensed node it belongs to.
///
/// :param graph: The input graph (PyDiGraph or PyGraph)
/// :param sccs: (Optional, directed only) List of SCCs to use instead of computing them
/// :returns: The condensed graph (PyDiGraph or PyGraph) with a 'node_map' attribute
/// :rtype: PyDiGraph or PyGraph
fn condensation_inner<'py, N, E, Ty, Ix>(
py: Python<'py>,
g: Graph<N, E, Ty, Ix>,
make_acyclic: bool,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This argument is always true, I'd rather remove the argument & the logic for when it is false

sccs: Option<Vec<Vec<usize>>>,
) -> PyResult<(StablePyGraph<Ty>, Vec<usize>)>
where
Ty: EdgeType,
Ix: IndexType,
N: IntoPyObject<'py, Target = PyAny> + Clone,
E: IntoPyObject<'py, Target = PyAny> + Clone,
{
// For directed graphs, use SCCs; for undirected, use connected components
let components: Vec<Vec<NodeIndex<Ix>>> = if Ty::is_directed() {
if let Some(sccs) = sccs {
sccs.into_iter()
.map(|row| row.into_iter().map(NodeIndex::new).collect())
.collect()
} else {
algo::kosaraju_scc(&g)
}
} else {
connectivity::connected_components(&g)
.into_iter()
.map(|set| set.into_iter().collect())
.collect()
};

// Convert all NodeIndex<Ix> to NodeIndex<usize> for the output graph
let components_usize: Vec<Vec<NodeIndex<usize>>> = components
.iter()
.map(|comp| comp.iter().map(|ix| NodeIndex::new(ix.index())).collect())
.collect();

let mut condensed: StableGraph<Vec<N>, E, Ty, u32> =
StableGraph::with_capacity(components_usize.len(), g.edge_count());

// Build a map from old indices to new ones.
let mut node_map = vec![usize::MAX; g.node_count()];
for comp in components_usize.iter() {
let new_nix = condensed.add_node(Vec::new());
for nix in comp {
node_map[nix.index()] = new_nix.index();
}
}

// Consume nodes and edges of the old graph and insert them into the new one.
let (nodes, edges) = g.into_nodes_edges();
for (nix, node) in nodes.into_iter().enumerate() {
let idx = node_map.get(nix).copied().unwrap_or(usize::MAX);
if idx != usize::MAX {
condensed[NodeIndex::new(idx)].push(node.weight);
}
}
for edge in edges {
let source = node_map
.get(edge.source().index())
.copied()
.unwrap_or(usize::MAX);
let target = node_map
.get(edge.target().index())
.copied()
.unwrap_or(usize::MAX);
if source == usize::MAX || target == usize::MAX {
continue;
}
let source = NodeIndex::new(source);
let target = NodeIndex::new(target);
if make_acyclic && Ty::is_directed() {
if source != target {
condensed.update_edge(source, target, edge.weight);
}
} else {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is always false, remove this statmenet

condensed.add_edge(source, target, edge.weight);
}
}

let mapped = condensed.map(
|_, w| match w.clone().into_pyobject(py) {
Ok(bound) => bound.unbind(),
Err(_) => PyValueError::new_err("Node conversion failed")
.into_pyobject(py)
.unwrap()
.unbind()
.into(),
},
|_, w| match w.clone().into_pyobject(py) {
Ok(bound) => bound.unbind(),
Err(_) => PyValueError::new_err("Edge conversion failed")
.into_pyobject(py)
.unwrap()
.unbind()
.into(),
},
);
Ok((mapped, node_map))
}

#[pyfunction]
#[pyo3(text_signature = "(graph, /, sccs=None)", signature=(graph, sccs=None))]
pub fn condensation(
py: Python,
graph: PyObject,
sccs: Option<Vec<Vec<usize>>>,
) -> PyResult<PyObject> {
if let Ok(digraph) = graph.extract::<digraph::PyDiGraph>(py) {
let g = digraph.graph.clone();
let (condensed, node_map) = condensation_inner(py, g.into(), true, sccs)?;
let mut result = digraph::PyDiGraph {
graph: condensed,
cycle_state: algo::DfsSpace::default(),
check_cycle: false,
node_removed: false,
multigraph: true,
attrs: PyDict::new(py).into(),
};
let node_map_py = node_map.into_pyobject(py)?;
let attrs = PyDict::new(py);
attrs.set_item("node_map", node_map_py)?;
result.attrs = attrs.into();
Ok(result.into_pyobject(py)?.into())
} else if let Ok(pygraph) = graph.extract::<graph::PyGraph>(py) {
let g = pygraph.graph.clone();
let (condensed, node_map) = condensation_inner(py, g.into(), false, None)?;
let mut result = graph::PyGraph {
graph: condensed,
node_removed: false,
multigraph: pygraph.multigraph,
attrs: PyDict::new(py).into(),
};
let node_map_py = node_map.into_pyobject(py)?;
let attrs = PyDict::new(py);
attrs.set_item("node_map", node_map_py)?;
result.attrs = attrs.into();
Ok(result.into_pyobject(py)?.into())
} else {
Err(PyValueError::new_err(
"Input must be a PyDiGraph or PyGraph",
))
}
}

/// Return the first cycle encountered during DFS of a given PyDiGraph,
/// empty list is returned if no cycle is found
///
Expand Down Expand Up @@ -480,7 +637,7 @@ pub fn is_semi_connected(graph: &digraph::PyDiGraph) -> PyResult<bool> {
temp_graph.add_edge(node_map[source.index()], node_map[target.index()], ());
}

let condensed = condensation(temp_graph, true);
let condensed = algo::condensation(temp_graph, true);
let n = condensed.node_count();
let weight_fn =
|_: petgraph::graph::EdgeReference<()>| Ok::<usize, std::convert::Infallible>(1usize);
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(number_strongly_connected_components))?;
m.add_wrapped(wrap_pyfunction!(strongly_connected_components))?;
m.add_wrapped(wrap_pyfunction!(is_strongly_connected))?;
m.add_wrapped(wrap_pyfunction!(condensation))?;
m.add_wrapped(wrap_pyfunction!(digraph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(graph_dfs_edges))?;
m.add_wrapped(wrap_pyfunction!(digraph_find_cycle))?;
Expand Down
75 changes: 75 additions & 0 deletions tests/digraph/test_strongly_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,78 @@ def test_is_strongly_connected_null_graph(self):
graph = rustworkx.PyDiGraph()
with self.assertRaises(rustworkx.NullGraph):
rustworkx.is_strongly_connected(graph)


class TestCondensation(unittest.TestCase):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing test case is excellent! But we need to add one test covering the case where we pass a list to sccs as an argument

def setUp(self):
# Set up the graph
self.graph = rustworkx.PyDiGraph()
self.node_a = self.graph.add_node("a")
self.node_b = self.graph.add_node("b")
self.node_c = self.graph.add_node("c")
self.node_d = self.graph.add_node("d")
self.node_e = self.graph.add_node("e")
self.node_f = self.graph.add_node("f")
self.node_g = self.graph.add_node("g")
self.node_h = self.graph.add_node("h")

# Add edges
self.graph.add_edge(self.node_a, self.node_b, "a->b")
self.graph.add_edge(self.node_b, self.node_c, "b->c")
self.graph.add_edge(self.node_c, self.node_d, "c->d")
self.graph.add_edge(self.node_d, self.node_a, "d->a") # Cycle: a -> b -> c -> d -> a

self.graph.add_edge(self.node_b, self.node_e, "b->e")

self.graph.add_edge(self.node_e, self.node_f, "e->f")
self.graph.add_edge(self.node_f, self.node_g, "f->g")
self.graph.add_edge(self.node_g, self.node_h, "g->h")
self.graph.add_edge(self.node_h, self.node_e, "h->e") # Cycle: e -> f -> g -> h -> e

def test_condensation(self):
# Call the condensation function
condensed_graph = rustworkx.condensation(self.graph)

# Check the number of nodes (two cycles should be condensed into one node each)
self.assertEqual(
len(condensed_graph.node_indices()), 2
) # [SCC(a, b, c, d), SCC(e, f, g, h)]

# Check the number of edges
self.assertEqual(
len(condensed_graph.edge_indices()), 1
) # Edge: [SCC(a, b, c, d)] -> [SCC(e, f, g, h)]

# Check the contents of the condensed nodes
nodes = list(condensed_graph.nodes())
scc1 = nodes[0]
scc2 = nodes[1]
self.assertTrue(set(scc1) == {"a", "b", "c", "d"} or set(scc2) == {"a", "b", "c", "d"})
self.assertTrue(set(scc1) == {"e", "f", "g", "h"} or set(scc2) == {"e", "f", "g", "h"})

# Check the contents of the edge
weight = condensed_graph.edges()[0]
self.assertIn("b->e", weight) # Ensure the correct edge remains in the condensed graph

def test_condensation_with_sccs_argument(self):
# Compute SCCs manually
sccs = rustworkx.strongly_connected_components(self.graph)
# Call condensation with explicit sccs argument
condensed_graph = rustworkx.condensation(self.graph, sccs=sccs)
condensed_graph.attrs["node_map"]

# Check the number of nodes (should match SCC count)
self.assertEqual(len(condensed_graph.node_indices()), len(sccs))

# Check the number of edges
self.assertEqual(len(condensed_graph.edge_indices()), 1)

# Check the contents of the condensed nodes
nodes = list(condensed_graph.nodes())
scc_sets = [set(n) for n in nodes]
self.assertIn(set(["a", "b", "c", "d"]), scc_sets)
self.assertIn(set(["e", "f", "g", "h"]), scc_sets)

# Check the contents of the edge
weight = condensed_graph.edges()[0]
self.assertIn("b->e", weight)