Skip to content

Commit 74e9344

Browse files
authored
support selection of tests that use op infos (#17)
1 parent fa9624c commit 74e9344

File tree

4 files changed

+91
-68
lines changed

4 files changed

+91
-68
lines changed

pytest_pytorch/plugin.py

+44-68
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
1+
import inspect
12
import re
23
import unittest.mock
34
import warnings
4-
from typing import Pattern
55

66
from _pytest.unittest import TestCaseFunction, UnitTestCase
77

88
try:
9-
from torch.testing._internal.common_device_type import get_device_type_test_bases
10-
from torch.testing._internal.common_utils import TestCase as PyTorchTestCaseTemplate
9+
from torch.testing._internal.common_utils import TestCase as TestCaseTemplate
1110

1211
TORCH_AVAILABLE = True
1312
except ImportError:
@@ -17,93 +16,73 @@
1716
"Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported."
1817
)
1918

20-
def get_device_type_test_bases():
21-
return []
2219

23-
class PyTorchTestCaseTemplate:
24-
pass
20+
class TemplatedName(str):
21+
def __new__(cls, name, template_name):
22+
self = super().__new__(cls, name)
23+
self._template_name = template_name
24+
return self
2525

26+
def __eq__(self, other):
27+
exact_match = str.__eq__(self, other)
28+
if exact_match:
29+
return True
2630

27-
class PytestPyTorchInternalError(Exception):
28-
def __init__(self, msg):
29-
super().__init__(
30-
f"{msg}\n"
31-
f"This is an internal error of the pytest plugin 'pytest-pytorch'."
32-
f"If you encounter this during normal operation, please file an issue "
33-
f"https://github.com/Quansight/pytest-pytorch/issues."
34-
)
35-
36-
37-
TEST_BASE_DEVICE_PATTERN = re.compile(r"(?P<device>\w*?)TestBase$")
38-
39-
40-
def _get_devices():
41-
devices = []
42-
for test_base in get_device_type_test_bases():
43-
match = TEST_BASE_DEVICE_PATTERN.match(test_base.__name__)
44-
if not match:
45-
raise PytestPyTorchInternalError(
46-
f"Unable to extract device name from {test_base.__name__}"
47-
)
48-
49-
devices.append(match.group("device"))
31+
if not self._template_name:
32+
return False
5033

51-
return devices
34+
return str.__eq__(self._template_name, other)
5235

36+
def __hash__(self):
37+
return super().__hash__()
5338

54-
DEVICES = _get_devices()
5539

40+
class TemplatedTestCaseFunction(TestCaseFunction):
41+
_TEMPLATE_NAME_PATTERN = re.compile(r"def (?P<template_name>test_\w+)\(")
5642

57-
class TemplateName(str):
58-
_TEMPLATE_NAME_PATTERN: Pattern
43+
@classmethod
44+
def _extract_template_name(cls, callobj):
45+
if not callobj:
46+
return None
5947

60-
def __init__(self, _):
61-
super().__init__()
62-
match = self._TEMPLATE_NAME_PATTERN.match(self)
48+
match = cls._TEMPLATE_NAME_PATTERN.search(inspect.getsource(callobj))
6349
if not match:
64-
raise PytestPyTorchInternalError(
65-
f"Unable to extract template name from {self}"
66-
)
67-
self._template_name = match.group("template_name")
68-
69-
def __eq__(self, other):
70-
return str.__eq__(self, other) or str.__eq__(self._template_name, other)
71-
72-
def __hash__(self) -> int:
73-
return super().__hash__()
74-
75-
76-
class TestCaseFunctionTemplateName(TemplateName):
77-
_TEMPLATE_NAME_PATTERN = re.compile(
78-
fr"(?P<template_name>\w*?)_({'|'.join([device.lower() for device in DEVICES])})"
79-
)
50+
return None
8051

52+
return match.group("template_name")
8153

82-
class PyTorchTestCaseFunction(TestCaseFunction):
8354
@classmethod
84-
def from_parent(cls, parent, *, name, **kw):
55+
def from_parent(cls, parent, *, name, callobj, **kw):
8556
return super().from_parent(
86-
parent, name=TestCaseFunctionTemplateName(name), **kw
57+
parent, name=TemplatedName(name, cls._extract_template_name(callobj)), **kw
8758
)
8859

8960

90-
class TestCaseTemplateName(TemplateName):
91-
_TEMPLATE_NAME_PATTERN = re.compile(
92-
fr"(?P<template_name>\w*?)({'|'.join([device.upper() for device in DEVICES])})"
93-
)
61+
class TemplatedTestCase(UnitTestCase):
62+
@classmethod
63+
def _extract_template_name(cls, name, obj):
64+
if not obj:
65+
return None
66+
67+
if not hasattr(obj, "device_type"):
68+
return None
9469

70+
return name[: -len(obj.device_type)]
9571

96-
class PyTorchTestCase(UnitTestCase):
9772
@classmethod
9873
def from_parent(cls, parent, *, name, obj=None):
99-
return super().from_parent(parent, name=TestCaseTemplateName(name), obj=obj)
74+
return super().from_parent(
75+
parent,
76+
name=TemplatedName(name, cls._extract_template_name(name, obj)),
77+
obj=obj,
78+
)
10079

10180
def collect(self):
10281
# Yes, this is a bad practice. Unfortunately, there is no other option to
10382
# inject our custom 'TestCaseFunction' without duplicating everything in
10483
# 'UnitTestCase.collect()'
10584
with unittest.mock.patch(
106-
"_pytest.unittest.TestCaseFunction", new=PyTorchTestCaseFunction
85+
"_pytest.unittest.TestCaseFunction", new=TemplatedTestCaseFunction
10786
):
10887
yield from super().collect()
10988

@@ -113,12 +92,9 @@ def pytest_pycollect_makeitem(collector, name, obj):
11392
return None
11493

11594
try:
116-
if (
117-
not issubclass(obj, PyTorchTestCaseTemplate)
118-
or obj is PyTorchTestCaseTemplate
119-
):
95+
if not issubclass(obj, TestCaseTemplate) or obj is TestCaseTemplate:
12096
return None
12197
except Exception:
12298
return None
12399

124-
return PyTorchTestCase.from_parent(collector, name=name, obj=obj)
100+
return TemplatedTestCase.from_parent(collector, name=name, obj=obj)

tests/assets/test_op_infos.py

+18
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import torch
2+
from torch.testing._internal.common_device_type import (
3+
dtypes,
4+
instantiate_device_type_tests,
5+
ops,
6+
)
7+
from torch.testing._internal.common_methods_invocations import OpInfo
8+
from torch.testing._internal.common_utils import TestCase
9+
10+
11+
class TestFoo(TestCase):
12+
@dtypes(torch.float16, torch.int32)
13+
@ops([OpInfo("add"), OpInfo("sub")])
14+
def test_bar(self, device, dtype, op):
15+
assert True
16+
17+
18+
instantiate_device_type_tests(TestFoo, globals())

tests/test_plugin.py

+27
Original file line numberDiff line numberDiff line change
@@ -205,3 +205,30 @@ def test_nested_names(testdir, file, cmds, outcomes):
205205
testdir.copy_example(file)
206206
result = testdir.runpytest(*cmds)
207207
result.assert_outcomes(**outcomes)
208+
209+
210+
@make_params(
211+
"test_op_infos.py",
212+
Config(
213+
"*testcase-*test-*op-*device-*dtype",
214+
new_cmds=(),
215+
legacy_cmds=(),
216+
passed=8,
217+
),
218+
Config(
219+
"1testcase-*test-*op-*device-*dtype",
220+
new_cmds="::TestFoo",
221+
legacy_cmds=("-k", "TestFoo"),
222+
passed=8,
223+
),
224+
Config(
225+
"1testcase-1test-*op-*device-*dtype",
226+
new_cmds="::TestFoo::test_bar",
227+
legacy_cmds=("-k", "TestFoo and test_bar"),
228+
passed=8,
229+
),
230+
)
231+
def test_op_infos(testdir, file, cmds, outcomes):
232+
testdir.copy_example(file)
233+
result = testdir.runpytest(*cmds)
234+
result.assert_outcomes(**outcomes)

tox.ini

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ deps =
3131
torch
3232
# The nightlies do not specify numpy as requirement
3333
numpy
34+
# Importing OpInfo requires scipy unconditionally
35+
scipy
3436
commands =
3537
pytest -c pytest.ini {posargs}
3638

0 commit comments

Comments
 (0)