1
1
import graphblas as gb
2
- from typing import TYPE_CHECKING , Optional , Union , Hashable
2
+ from typing import TYPE_CHECKING , Optional , Hashable
3
3
import functools as ft
4
4
import torch
5
- from torch_sparse import SparseTensor
6
5
import warnings
6
+ from graph_diffusers import _extras
7
7
8
8
if TYPE_CHECKING :
9
9
from graph_diffusers ._typing import _GraphblasModule as gb
10
10
from numpy .typing import DTypeLike
11
11
12
- # TODO check if torch_sparse installed
12
+ if _extras .TORCH_SPARSE :
13
+ from torch_sparse import SparseTensor as _SparseTensor
14
+ else :
15
+
16
+ class _SparseTensor : ...
13
17
14
18
15
19
def torch_to_graphblas (
@@ -19,7 +23,7 @@ def torch_to_graphblas(
19
23
weighted : bool = False ,
20
24
dtype : "Optional[DTypeLike]" = None ,
21
25
) -> gb .Matrix :
22
- if isinstance (edge_index , SparseTensor ):
26
+ if isinstance (edge_index , _SparseTensor ):
23
27
return torch_sparse_tensor_to_graphblas (edge_index , weighted = weighted , dtype = dtype )
24
28
if edge_index .is_sparse_csr :
25
29
return torch_sparse_csr_to_graphblas (edge_index , weighted = weighted , dtype = dtype )
@@ -37,30 +41,33 @@ def torch_sparse_csr_to_graphblas(
37
41
return _torch_sparse_csr_to_graphblas (adj_t , weighted = weighted , dtype = dtype )
38
42
39
43
40
- def torch_sparse_tensor_to_graphblas (
41
- adj_t : SparseTensor , * , weighted : bool = False , dtype : "Optional[DTypeLike]" = None
42
- ) -> gb .Matrix :
43
- return torch_sparse_csr_to_graphblas (
44
- adj_t .to_torch_sparse_csr_tensor (),
45
- weighted = weighted ,
46
- dtype = dtype ,
47
- )
48
-
49
-
50
44
def torch_edge_index_to_graphblas (
51
- edge_index : Union [ torch .Tensor , SparseTensor ] ,
45
+ edge_index : torch .Tensor ,
52
46
* ,
53
47
num_nodes : Optional [int ] = None ,
54
48
dtype : "Optional[DTypeLike]" = None ,
55
49
) -> gb .Matrix :
56
50
if not isinstance (dtype , Hashable ):
57
51
warnings .warn (
58
- f"Unhashable dtype { dtype } passed when converting from torch to graphblas." " The result will not be cached."
52
+ f"Unhashable dtype { dtype } passed when converting from torch to graphblas. The result will not be cached."
59
53
)
60
54
return _torch_edge_index_to_graphblas .__wrapped__ (edge_index , num_nodes = num_nodes , dtype = dtype )
61
55
return _torch_edge_index_to_graphblas (edge_index , num_nodes = num_nodes , dtype = dtype )
62
56
63
57
58
+ if _extras .TORCH_SPARSE :
59
+ import torch_sparse
60
+
61
+ def torch_sparse_tensor_to_graphblas (
62
+ adj_t : torch_sparse .SparseTensor , * , weighted : bool = False , dtype : "Optional[DTypeLike]" = None
63
+ ) -> gb .Matrix :
64
+ return torch_sparse_csr_to_graphblas (
65
+ adj_t .to_torch_sparse_csr_tensor (),
66
+ weighted = weighted ,
67
+ dtype = dtype ,
68
+ )
69
+
70
+
64
71
@ft .lru_cache (maxsize = 1 )
65
72
def _torch_sparse_csr_to_graphblas (
66
73
adj_t : torch .Tensor ,
0 commit comments