Skip to content

Commit 9cf2d18

Browse files
committed
Update code style for Python 3.9
1 parent 7903236 commit 9cf2d18

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+475
-541
lines changed

pytensor/compile/builders.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""Define new Ops from existing Ops"""
22
from collections import OrderedDict
3+
from collections.abc import Sequence
34
from copy import copy
45
from functools import partial
5-
from typing import Dict, List, Optional, Sequence, Tuple, cast
6+
from typing import Optional, cast
67

78
import pytensor.tensor as at
89
from pytensor import function
@@ -83,11 +84,11 @@ def local_traverse(out):
8384

8485
def construct_nominal_fgraph(
8586
inputs: Sequence[Variable], outputs: Sequence[Variable]
86-
) -> Tuple[
87+
) -> tuple[
8788
FunctionGraph,
8889
Sequence[Variable],
89-
Dict[Variable, Variable],
90-
Dict[Variable, Variable],
90+
dict[Variable, Variable],
91+
dict[Variable, Variable],
9192
]:
9293
"""Construct an inner-`FunctionGraph` with ordered nominal inputs."""
9394
dummy_inputs = []
@@ -306,13 +307,13 @@ def _filter_rop_var(inpJ, out):
306307

307308
def __init__(
308309
self,
309-
inputs: List[Variable],
310-
outputs: List[Variable],
310+
inputs: list[Variable],
311+
outputs: list[Variable],
311312
inline: bool = False,
312313
lop_overrides: str = "default",
313314
grad_overrides: str = "default",
314315
rop_overrides: str = "default",
315-
connection_pattern: Optional[List[List[bool]]] = None,
316+
connection_pattern: Optional[list[list[bool]]] = None,
316317
name: Optional[str] = None,
317318
**kwargs,
318319
):

pytensor/compile/function/pfunc.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
44
"""
55

6+
from collections.abc import Sequence
67
from copy import copy
7-
from typing import Optional, Sequence, Union, overload
8+
from typing import Optional, Union, overload
89

910
from pytensor.compile.function.types import Function, UnusedInputError, orig_function
1011
from pytensor.compile.io import In, Out

pytensor/compile/function/types.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import time
77
import warnings
88
from itertools import chain
9-
from typing import TYPE_CHECKING, List, Optional, Tuple, Type
9+
from typing import TYPE_CHECKING, Optional
1010

1111
import numpy as np
1212

@@ -170,13 +170,13 @@ def validate(self, fgraph):
170170

171171

172172
def std_fgraph(
173-
input_specs: List[SymbolicInput],
174-
output_specs: List[SymbolicOutput],
173+
input_specs: list[SymbolicInput],
174+
output_specs: list[SymbolicOutput],
175175
accept_inplace: bool = False,
176176
fgraph: Optional[FunctionGraph] = None,
177-
features: List[Type[Feature]] = [PreserveVariableAttributes],
177+
features: list[type[Feature]] = [PreserveVariableAttributes],
178178
force_clone=False,
179-
) -> Tuple[FunctionGraph, List[SymbolicOutput]]:
179+
) -> tuple[FunctionGraph, list[SymbolicOutput]]:
180180
"""Make or set up `FunctionGraph` corresponding to the input specs and the output specs.
181181
182182
Any `SymbolicInput` in the `input_specs`, if its `update` field is not

pytensor/compile/mode.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import logging
77
import warnings
8-
from typing import Literal, Optional, Tuple, Union
8+
from typing import Literal, Optional, Union
99

1010
from pytensor.compile.function.types import Supervisor
1111
from pytensor.configdefaults import config
@@ -260,7 +260,7 @@ def apply(self, fgraph):
260260
# final pass just to make sure
261261
optdb.register("merge3", MergeOptimizer(), "fast_run", "merge", position=100)
262262

263-
_tags: Union[Tuple[str, str], Tuple]
263+
_tags: Union[tuple[str, str], tuple]
264264

265265
if config.check_stack_trace in ("raise", "warn", "log"):
266266
_tags = ("fast_run", "fast_compile")
@@ -548,7 +548,7 @@ def register_mode(name, mode):
548548
predefined_modes[name] = mode
549549

550550

551-
def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]:
551+
def get_target_language(mode=None) -> tuple[Literal["py", "c", "numba", "jax"], ...]:
552552
"""Get the compilation target language."""
553553

554554
if mode is None:

pytensor/compile/ops.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import copy
99
import pickle
1010
import warnings
11-
from typing import Dict, Tuple
1211

1312
from pytensor.graph.basic import Apply
1413
from pytensor.graph.op import Op
@@ -44,8 +43,8 @@ class ViewOp(COp):
4443
# Mapping from Type to C code (and version) to use.
4544
# In the C code, the name of the input variable is %(iname)s,
4645
# the output variable is %(oname)s.
47-
c_code_and_version: Dict = {}
48-
__props__: Tuple = ()
46+
c_code_and_version: dict = {}
47+
__props__: tuple = ()
4948
_f16_ok: bool = True
5049

5150
def make_node(self, x):
@@ -150,10 +149,10 @@ class DeepCopyOp(COp):
150149
# Mapping from Type to C code (and version) to use.
151150
# In the C code, the name of the input variable is %(iname)s,
152151
# the output variable is %(oname)s.
153-
c_code_and_version: Dict = {}
152+
c_code_and_version: dict = {}
154153

155154
check_input: bool = False
156-
__props__: Tuple = ()
155+
__props__: tuple = ()
157156
_f16_ok: bool = True
158157

159158
def __init__(self):

pytensor/compile/profiling.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import time
1717
from collections import defaultdict
1818
from contextlib import contextmanager
19-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
19+
from typing import TYPE_CHECKING, Any, Optional, Union
2020

2121
import numpy as np
2222

@@ -48,7 +48,7 @@ def extended_open(filename, mode="r"):
4848
total_graph_rewrite_time: float = 0.0
4949
total_time_linker: float = 0.0
5050

51-
_atexit_print_list: List["ProfileStats"] = []
51+
_atexit_print_list: list["ProfileStats"] = []
5252
_atexit_registered: bool = False
5353

5454

@@ -234,27 +234,27 @@ def reset(self):
234234
# Total time spent in Function.vm.__call__
235235
#
236236

237-
apply_time: Optional[Dict[Union["FunctionGraph", Variable], float]] = None
237+
apply_time: Optional[dict[Union["FunctionGraph", Variable], float]] = None
238238

239-
apply_callcount: Optional[Dict[Union["FunctionGraph", Variable], int]] = None
239+
apply_callcount: Optional[dict[Union["FunctionGraph", Variable], int]] = None
240240

241-
apply_cimpl: Optional[Dict[Apply, bool]] = None
241+
apply_cimpl: Optional[dict[Apply, bool]] = None
242242
# dict from node -> bool (1 if c, 0 if py)
243243
#
244244

245245
message: Optional[str] = None
246246
# pretty string to print in summary, to identify this output
247247
#
248248

249-
variable_shape: Dict[Variable, Any] = {}
249+
variable_shape: dict[Variable, Any] = {}
250250
# Variable -> shapes
251251
#
252252

253-
variable_strides: Dict[Variable, Any] = {}
253+
variable_strides: dict[Variable, Any] = {}
254254
# Variable -> strides
255255
#
256256

257-
variable_offset: Dict[Variable, Any] = {}
257+
variable_offset: dict[Variable, Any] = {}
258258
# Variable -> offset
259259
#
260260

@@ -274,7 +274,7 @@ def reset(self):
274274

275275
linker_node_make_thunks: float = 0.0
276276

277-
linker_make_thunk_time: Dict = {}
277+
linker_make_thunk_time: dict = {}
278278

279279
line_width = config.profiling__output_line_width
280280

pytensor/compile/sharedvalue.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import copy
44
from contextlib import contextmanager
55
from functools import singledispatch
6-
from typing import TYPE_CHECKING, List, Optional
6+
from typing import TYPE_CHECKING, Optional
77

88
from pytensor.graph.basic import Variable
99
from pytensor.graph.utils import add_tag_trace
@@ -15,7 +15,7 @@
1515
from pytensor.graph.type import Type
1616

1717

18-
__SHARED_CONTEXT__: Optional[List[Variable]] = None
18+
__SHARED_CONTEXT__: Optional[list[Variable]] = None
1919

2020

2121
@contextmanager

pytensor/configparser.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import shlex
44
import sys
55
import warnings
6+
from collections.abc import Sequence
67
from configparser import (
78
ConfigParser,
89
InterpolationError,
@@ -12,7 +13,7 @@
1213
)
1314
from functools import wraps
1415
from io import StringIO
15-
from typing import Callable, Dict, Optional, Sequence, Union
16+
from typing import Callable, Optional, Union
1617

1718
from pytensor.utils import hash_from_code
1819

@@ -72,7 +73,7 @@ def __init__(self, flags_dict: dict, pytensor_cfg, pytensor_raw_cfg):
7273
self._flags_dict = flags_dict
7374
self._pytensor_cfg = pytensor_cfg
7475
self._pytensor_raw_cfg = pytensor_raw_cfg
75-
self._config_var_dict: Dict = {}
76+
self._config_var_dict: dict = {}
7677
super().__init__()
7778

7879
def __str__(self, print_doc=True):

pytensor/gradient.py

+13-25
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,9 @@
22

33
import time
44
import warnings
5+
from collections.abc import Mapping, MutableSequence, Sequence
56
from functools import partial, reduce
6-
from typing import (
7-
TYPE_CHECKING,
8-
Callable,
9-
Dict,
10-
List,
11-
Literal,
12-
Mapping,
13-
MutableSequence,
14-
Optional,
15-
Sequence,
16-
Tuple,
17-
TypeVar,
18-
Union,
19-
)
7+
from typing import TYPE_CHECKING, Callable, Literal, Optional, TypeVar, Union
208

219
import numpy as np
2210

@@ -44,7 +32,7 @@
4432
# TODO: Add `overload` variants
4533
def as_list_or_tuple(
4634
use_list: bool, use_tuple: bool, outputs: Union[V, Sequence[V]]
47-
) -> Union[V, List[V], Tuple[V, ...]]:
35+
) -> Union[V, list[V], tuple[V, ...]]:
4836
"""Return either a single object or a list/tuple of objects.
4937
5038
If `use_list` is True, `outputs` is returned as a list (if `outputs`
@@ -206,17 +194,17 @@ def Rop(
206194
"""
207195

208196
if not isinstance(wrt, (list, tuple)):
209-
_wrt: List[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
197+
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
210198
else:
211199
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
212200

213201
if not isinstance(eval_points, (list, tuple)):
214-
_eval_points: List[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
202+
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
215203
else:
216204
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
217205

218206
if not isinstance(f, (list, tuple)):
219-
_f: List[Variable] = [pytensor.tensor.as_tensor_variable(f)]
207+
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
220208
else:
221209
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
222210

@@ -237,7 +225,7 @@ def Rop(
237225
# Tensor, Sparse have the ndim attribute
238226
pass
239227

240-
seen_nodes: Dict[Apply, Sequence[Variable]] = {}
228+
seen_nodes: dict[Apply, Sequence[Variable]] = {}
241229

242230
def _traverse(node):
243231
"""TODO: writeme"""
@@ -310,7 +298,7 @@ def _traverse(node):
310298
for out in _f:
311299
_traverse(out.owner)
312300

313-
rval: List[Optional[Variable]] = []
301+
rval: list[Optional[Variable]] = []
314302
for out in _f:
315303
if out in _wrt:
316304
rval.append(_eval_points[_wrt.index(out)])
@@ -394,19 +382,19 @@ def Lop(
394382
If `f` is a list/tuple, then return a list/tuple with the results.
395383
"""
396384
if not isinstance(eval_points, (list, tuple)):
397-
_eval_points: List[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
385+
_eval_points: list[Variable] = [pytensor.tensor.as_tensor_variable(eval_points)]
398386
else:
399387
_eval_points = [pytensor.tensor.as_tensor_variable(x) for x in eval_points]
400388

401389
if not isinstance(f, (list, tuple)):
402-
_f: List[Variable] = [pytensor.tensor.as_tensor_variable(f)]
390+
_f: list[Variable] = [pytensor.tensor.as_tensor_variable(f)]
403391
else:
404392
_f = [pytensor.tensor.as_tensor_variable(x) for x in f]
405393

406394
grads = list(_eval_points)
407395

408396
if not isinstance(wrt, (list, tuple)):
409-
_wrt: List[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
397+
_wrt: list[Variable] = [pytensor.tensor.as_tensor_variable(wrt)]
410398
else:
411399
_wrt = [pytensor.tensor.as_tensor_variable(x) for x in wrt]
412400

@@ -504,7 +492,7 @@ def grad(
504492
raise TypeError("Cost must be a scalar.")
505493

506494
if not isinstance(wrt, Sequence):
507-
_wrt: List[Variable] = [wrt]
495+
_wrt: list[Variable] = [wrt]
508496
else:
509497
_wrt = list(wrt)
510498

@@ -1677,7 +1665,7 @@ def mode_not_slow(mode):
16771665

16781666
def verify_grad(
16791667
fun: Callable,
1680-
pt: List[np.ndarray],
1668+
pt: list[np.ndarray],
16811669
n_tests: int = 2,
16821670
rng: Optional[Union[np.random.Generator, np.random.RandomState]] = None,
16831671
eps: Optional[float] = None,

0 commit comments

Comments
 (0)