Skip to content

Commit 85aa820

Browse files
committed
adds versioning for release
1 parent 25a3da0 commit 85aa820

File tree

6 files changed

+102
-19
lines changed

6 files changed

+102
-19
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
# graph-diffusers
2-
Diffusion patterns for graph machine learning
2+
Diffusion patterns for graph machine learning based on GraphBLAS.

graph_diffusers/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from graph_diffusers import models, operators, io
2+
from graph_diffusers._version import __version__
23

34
__all__ = [
45
"models",

graph_diffusers/_version.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
__version__ = "0.0.0"

graph_diffusers/io/torch.py

+23-16
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,19 @@
11
import graphblas as gb
2-
from typing import TYPE_CHECKING, Optional, Union, Hashable
2+
from typing import TYPE_CHECKING, Optional, Hashable
33
import functools as ft
44
import torch
5-
from torch_sparse import SparseTensor
65
import warnings
6+
from graph_diffusers import _extras
77

88
if TYPE_CHECKING:
99
from graph_diffusers._typing import _GraphblasModule as gb
1010
from numpy.typing import DTypeLike
1111

12-
# TODO check if torch_sparse installed
12+
if _extras.TORCH_SPARSE:
13+
from torch_sparse import SparseTensor as _SparseTensor
14+
else:
15+
16+
class _SparseTensor: ...
1317

1418

1519
def torch_to_graphblas(
@@ -19,7 +23,7 @@ def torch_to_graphblas(
1923
weighted: bool = False,
2024
dtype: "Optional[DTypeLike]" = None,
2125
) -> gb.Matrix:
22-
if isinstance(edge_index, SparseTensor):
26+
if isinstance(edge_index, _SparseTensor):
2327
return torch_sparse_tensor_to_graphblas(edge_index, weighted=weighted, dtype=dtype)
2428
if edge_index.is_sparse_csr:
2529
return torch_sparse_csr_to_graphblas(edge_index, weighted=weighted, dtype=dtype)
@@ -37,30 +41,33 @@ def torch_sparse_csr_to_graphblas(
3741
return _torch_sparse_csr_to_graphblas(adj_t, weighted=weighted, dtype=dtype)
3842

3943

40-
def torch_sparse_tensor_to_graphblas(
41-
adj_t: SparseTensor, *, weighted: bool = False, dtype: "Optional[DTypeLike]" = None
42-
) -> gb.Matrix:
43-
return torch_sparse_csr_to_graphblas(
44-
adj_t.to_torch_sparse_csr_tensor(),
45-
weighted=weighted,
46-
dtype=dtype,
47-
)
48-
49-
5044
def torch_edge_index_to_graphblas(
51-
edge_index: Union[torch.Tensor, SparseTensor],
45+
edge_index: torch.Tensor,
5246
*,
5347
num_nodes: Optional[int] = None,
5448
dtype: "Optional[DTypeLike]" = None,
5549
) -> gb.Matrix:
5650
if not isinstance(dtype, Hashable):
5751
warnings.warn(
58-
f"Unhashable dtype {dtype} passed when converting from torch to graphblas." "The result will not be cached."
52+
f"Unhashable dtype {dtype} passed when converting from torch to graphblas. The result will not be cached."
5953
)
6054
return _torch_edge_index_to_graphblas.__wrapped__(edge_index, num_nodes=num_nodes, dtype=dtype)
6155
return _torch_edge_index_to_graphblas(edge_index, num_nodes=num_nodes, dtype=dtype)
6256

6357

58+
if _extras.TORCH_SPARSE:
59+
import torch_sparse
60+
61+
def torch_sparse_tensor_to_graphblas(
62+
adj_t: torch_sparse.SparseTensor, *, weighted: bool = False, dtype: "Optional[DTypeLike]" = None
63+
) -> gb.Matrix:
64+
return torch_sparse_csr_to_graphblas(
65+
adj_t.to_torch_sparse_csr_tensor(),
66+
weighted=weighted,
67+
dtype=dtype,
68+
)
69+
70+
6471
@ft.lru_cache(maxsize=1)
6572
def _torch_sparse_csr_to_graphblas(
6673
adj_t: torch.Tensor,

pyproject.toml

+20-2
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,28 @@
11
[tool.poetry]
22
name = "graph-diffusers"
3-
version = "0.1.0-dev"
4-
description = "Diffusion operators for graph machine learning"
3+
version = "0.0.0"
4+
description = "Diffusion operators for graph machine learning based on GraphBLAS"
55
authors = ["Aaron Zolnai-Lucas <azolnailucas@gmail.com>"]
66
license = "MIT"
77
readme = "README.md"
8+
classifiers = [
9+
"License :: OSI Approved :: MIT License",
10+
"Operating System :: OS Independent",
11+
"Topic :: Scientific/Engineering",
12+
"Topic :: Scientific/Engineering :: Mathematics",
13+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
14+
"Intended Audience :: Developers",
15+
"Intended Audience :: Science/Research",
16+
"Intended Audience :: Education",
17+
"Programming Language :: Python :: 3",
18+
"Programming Language :: Python :: 3.8",
19+
"Programming Language :: Python :: 3.9",
20+
"Programming Language :: Python :: 3.10",
21+
"Programming Language :: Python :: 3.11",
22+
"Programming Language :: Python :: 3.12"
23+
24+
]
25+
urls = { Homepage = "https://github.com/aaronzo/graph-diffusers" }
826

927
[[tool.poetry.source]]
1028
name = "pytorch-cpu"

release.sh

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#!/bin/sh
2+
set -e
3+
4+
repository="$1"
5+
version="$2"
6+
7+
cwd=$(pwd)
8+
cd `dirname $0`
9+
10+
if [ -z "$repository" ]; then
11+
echo "Usage: $0 <repository> <version>"
12+
exit 1
13+
fi
14+
15+
if [ -z "$version" ]; then
16+
echo "Usage: $0 <repository> <version>"
17+
exit 1
18+
fi
19+
20+
echo "Running poetry lock and install ..."
21+
poetry lock --no-update
22+
poetry install
23+
24+
echo "Setting package version ..."
25+
poetry version $2
26+
echo "__version__ = \"$2\"" > graph_diffusers/_version.py
27+
28+
echo "Running tests ..."
29+
poetry run pytest tests/
30+
31+
echo "Running tests without torch_sparse ..."
32+
poetry run pip uninstall -y torch_sparse
33+
poetry run pytest tests/
34+
35+
echo "Running tests without torch ..."
36+
poetry run pip uninstall -y torch
37+
poetry run pytest tests/
38+
39+
echo "Restoring environment ..."
40+
poetry install
41+
42+
echo "Cleaning up dist/ ..."
43+
rm -rf dist/
44+
45+
echo "Publishing to $repository ..."
46+
poetry publish \
47+
--repository $repository \
48+
--build \
49+
--no-interaction
50+
51+
git tag -m "Release $version" $version
52+
git push origin $version
53+
54+
echo "Done!"
55+
56+
cd $cwd

0 commit comments

Comments
 (0)