File tree 1 file changed +8
-8
lines changed
1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -170,23 +170,23 @@ def test_batch_3d_2d_input_concat(sample_ds_3d, bsize):
170
170
171
171
172
172
@pytest .mark .parametrize ('bsize' , [5 , 10 ])
173
- def test_batch_3d_squeeze_batch_dim ( sample_ds_3d , bsize ):
173
+ def test_batch_1d_squeeze_batch_dim ( sample_ds_1d , bsize ):
174
174
xbsize = 20
175
175
bg = BatchGenerator (
176
- sample_ds_3d ,
177
- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
176
+ sample_ds_1d ,
177
+ input_dims = {'x' : xbsize },
178
178
squeeze_batch_dim = False ,
179
179
)
180
180
for ds_batch in bg :
181
- assert ds_batch ['x ' ].shape == [1 , bsize , xbsize ]
181
+ assert list ( ds_batch ['foo ' ].shape ) == [1 , xbsize ]
182
182
183
183
bg2 = BatchGenerator (
184
- sample_ds_3d ,
185
- input_dims = {'time' : 1 , 'y' : bsize , ' x' : xbsize },
184
+ sample_ds_1d ,
185
+ input_dims = {'x' : xbsize },
186
186
squeeze_batch_dim = True ,
187
187
)
188
- for ds_batch in bg :
189
- assert ds_batch ['x ' ].shape == [bsize , xbsize ]
188
+ for ds_batch in bg2 :
189
+ assert list ( ds_batch ['foo ' ].shape ) == [xbsize ]
190
190
191
191
192
192
def test_preload_batch_false (sample_ds_1d ):
You can’t perform that action at this time.
0 commit comments