1
+ import inspect
1
2
import re
2
3
import unittest .mock
3
4
import warnings
4
- from typing import Pattern
5
5
6
6
from _pytest .unittest import TestCaseFunction , UnitTestCase
7
7
8
8
try :
9
- from torch .testing ._internal .common_device_type import get_device_type_test_bases
10
- from torch .testing ._internal .common_utils import TestCase as PyTorchTestCaseTemplate
9
+ from torch .testing ._internal .common_utils import TestCase as TestCaseTemplate
11
10
12
11
TORCH_AVAILABLE = True
13
12
except ImportError :
17
16
"Disabling the plugin 'pytest-pytorch', because 'torch' could not be imported."
18
17
)
19
18
20
- def get_device_type_test_bases ():
21
- return []
22
19
23
- class PyTorchTestCaseTemplate :
24
- pass
20
+ class TemplatedName (str ):
21
+ def __new__ (cls , name , template_name ):
22
+ self = super ().__new__ (cls , name )
23
+ self ._template_name = template_name
24
+ return self
25
25
26
+ def __eq__ (self , other ):
27
+ exact_match = str .__eq__ (self , other )
28
+ if exact_match :
29
+ return True
26
30
27
- class PytestPyTorchInternalError (Exception ):
28
- def __init__ (self , msg ):
29
- super ().__init__ (
30
- f"{ msg } \n "
31
- f"This is an internal error of the pytest plugin 'pytest-pytorch'."
32
- f"If you encounter this during normal operation, please file an issue "
33
- f"https://github.com/Quansight/pytest-pytorch/issues."
34
- )
35
-
36
-
37
- TEST_BASE_DEVICE_PATTERN = re .compile (r"(?P<device>\w*?)TestBase$" )
38
-
39
-
40
- def _get_devices ():
41
- devices = []
42
- for test_base in get_device_type_test_bases ():
43
- match = TEST_BASE_DEVICE_PATTERN .match (test_base .__name__ )
44
- if not match :
45
- raise PytestPyTorchInternalError (
46
- f"Unable to extract device name from { test_base .__name__ } "
47
- )
48
-
49
- devices .append (match .group ("device" ))
31
+ if not self ._template_name :
32
+ return False
50
33
51
- return devices
34
+ return str . __eq__ ( self . _template_name , other )
52
35
36
+ def __hash__ (self ):
37
+ return super ().__hash__ ()
53
38
54
- DEVICES = _get_devices ()
55
39
40
+ class TemplatedTestCaseFunction (TestCaseFunction ):
41
+ _TEMPLATE_NAME_PATTERN = re .compile (r"def (?P<template_name>test_\w+)\(" )
56
42
57
- class TemplateName (str ):
58
- _TEMPLATE_NAME_PATTERN : Pattern
43
+ @classmethod
44
+ def _extract_template_name (cls , callobj ):
45
+ if not callobj :
46
+ return None
59
47
60
- def __init__ (self , _ ):
61
- super ().__init__ ()
62
- match = self ._TEMPLATE_NAME_PATTERN .match (self )
48
+ match = cls ._TEMPLATE_NAME_PATTERN .search (inspect .getsource (callobj ))
63
49
if not match :
64
- raise PytestPyTorchInternalError (
65
- f"Unable to extract template name from { self } "
66
- )
67
- self ._template_name = match .group ("template_name" )
68
-
69
- def __eq__ (self , other ):
70
- return str .__eq__ (self , other ) or str .__eq__ (self ._template_name , other )
71
-
72
- def __hash__ (self ) -> int :
73
- return super ().__hash__ ()
74
-
75
-
76
- class TestCaseFunctionTemplateName (TemplateName ):
77
- _TEMPLATE_NAME_PATTERN = re .compile (
78
- fr"(?P<template_name>\w*?)_({ '|' .join ([device .lower () for device in DEVICES ])} )"
79
- )
50
+ return None
80
51
52
+ return match .group ("template_name" )
81
53
82
- class PyTorchTestCaseFunction (TestCaseFunction ):
83
54
@classmethod
84
- def from_parent (cls , parent , * , name , ** kw ):
55
+ def from_parent (cls , parent , * , name , callobj , ** kw ):
85
56
return super ().from_parent (
86
- parent , name = TestCaseFunctionTemplateName (name ), ** kw
57
+ parent , name = TemplatedName (name , cls . _extract_template_name ( callobj ) ), ** kw
87
58
)
88
59
89
60
90
- class TestCaseTemplateName (TemplateName ):
91
- _TEMPLATE_NAME_PATTERN = re .compile (
92
- fr"(?P<template_name>\w*?)({ '|' .join ([device .upper () for device in DEVICES ])} )"
93
- )
61
+ class TemplatedTestCase (UnitTestCase ):
62
+ @classmethod
63
+ def _extract_template_name (cls , name , obj ):
64
+ if not obj :
65
+ return None
66
+
67
+ if not hasattr (obj , "device_type" ):
68
+ return None
94
69
70
+ return name [: - len (obj .device_type )]
95
71
96
- class PyTorchTestCase (UnitTestCase ):
97
72
@classmethod
98
73
def from_parent (cls , parent , * , name , obj = None ):
99
- return super ().from_parent (parent , name = TestCaseTemplateName (name ), obj = obj )
74
+ return super ().from_parent (
75
+ parent ,
76
+ name = TemplatedName (name , cls ._extract_template_name (name , obj )),
77
+ obj = obj ,
78
+ )
100
79
101
80
def collect (self ):
102
81
# Yes, this is a bad practice. Unfortunately, there is no other option to
103
82
# inject our custom 'TestCaseFunction' without duplicating everything in
104
83
# 'UnitTestCase.collect()'
105
84
with unittest .mock .patch (
106
- "_pytest.unittest.TestCaseFunction" , new = PyTorchTestCaseFunction
85
+ "_pytest.unittest.TestCaseFunction" , new = TemplatedTestCaseFunction
107
86
):
108
87
yield from super ().collect ()
109
88
@@ -113,12 +92,9 @@ def pytest_pycollect_makeitem(collector, name, obj):
113
92
return None
114
93
115
94
try :
116
- if (
117
- not issubclass (obj , PyTorchTestCaseTemplate )
118
- or obj is PyTorchTestCaseTemplate
119
- ):
95
+ if not issubclass (obj , TestCaseTemplate ) or obj is TestCaseTemplate :
120
96
return None
121
97
except Exception :
122
98
return None
123
99
124
- return PyTorchTestCase .from_parent (collector , name = name , obj = obj )
100
+ return TemplatedTestCase .from_parent (collector , name = name , obj = obj )
0 commit comments