Skip to content

Commit 25a3da0

Browse files
committed
init
0 parents  commit 25a3da0

14 files changed

+493
-0
lines changed

.gitignore

+73
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Byte-compiled / optimized / DLL files
2+
__pycache__/
3+
*.py[cod]
4+
*$py.class
5+
6+
# C extensions
7+
*.so
8+
9+
# Distribution / packaging
10+
.Python
11+
build/
12+
develop-eggs/
13+
dist/
14+
downloads/
15+
eggs/
16+
.eggs/
17+
lib/
18+
lib64/
19+
parts/
20+
sdist/
21+
var/
22+
wheels/
23+
share/python-wheels/
24+
*.egg-info/
25+
.installed.cfg
26+
*.egg
27+
MANIFEST
28+
29+
# Installer logs
30+
pip-log.txt
31+
pip-delete-this-directory.txt
32+
33+
# Unit test / coverage reports
34+
htmlcov/
35+
.tox/
36+
.nox/
37+
.coverage
38+
.coverage.*
39+
.cache
40+
nosetests.xml
41+
coverage.xml
42+
*.cover
43+
*.py,cover
44+
.hypothesis/
45+
.pytest_cache/
46+
cover/
47+
48+
# Jupyter Notebook
49+
.ipynb_checkpoints
50+
51+
# IPython
52+
profile_default/
53+
ipython_config.py
54+
55+
# poetry
56+
poetry.lock
57+
58+
# Environments
59+
.env
60+
.venv
61+
env/
62+
venv/
63+
ENV/
64+
env.bak/
65+
venv.bak/
66+
67+
# mypy
68+
.mypy_cache/
69+
.dmypy.json
70+
dmypy.json
71+
72+
# ruff
73+
.ruff_cache

LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2024 Aaron Z.
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

+2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# graph-diffusers
2+
Diffusion patterns for graph machine learning

graph_diffusers/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from graph_diffusers import models, operators, io
2+
3+
__all__ = [
4+
"models",
5+
"operators",
6+
"io",
7+
]

graph_diffusers/_extras.py

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from importlib.util import find_spec
2+
3+
TORCH = bool(find_spec("torch"))
4+
TORCH_SPARSE = bool(find_spec("torch_sparse"))

graph_diffusers/_typing.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from typing import Protocol
2+
from types import ModuleType
3+
4+
import graphblas.binary
5+
import graphblas.core.matrix
6+
import graphblas.core.vector
7+
import graphblas.io
8+
import graphblas.select
9+
import graphblas.unary
10+
11+
12+
class _GraphblasModule(Protocol):
13+
def __new__(cls) -> "_GraphblasModule":
14+
raise NotImplementedError
15+
16+
class Matrix(graphblas.core.matrix.Matrix, Protocol): ...
17+
18+
class Vector(graphblas.core.vector.Vector, Protocol): ...
19+
20+
unary: ModuleType = graphblas.unary
21+
binary: ModuleType = graphblas.binary
22+
select: ModuleType = graphblas.select
23+
io: ModuleType = graphblas.io

graph_diffusers/_utils.py

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import graphblas as gb
2+
from typing import TYPE_CHECKING, Optional, Union, Callable, Any
3+
4+
if TYPE_CHECKING:
5+
from graph_diffusers._typing import _GraphblasModule as gb
6+
from graphblas.core.matrix import MatrixExpression, TransposedMatrix
7+
8+
9+
def eval_expr(expr: "MatrixExpression", out: Optional[gb.Matrix] = None, **kw: Any) -> gb.Matrix:
10+
if out is None:
11+
return expr.new(**kw)
12+
out(**kw) << expr
13+
return out
14+
15+
16+
def transpose_if(
17+
condition: bool,
18+
matrix: "Union[gb.Matrix, MatrixExpression]",
19+
) -> "Union[MatrixExpression, MatrixExpression, TransposedMatrix]":
20+
return matrix.T if condition else matrix
21+
22+
23+
def _recover_triad_count(x: int) -> float:
24+
return (1 - x) / 2 if x % 2 else x
25+
26+
27+
recover_triad_count: Callable[[gb.Matrix], "MatrixExpression"] = gb.unary.register_anonymous(_recover_triad_count)

graph_diffusers/io/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from graph_diffusers import _extras
2+
3+
__all__ = []
4+
5+
if _extras.TORCH:
6+
from graph_diffusers.io import torch
7+
8+
__all__.append("torch")

graph_diffusers/io/torch.py

+88
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import graphblas as gb
2+
from typing import TYPE_CHECKING, Optional, Union, Hashable
3+
import functools as ft
4+
import torch
5+
from torch_sparse import SparseTensor
6+
import warnings
7+
8+
if TYPE_CHECKING:
9+
from graph_diffusers._typing import _GraphblasModule as gb
10+
from numpy.typing import DTypeLike
11+
12+
# TODO check if torch_sparse installed
13+
14+
15+
def torch_to_graphblas(
16+
edge_index: torch.Tensor,
17+
*,
18+
num_nodes: Optional[int] = None,
19+
weighted: bool = False,
20+
dtype: "Optional[DTypeLike]" = None,
21+
) -> gb.Matrix:
22+
if isinstance(edge_index, SparseTensor):
23+
return torch_sparse_tensor_to_graphblas(edge_index, weighted=weighted, dtype=dtype)
24+
if edge_index.is_sparse_csr:
25+
return torch_sparse_csr_to_graphblas(edge_index, weighted=weighted, dtype=dtype)
26+
return torch_edge_index_to_graphblas(edge_index, num_nodes=num_nodes, dtype=dtype)
27+
28+
29+
def torch_sparse_csr_to_graphblas(
30+
adj_t: torch.Tensor, *, weighted: bool = False, dtype: "Optional[DTypeLike]" = None
31+
) -> gb.Matrix:
32+
if not isinstance(dtype, Hashable):
33+
warnings.warn(
34+
f"Unhashable dtype {dtype} passed when converting from torch to graphblas." "The result will not be cached."
35+
)
36+
return _torch_edge_index_to_graphblas.__wrapped__(adj_t, weighted=weighted, dtype=dtype)
37+
return _torch_sparse_csr_to_graphblas(adj_t, weighted=weighted, dtype=dtype)
38+
39+
40+
def torch_sparse_tensor_to_graphblas(
41+
adj_t: SparseTensor, *, weighted: bool = False, dtype: "Optional[DTypeLike]" = None
42+
) -> gb.Matrix:
43+
return torch_sparse_csr_to_graphblas(
44+
adj_t.to_torch_sparse_csr_tensor(),
45+
weighted=weighted,
46+
dtype=dtype,
47+
)
48+
49+
50+
def torch_edge_index_to_graphblas(
51+
edge_index: Union[torch.Tensor, SparseTensor],
52+
*,
53+
num_nodes: Optional[int] = None,
54+
dtype: "Optional[DTypeLike]" = None,
55+
) -> gb.Matrix:
56+
if not isinstance(dtype, Hashable):
57+
warnings.warn(
58+
f"Unhashable dtype {dtype} passed when converting from torch to graphblas." "The result will not be cached."
59+
)
60+
return _torch_edge_index_to_graphblas.__wrapped__(edge_index, num_nodes=num_nodes, dtype=dtype)
61+
return _torch_edge_index_to_graphblas(edge_index, num_nodes=num_nodes, dtype=dtype)
62+
63+
64+
@ft.lru_cache(maxsize=1)
65+
def _torch_sparse_csr_to_graphblas(
66+
adj_t: torch.Tensor,
67+
weighted: bool,
68+
dtype: "Optional[DTypeLike]",
69+
) -> gb.Matrix:
70+
if not adj_t.is_sparse_csr:
71+
adj_t = adj_t.to_sparse_csr()
72+
return gb.Matrix.from_csr(
73+
indptr=adj_t.crow_indices().detach().cpu().numpy(),
74+
col_indices=adj_t.col_indices().detach().cpu().numpy(),
75+
values=1.0 if not weighted else adj_t.values().detach().cpu().numpy(),
76+
nrows=adj_t.shape[0],
77+
ncols=adj_t.shape[0],
78+
dtype=dtype,
79+
)
80+
81+
82+
@ft.lru_cache(maxsize=1)
83+
def _torch_edge_index_to_graphblas(
84+
edge_index: torch.Tensor,
85+
num_nodes: Optional[int],
86+
dtype: "Optional[DTypeLike]",
87+
) -> gb.Matrix:
88+
return gb.Matrix.from_coo(*edge_index, dtype=dtype, nrows=num_nodes, ncols=num_nodes)

graph_diffusers/models.py

+71
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
from typing import Union, Literal, TYPE_CHECKING
3+
import graphblas as gb
4+
from itertools import chain
5+
from abc import ABCMeta, abstractmethod
6+
import graph_diffusers.operators as op
7+
8+
if TYPE_CHECKING:
9+
from graph_diffusers._typing import _GraphblasModule as gb
10+
11+
12+
class Diffusion(metaclass=ABCMeta):
13+
@abstractmethod
14+
def num_features(self, in_features: int) -> int: ...
15+
16+
@abstractmethod
17+
def propagate(self, adj: gb.Matrix, X: np.ndarray) -> np.ndarray: ...
18+
19+
20+
class SimpleGCNDiffusion(Diffusion):
21+
def __init__(self, k: int) -> None:
22+
self.k = k
23+
24+
def num_features(self, in_features: int) -> int:
25+
return in_features
26+
27+
def propagate(self, adj: gb.Matrix, X: np.ndarray) -> np.ndarray:
28+
adj_gcn = op.gcn_norm(adj)
29+
return op.power(adj_gcn, self.k)(X)
30+
31+
32+
class SIGNDiffusion(Diffusion):
33+
def __init__(
34+
self,
35+
s: int,
36+
p: int,
37+
t: int,
38+
s_norm: Union[Literal["gcn"], Literal["rw"]] = "gcn",
39+
p_norm: Union[Literal["gcn"], Literal["rw"]] = "rw",
40+
t_norm: Union[Literal["gcn"], Literal["rw"]] = "rw",
41+
) -> None:
42+
self.s = s
43+
self.p = p
44+
self.t = t
45+
self.s_norm = s_norm
46+
self.p_norm = p_norm
47+
self.t_norm = t_norm
48+
if self.r < 1:
49+
raise ValueError
50+
51+
@property
52+
def r(self) -> int:
53+
return self.s + self.p + self.t
54+
55+
def num_features(self, in_features: int) -> int:
56+
return self.r * in_features
57+
58+
def propagate(self, adj: gb.Matrix, X: np.ndarray) -> np.ndarray:
59+
ops = []
60+
if s := self.s:
61+
simple_diffuser = op.simple(op.norm(adj, method=self.s_norm))
62+
ops.append(op.diffuse_powers(simple_diffuser, X, s))
63+
if p := self.p:
64+
ppr_diffuser = op.appnp(op.norm(adj, method=self.p_norm))
65+
ops.append(op.diffuse_powers(ppr_diffuser, X, p))
66+
if t := self.t:
67+
adj_triangle = op.triangle(adj, directed=False)
68+
triangle_diffuser = op.simple(op.norm(adj_triangle, method=self.t_norm, copy=False))
69+
ops.append(op.diffuse_powers(triangle_diffuser, X, t))
70+
71+
return np.hstack(tuple(chain(*ops)), dtype=X.dtype)

0 commit comments

Comments
 (0)