Skip to content

Commit c46cd53

Browse files
committed
Add filter helper
1 parent 181f785 commit c46cd53

File tree

2 files changed

+43
-2
lines changed

2 files changed

+43
-2
lines changed

pytensor/loop/basic.py

+32
Original file line numberDiff line numberDiff line change
@@ -118,3 +118,35 @@ def reduce(
118118
if len(final_states) == 1:
119119
return final_states[0]
120120
return final_states
121+
122+
123+
def filter(
124+
fn,
125+
sequences,
126+
non_sequences=None,
127+
go_backwards=False,
128+
):
129+
if not isinstance(sequences, (tuple, list)):
130+
sequences = [sequences]
131+
132+
_, masks = scan(
133+
fn=fn,
134+
sequences=sequences,
135+
non_sequences=non_sequences,
136+
go_backwards=go_backwards,
137+
)
138+
139+
if not all(mask.dtype == "bool" for mask in masks):
140+
raise TypeError("The output of filter fn should be a boolean variable")
141+
if len(masks) == 1:
142+
masks = [masks[0]] * len(sequences)
143+
elif len(masks) != len(sequences):
144+
raise ValueError(
145+
"filter fn must return one variable or len(sequences), but it returned {len(masks)}"
146+
)
147+
148+
filtered_sequences = [seq[mask] for seq, mask in zip(sequences, masks)]
149+
150+
if len(filtered_sequences) == 1:
151+
return filtered_sequences[0]
152+
return filtered_sequences

tests/loop/basic.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
import numpy as np
22

33
import pytensor
4-
from pytensor.loop.basic import map, reduce, scan
5-
from pytensor.tensor import vector, zeros
4+
from pytensor.loop.basic import filter, map, reduce, scan
5+
from pytensor.tensor import eq, vector, zeros
66

77

88
def test_scan():
@@ -38,3 +38,12 @@ def test_reduce():
3838
np.testing.assert_almost_equal(
3939
y.eval({xs: np.arange(10)}), np.arange(10).cumsum()[-1]
4040
)
41+
42+
43+
def test_filter():
44+
xs = vector("xs")
45+
ys = filter(
46+
fn=lambda x: eq(x % 2, 0),
47+
sequences=xs,
48+
)
49+
np.testing.assert_array_equal(ys.eval({xs: np.arange(0, 20)}), np.arange(0, 20, 2))

0 commit comments

Comments
 (0)