Skip to content

Commit 4612eed

Browse files
authored
Fix 1D squeeze_batch_dim test
1 parent 244b2e7 commit 4612eed

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

xbatcher/tests/test_generators.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -170,23 +170,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
170170

171171

172172
@pytest.mark.parametrize('bsize', [5, 10])
173-
def test_batch_3d_squeeze_batch_dim(sample_ds_3d, bsize):
173+
def test_batch_1d_squeeze_batch_dim(sample_ds_1d, bsize):
174174
xbsize = 20
175175
bg = BatchGenerator(
176-
sample_ds_3d,
177-
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
176+
sample_ds_1d,
177+
input_dims={'x': xbsize},
178178
squeeze_batch_dim=False,
179179
)
180180
for ds_batch in bg:
181-
assert ds_batch['x'].shape == [1, bsize, xbsize]
181+
assert list(ds_batch['foo'].shape) == [1, xbsize]
182182

183183
bg2 = BatchGenerator(
184-
sample_ds_3d,
185-
input_dims={'time': 1, 'y': bsize, 'x': xbsize},
184+
sample_ds_1d,
185+
input_dims={'x': xbsize},
186186
squeeze_batch_dim=True,
187187
)
188-
for ds_batch in bg:
189-
assert ds_batch['x'].shape == [bsize, xbsize]
188+
for ds_batch in bg2:
189+
assert list(ds_batch['foo'].shape) == [xbsize]
190190

191191

192192
def test_preload_batch_false(sample_ds_1d):

0 commit comments

Comments
 (0)