Skip to content

Commit 0279a23

Browse files
authored
remove torch as install requirement (#14)
* remove torch as install requirement * make the warning more concise
1 parent bd98f6b commit 0279a23

File tree

4 files changed

+23
-4
lines changed

4 files changed

+23
-4
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ order_by_type = true
2828
combine_star = true
2929
filter_files = true
3030

31-
known_third_party = ["pytest"]
31+
known_third_party = ["pytest", "_pytest"]
3232
known_first_party = ["torch", "pytest_pytorch"]
3333
known_local_folder = ["tests"]
3434

pytest_pytorch/plugin.py

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,27 @@
11
import re
22
import unittest.mock
3+
import warnings
34
from typing import Pattern
45

56
from _pytest.unittest import TestCaseFunction, UnitTestCase
67

7-
from torch.testing._internal.common_device_type import get_device_type_test_bases
8-
from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate
8+
try:
9+
from torch.testing._internal.common_device_type import get_device_type_test_bases
10+
from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate
11+
12+
TORCH_AVAILABLE = True
13+
except ImportError:
14+
TORCH_AVAILABLE = False
15+
16+
warnings.warn(
17+
"Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported."
18+
)
19+
20+
def get_device_type_test_bases():
21+
return []
22+
23+
class PyTorchTestCaseTemplate:
24+
pass
925

1026

1127
class PytestPyTorchInternalError(Exception):
@@ -93,6 +109,9 @@ def collect(self):
93109

94110

95111
def pytest_pycollect_makeitem(collector, name, obj):
112+
if not TORCH_AVAILABLE:
113+
return None
114+
96115
try:
97116
if (
98117
not issubclass(obj, PyTorchTestCaseTemplate)

setup.cfg

-1
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ include_package_data = True
3131
python_requires = >=3.6
3232
install_requires =
3333
pytest
34-
torch
3534

3635
[options.packages.find]
3736
exclude =

tox.ini

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ pytorch_channel = nightly
2828
deps =
2929
pytest >= 6
3030
pytest-mock >= 3.1
31+
torch
3132
# The nightlies do not specify numpy as requirement
3233
numpy
3334
commands =

0 commit comments

Comments
 (0)