2
2
3
3
import time
4
4
import warnings
5
+ from collections .abc import Mapping , MutableSequence , Sequence
5
6
from functools import partial , reduce
6
- from typing import (
7
- TYPE_CHECKING ,
8
- Callable ,
9
- Dict ,
10
- List ,
11
- Literal ,
12
- Mapping ,
13
- MutableSequence ,
14
- Optional ,
15
- Sequence ,
16
- Tuple ,
17
- TypeVar ,
18
- Union ,
19
- )
7
+ from typing import TYPE_CHECKING , Callable , Literal , Optional , TypeVar , Union
20
8
21
9
import numpy as np
22
10
44
32
# TODO: Add `overload` variants
45
33
def as_list_or_tuple (
46
34
use_list : bool , use_tuple : bool , outputs : Union [V , Sequence [V ]]
47
- ) -> Union [V , List [V ], Tuple [V , ...]]:
35
+ ) -> Union [V , list [V ], tuple [V , ...]]:
48
36
"""Return either a single object or a list/tuple of objects.
49
37
50
38
If `use_list` is True, `outputs` is returned as a list (if `outputs`
@@ -206,17 +194,17 @@ def Rop(
206
194
"""
207
195
208
196
if not isinstance (wrt , (list , tuple )):
209
- _wrt : List [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
197
+ _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
210
198
else :
211
199
_wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
212
200
213
201
if not isinstance (eval_points , (list , tuple )):
214
- _eval_points : List [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
202
+ _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
215
203
else :
216
204
_eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
217
205
218
206
if not isinstance (f , (list , tuple )):
219
- _f : List [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
207
+ _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
220
208
else :
221
209
_f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
222
210
@@ -237,7 +225,7 @@ def Rop(
237
225
# Tensor, Sparse have the ndim attribute
238
226
pass
239
227
240
- seen_nodes : Dict [Apply , Sequence [Variable ]] = {}
228
+ seen_nodes : dict [Apply , Sequence [Variable ]] = {}
241
229
242
230
def _traverse (node ):
243
231
"""TODO: writeme"""
@@ -310,7 +298,7 @@ def _traverse(node):
310
298
for out in _f :
311
299
_traverse (out .owner )
312
300
313
- rval : List [Optional [Variable ]] = []
301
+ rval : list [Optional [Variable ]] = []
314
302
for out in _f :
315
303
if out in _wrt :
316
304
rval .append (_eval_points [_wrt .index (out )])
@@ -394,19 +382,19 @@ def Lop(
394
382
If `f` is a list/tuple, then return a list/tuple with the results.
395
383
"""
396
384
if not isinstance (eval_points , (list , tuple )):
397
- _eval_points : List [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
385
+ _eval_points : list [Variable ] = [pytensor .tensor .as_tensor_variable (eval_points )]
398
386
else :
399
387
_eval_points = [pytensor .tensor .as_tensor_variable (x ) for x in eval_points ]
400
388
401
389
if not isinstance (f , (list , tuple )):
402
- _f : List [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
390
+ _f : list [Variable ] = [pytensor .tensor .as_tensor_variable (f )]
403
391
else :
404
392
_f = [pytensor .tensor .as_tensor_variable (x ) for x in f ]
405
393
406
394
grads = list (_eval_points )
407
395
408
396
if not isinstance (wrt , (list , tuple )):
409
- _wrt : List [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
397
+ _wrt : list [Variable ] = [pytensor .tensor .as_tensor_variable (wrt )]
410
398
else :
411
399
_wrt = [pytensor .tensor .as_tensor_variable (x ) for x in wrt ]
412
400
@@ -504,7 +492,7 @@ def grad(
504
492
raise TypeError ("Cost must be a scalar." )
505
493
506
494
if not isinstance (wrt , Sequence ):
507
- _wrt : List [Variable ] = [wrt ]
495
+ _wrt : list [Variable ] = [wrt ]
508
496
else :
509
497
_wrt = list (wrt )
510
498
@@ -1677,7 +1665,7 @@ def mode_not_slow(mode):
1677
1665
1678
1666
def verify_grad (
1679
1667
fun : Callable ,
1680
- pt : List [np .ndarray ],
1668
+ pt : list [np .ndarray ],
1681
1669
n_tests : int = 2 ,
1682
1670
rng : Optional [Union [np .random .Generator , np .random .RandomState ]] = None ,
1683
1671
eps : Optional [float ] = None ,
0 commit comments