21
21
from absl .testing import parameterized
22
22
import tensorflow as tf
23
23
24
- from tensorflow .python .keras import keras_parameterized
25
24
from tensorflow_model_optimization .python .core .clustering .keras import cluster
26
25
from tensorflow_model_optimization .python .core .clustering .keras import cluster_config
27
26
from tensorflow_model_optimization .python .core .clustering .keras import cluster_wrapper
@@ -162,15 +161,13 @@ def _count_clustered_layers(self, model):
162
161
count += 1
163
162
return count
164
163
165
- @keras_parameterized .run_all_keras_modes
166
164
def testClusterKerasClusterableLayer (self ):
167
165
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly."""
168
166
wrapped_layer = self ._build_clustered_layer_model (
169
167
self .keras_clusterable_layer )
170
168
171
169
self ._validate_clustered_layer (self .keras_clusterable_layer , wrapped_layer )
172
170
173
- @keras_parameterized .run_all_keras_modes
174
171
def testClusterKerasClusterableLayerWithSparsityPreservation (self ):
175
172
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly when sparsity preservation is enabled."""
176
173
preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -180,7 +177,6 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self):
180
177
181
178
self ._validate_clustered_layer (self .keras_clusterable_layer , wrapped_layer )
182
179
183
- @keras_parameterized .run_all_keras_modes
184
180
def testClusterKerasNonClusterableLayer (self ):
185
181
"""Verifies that a built-in keras layer not marked as clusterable is not being clustered."""
186
182
wrapped_layer = self ._build_clustered_layer_model (
@@ -190,7 +186,6 @@ def testClusterKerasNonClusterableLayer(self):
190
186
wrapped_layer )
191
187
self .assertEqual ([], wrapped_layer .layer .get_clusterable_weights ())
192
188
193
- @keras_parameterized .run_all_keras_modes
194
189
def testDepthwiseConv2DLayerNonClusterable (self ):
195
190
"""Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196
191
wrapped_layer = self ._build_clustered_layer_model (
@@ -200,7 +195,6 @@ def testDepthwiseConv2DLayerNonClusterable(self):
200
195
wrapped_layer )
201
196
self .assertEqual ([], wrapped_layer .layer .get_clusterable_weights ())
202
197
203
- @keras_parameterized .run_all_keras_modes
204
198
def testDenseLayer (self ):
205
199
"""Verifies that we can cluster a Dense layer."""
206
200
input_shape = (28 , 1 )
@@ -214,7 +208,6 @@ def testDenseLayer(self):
214
208
self .assertEqual ([1 , 10 ],
215
209
wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
216
210
217
- @keras_parameterized .run_all_keras_modes
218
211
def testConv1DLayer (self ):
219
212
"""Verifies that we can cluster a Conv1D layer."""
220
213
input_shape = (28 , 1 )
@@ -227,7 +220,6 @@ def testConv1DLayer(self):
227
220
self .assertEqual ([5 , 1 , 3 ],
228
221
wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
229
222
230
- @keras_parameterized .run_all_keras_modes
231
223
def testConv1DTransposeLayer (self ):
232
224
"""Verifies that we can cluster a Conv1DTranspose layer."""
233
225
input_shape = (28 , 1 )
@@ -240,7 +232,6 @@ def testConv1DTransposeLayer(self):
240
232
self .assertEqual ([5 , 3 , 1 ],
241
233
wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
242
234
243
- @keras_parameterized .run_all_keras_modes
244
235
def testConv2DLayer (self ):
245
236
"""Verifies that we can cluster a Conv2D layer."""
246
237
input_shape = (28 , 28 , 1 )
@@ -253,7 +244,6 @@ def testConv2DLayer(self):
253
244
self .assertEqual ([4 , 5 , 1 , 3 ],
254
245
wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
255
246
256
- @keras_parameterized .run_all_keras_modes
257
247
def testConv2DTransposeLayer (self ):
258
248
"""Verifies that we can cluster a Conv2DTranspose layer."""
259
249
input_shape = (28 , 28 , 1 )
@@ -266,7 +256,6 @@ def testConv2DTransposeLayer(self):
266
256
self .assertEqual ([4 , 5 , 3 , 1 ],
267
257
wrapped_layer .layer .get_clusterable_weights ()[0 ][1 ].shape )
268
258
269
- @keras_parameterized .run_all_keras_modes
270
259
def testConv3DLayer (self ):
271
260
"""Verifies that we can cluster a Conv3D layer."""
272
261
input_shape = (28 , 28 , 28 , 1 )
@@ -287,7 +276,6 @@ def testClusterKerasUnsupportedLayer(self):
287
276
with self .assertRaises (ValueError ):
288
277
cluster .cluster_weights (keras_unsupported_layer , ** self .params )
289
278
290
- @keras_parameterized .run_all_keras_modes
291
279
def testClusterCustomClusterableLayer (self ):
292
280
"""Verifies that a custom clusterable layer is being clustered correctly."""
293
281
wrapped_layer = self ._build_clustered_layer_model (
@@ -297,7 +285,6 @@ def testClusterCustomClusterableLayer(self):
297
285
self .assertEqual ([('kernel' , wrapped_layer .layer .kernel )],
298
286
wrapped_layer .layer .get_clusterable_weights ())
299
287
300
- @keras_parameterized .run_all_keras_modes
301
288
def testClusterCustomClusterableLayerWithSparsityPreservation (self ):
302
289
"""Verifies that a custom clusterable layer is being clustered correctly when sparsity preservation is enabled."""
303
290
preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -424,7 +411,6 @@ def testStripClusteringSequentialModelWithBiasConstraint(self):
424
411
keras_file = os .path .join (tmp_dir_name , 'cluster_test' )
425
412
stripped_model .save (keras_file , save_traces = True )
426
413
427
- @keras_parameterized .run_all_keras_modes
428
414
def testClusterSequentialModelSelectively (self ):
429
415
clustered_model = keras .Sequential ()
430
416
clustered_model .add (
@@ -437,7 +423,6 @@ def testClusterSequentialModelSelectively(self):
437
423
self .assertNotIsInstance (clustered_model .layers [1 ],
438
424
cluster_wrapper .ClusterWeights )
439
425
440
- @keras_parameterized .run_all_keras_modes
441
426
def testClusterSequentialModelSelectivelyWithSparsityPreservation (self ):
442
427
"""Verifies that layers within a sequential model can be clustered selectively when sparsity preservation is enabled."""
443
428
preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -454,7 +439,6 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
454
439
self .assertNotIsInstance (clustered_model .layers [1 ],
455
440
cluster_wrapper .ClusterWeights )
456
441
457
- @keras_parameterized .run_all_keras_modes
458
442
def testClusterFunctionalModelSelectively (self ):
459
443
"""Verifies that layers within a functional model can be clustered selectively."""
460
444
i1 = keras .Input (shape = (10 ,))
@@ -469,7 +453,6 @@ def testClusterFunctionalModelSelectively(self):
469
453
self .assertNotIsInstance (clustered_model .layers [3 ],
470
454
cluster_wrapper .ClusterWeights )
471
455
472
- @keras_parameterized .run_all_keras_modes
473
456
def testClusterFunctionalModelSelectivelyWithSparsityPreservation (self ):
474
457
"""Verifies that layers within a functional model can be clustered selectively when sparsity preservation is enabled."""
475
458
preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -486,7 +469,6 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
486
469
self .assertNotIsInstance (clustered_model .layers [3 ],
487
470
cluster_wrapper .ClusterWeights )
488
471
489
- @keras_parameterized .run_all_keras_modes
490
472
def testClusterModelValidLayersSuccessful (self ):
491
473
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered."""
492
474
model = keras .Sequential ([
@@ -500,7 +482,6 @@ def testClusterModelValidLayersSuccessful(self):
500
482
for layer , clustered_layer in zip (model .layers , clustered_model .layers ):
501
483
self ._validate_clustered_layer (layer , clustered_layer )
502
484
503
- @keras_parameterized .run_all_keras_modes
504
485
def testClusterModelValidLayersSuccessfulWithSparsityPreservation (self ):
505
486
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered when sparsity preservation is enabled."""
506
487
preserve_sparsity_params = {'preserve_sparsity' : True }
@@ -540,7 +521,6 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
540
521
self .custom_clusterable_layer , custom_non_clusterable_layer
541
522
]), ** self .params )
542
523
543
- @keras_parameterized .run_all_keras_modes
544
524
def testClusterModelDoesNotWrapAlreadyWrappedLayer (self ):
545
525
"""Verifies that clustering a model that contains an already clustered layer does not result in wrapping the clustered layer into another cluster_wrapper."""
546
526
model = keras .Sequential ([
@@ -579,7 +559,6 @@ def testClusterSequentialModelNoInput(self):
579
559
clustered_model = cluster .cluster_weights (model , ** self .params )
580
560
self .assertEqual (self ._count_clustered_layers (clustered_model ), 2 )
581
561
582
- @keras_parameterized .run_all_keras_modes
583
562
def testClusterSequentialModelWithInput (self ):
584
563
"""Verifies that a sequential model with an input layer is being clustered correctly."""
585
564
# With InputLayer
@@ -607,7 +586,6 @@ def testClusterSequentialModelPreservesBuiltStateNoInput(self):
607
586
json .loads (clustered_model .to_json ()))
608
587
self .assertEqual (loaded_model .built , False )
609
588
610
- @keras_parameterized .run_all_keras_modes
611
589
def testClusterSequentialModelPreservesBuiltStateWithInput (self ):
612
590
"""Verifies that clustering a sequential model with an input layer preserves the built state of the model."""
613
591
# With InputLayer
@@ -625,7 +603,6 @@ def testClusterSequentialModelPreservesBuiltStateWithInput(self):
625
603
json .loads (clustered_model .to_json ()))
626
604
self .assertEqual (loaded_model .built , True )
627
605
628
- @keras_parameterized .run_all_keras_modes
629
606
def testClusterFunctionalModelPreservesBuiltState (self ):
630
607
"""Verifies that clustering a functional model preserves the built state of the model."""
631
608
i1 = keras .Input (shape = (10 ,))
@@ -644,7 +621,6 @@ def testClusterFunctionalModelPreservesBuiltState(self):
644
621
json .loads (clustered_model .to_json ()))
645
622
self .assertEqual (loaded_model .built , True )
646
623
647
- @keras_parameterized .run_all_keras_modes
648
624
def testClusterFunctionalModel (self ):
649
625
"""Verifies that a functional model is being clustered correctly."""
650
626
i1 = keras .Input (shape = (10 ,))
@@ -656,7 +632,6 @@ def testClusterFunctionalModel(self):
656
632
clustered_model = cluster .cluster_weights (model , ** self .params )
657
633
self .assertEqual (self ._count_clustered_layers (clustered_model ), 3 )
658
634
659
- @keras_parameterized .run_all_keras_modes
660
635
def testClusterFunctionalModelWithLayerReused (self ):
661
636
"""Verifies that a layer reused within a functional model multiple times is only being clustered once."""
662
637
# The model reuses the Dense() layer. Make sure it's only clustered once.
@@ -668,22 +643,19 @@ def testClusterFunctionalModelWithLayerReused(self):
668
643
clustered_model = cluster .cluster_weights (model , ** self .params )
669
644
self .assertEqual (self ._count_clustered_layers (clustered_model ), 1 )
670
645
671
- @keras_parameterized .run_all_keras_modes
672
646
def testClusterSubclassModel (self ):
673
647
"""Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception."""
674
648
model = TestModel ()
675
649
with self .assertRaises (ValueError ):
676
650
_ = cluster .cluster_weights (model , ** self .params )
677
651
678
- @keras_parameterized .run_all_keras_modes
679
652
def testClusterSubclassModelAsSubmodel (self ):
680
653
"""Verifies that attempting to cluster a model with submodel that is a subclass throws an exception."""
681
654
model_subclass = TestModel ()
682
655
model = keras .Sequential ([layers .Dense (10 ), model_subclass ])
683
656
with self .assertRaisesRegex (ValueError , 'Subclassed models.*' ):
684
657
_ = cluster .cluster_weights (model , ** self .params )
685
658
686
- @keras_parameterized .run_all_keras_modes
687
659
def testStripClusteringSequentialModel (self ):
688
660
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
689
661
model = keras .Sequential ([
@@ -697,7 +669,6 @@ def testStripClusteringSequentialModel(self):
697
669
self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
698
670
self .assertEqual (model .get_config (), stripped_model .get_config ())
699
671
700
- @keras_parameterized .run_all_keras_modes
701
672
def testClusterStrippingFunctionalModel (self ):
702
673
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
703
674
i1 = keras .Input (shape = (10 ,))
@@ -713,7 +684,6 @@ def testClusterStrippingFunctionalModel(self):
713
684
self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
714
685
self .assertEqual (model .get_config (), stripped_model .get_config ())
715
686
716
- @keras_parameterized .run_all_keras_modes
717
687
def testClusterWeightsStrippedWeights (self ):
718
688
"""Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights."""
719
689
i1 = keras .Input (shape = (10 ,))
@@ -728,7 +698,6 @@ def testClusterWeightsStrippedWeights(self):
728
698
self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
729
699
self .assertLen (stripped_model .get_weights (), cluster_weight_length )
730
700
731
- @keras_parameterized .run_all_keras_modes
732
701
def testStrippedKernel (self ):
733
702
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734
703
i1 = keras .Input (shape = (1 , 1 , 1 ))
@@ -746,7 +715,6 @@ def testStrippedKernel(self):
746
715
self .assertIsNot (stripped_conv2d_layer .kernel , clustered_kernel )
747
716
self .assertIn (stripped_conv2d_layer .kernel , stripped_conv2d_layer .weights )
748
717
749
- @keras_parameterized .run_all_keras_modes
750
718
def testStripSelectivelyClusteredFunctionalModel (self ):
751
719
"""Verifies that invoking strip_clustering() on a selectively clustered functional model strips the clustering wrappers from the clustered layers."""
752
720
i1 = keras .Input (shape = (10 ,))
@@ -761,7 +729,6 @@ def testStripSelectivelyClusteredFunctionalModel(self):
761
729
self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
762
730
self .assertIsInstance (stripped_model .layers [2 ], layers .Dense )
763
731
764
- @keras_parameterized .run_all_keras_modes
765
732
def testStripSelectivelyClusteredSequentialModel (self ):
766
733
"""Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers."""
767
734
clustered_model = keras .Sequential ([
@@ -775,7 +742,6 @@ def testStripSelectivelyClusteredSequentialModel(self):
775
742
self .assertEqual (self ._count_clustered_layers (stripped_model ), 0 )
776
743
self .assertIsInstance (stripped_model .layers [0 ], layers .Dense )
777
744
778
- @keras_parameterized .run_all_keras_modes
779
745
def testStripClusteringAndSetOriginalWeightsBack (self ):
780
746
"""Verifies that we can set_weights onto the stripped model."""
781
747
model = keras .Sequential ([
0 commit comments