Skip to content

Commit 300638f

Browse files
Replace tensorflow.python.keras with keras. tensorflow.python.keras is an old copy and is deprecated.
PiperOrigin-RevId: 485364304
1 parent 0e08dea commit 300638f

25 files changed

+7
-106
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_integration_test.py

-16
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
import numpy as np
2222
import tensorflow as tf
2323

24-
from tensorflow.python.keras import keras_parameterized
2524
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2625
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2726
from tensorflow_model_optimization.python.core.clustering.keras.experimental import cluster as experimental_cluster
@@ -153,7 +152,6 @@ def testDefaultClusteringInit(self):
153152
)["cluster_centroids_init"]
154153
self.assertEqual(init_method, CentroidInitialization.KMEANS_PLUS_PLUS)
155154

156-
@keras_parameterized.run_all_keras_modes
157155
def testValuesRemainClusteredAfterTraining(self):
158156
"""Verifies that training a clustered model does not destroy the clusters."""
159157
original_model = keras.Sequential([
@@ -175,7 +173,6 @@ def testValuesRemainClusteredAfterTraining(self):
175173
unique_weights = set(weights_as_list)
176174
self.assertLessEqual(len(unique_weights), self.params["number_of_clusters"])
177175

178-
@keras_parameterized.run_all_keras_modes
179176
def testSparsityIsPreservedDuringTraining(self):
180177
"""Set a specific random seed.
181178
@@ -230,7 +227,6 @@ def testSparsityIsPreservedDuringTraining(self):
230227
nr_of_unique_weights_after,
231228
clustering_params["number_of_clusters"])
232229

233-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
234230
def testEndToEndSequential(self):
235231
"""Test End to End clustering - sequential model."""
236232
original_model = keras.Sequential([
@@ -247,7 +243,6 @@ def clusters_check(stripped_model):
247243

248244
self.end_to_end_testing(original_model, clusters_check)
249245

250-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
251246
def testEndToEndConv1DAndConv1DTranspose(self):
252247
"""Test End to End clustering - model with Conv1D and Conv1DTranspose."""
253248
inp = layers.Input(batch_shape=(1, 16))
@@ -372,7 +367,6 @@ def clusters_check(stripped_model):
372367

373368
self.end_to_end_testing(original_model, clusters_check)
374369

375-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
376370
def testEndToEndFunctional(self):
377371
"""Test End to End clustering - functional model."""
378372
inputs = keras.layers.Input(shape=(5,))
@@ -389,7 +383,6 @@ def clusters_check(stripped_model):
389383

390384
self.end_to_end_testing(original_model, clusters_check)
391385

392-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
393386
def testEndToEndDeepLayer(self):
394387
"""Test End to End clustering for the model with deep layer."""
395388
internal_model = tf.keras.Sequential(
@@ -416,7 +409,6 @@ def clusters_check(stripped_model):
416409

417410
self.end_to_end_testing(original_model, clusters_check)
418411

419-
@keras_parameterized.run_all_keras_modes(always_skip_v1=True)
420412
def testEndToEndDeepLayer2(self):
421413
"""Test End to End clustering for the model with 2 deep layers."""
422414
internal_model = tf.keras.Sequential(
@@ -454,7 +446,6 @@ def clusters_check(stripped_model):
454446

455447
self.end_to_end_testing(original_model, clusters_check)
456448

457-
@keras_parameterized.run_all_keras_modes
458449
def testWeightsAreLearningDuringClustering(self):
459450
"""Verifies that weights are updated during training a clustered model.
460451
@@ -541,7 +532,6 @@ def _assertNbUniqueWeights(self, weight, expected_unique_weights):
541532
nr_unique_weights = len(np.unique(weight.numpy().flatten()))
542533
assert nr_unique_weights == expected_unique_weights
543534

544-
@keras_parameterized.run_all_keras_modes
545535
def testClusterSimpleRNN(self):
546536
model = keras.models.Sequential()
547537
model.add(keras.layers.Embedding(self.max_features, 16,
@@ -564,7 +554,6 @@ def testClusterSimpleRNN(self):
564554

565555
self._train(stripped_model)
566556

567-
@keras_parameterized.run_all_keras_modes
568557
def testClusterLSTM(self):
569558
model = keras.models.Sequential()
570559
model.add(keras.layers.Embedding(self.max_features, 16,
@@ -587,7 +576,6 @@ def testClusterLSTM(self):
587576

588577
self._train(stripped_model)
589578

590-
@keras_parameterized.run_all_keras_modes
591579
def testClusterGRU(self):
592580
model = keras.models.Sequential()
593581
model.add(keras.layers.Embedding(self.max_features, 16,
@@ -610,7 +598,6 @@ def testClusterGRU(self):
610598

611599
self._train(stripped_model)
612600

613-
@keras_parameterized.run_all_keras_modes
614601
def testClusterBidirectional(self):
615602
model = keras.models.Sequential()
616603
model.add(
@@ -634,7 +621,6 @@ def testClusterBidirectional(self):
634621
expected_unique_weights=self.params_clustering["number_of_clusters"],
635622
)
636623

637-
@keras_parameterized.run_all_keras_modes
638624
def testClusterStackedRNNCells(self):
639625
model = keras.models.Sequential()
640626
model.add(
@@ -685,7 +671,6 @@ def _get_model(self):
685671
model = tf.keras.Model(inputs=inp, outputs=out)
686672
return model
687673

688-
@keras_parameterized.run_all_keras_modes
689674
def testMHA(self):
690675
model = self._get_model()
691676

@@ -736,7 +721,6 @@ def _get_model(self):
736721
model = tf.keras.Model(inputs=inp, outputs=out)
737722
return model
738723

739-
@keras_parameterized.run_all_keras_modes
740724
def testPerChannel(self):
741725
model = self._get_model()
742726

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

-34
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from absl.testing import parameterized
2222
import tensorflow as tf
2323

24-
from tensorflow.python.keras import keras_parameterized
2524
from tensorflow_model_optimization.python.core.clustering.keras import cluster
2625
from tensorflow_model_optimization.python.core.clustering.keras import cluster_config
2726
from tensorflow_model_optimization.python.core.clustering.keras import cluster_wrapper
@@ -162,15 +161,13 @@ def _count_clustered_layers(self, model):
162161
count += 1
163162
return count
164163

165-
@keras_parameterized.run_all_keras_modes
166164
def testClusterKerasClusterableLayer(self):
167165
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly."""
168166
wrapped_layer = self._build_clustered_layer_model(
169167
self.keras_clusterable_layer)
170168

171169
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
172170

173-
@keras_parameterized.run_all_keras_modes
174171
def testClusterKerasClusterableLayerWithSparsityPreservation(self):
175172
"""Verifies that a built-in keras layer marked as clusterable is being clustered correctly when sparsity preservation is enabled."""
176173
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -180,7 +177,6 @@ def testClusterKerasClusterableLayerWithSparsityPreservation(self):
180177

181178
self._validate_clustered_layer(self.keras_clusterable_layer, wrapped_layer)
182179

183-
@keras_parameterized.run_all_keras_modes
184180
def testClusterKerasNonClusterableLayer(self):
185181
"""Verifies that a built-in keras layer not marked as clusterable is not being clustered."""
186182
wrapped_layer = self._build_clustered_layer_model(
@@ -190,7 +186,6 @@ def testClusterKerasNonClusterableLayer(self):
190186
wrapped_layer)
191187
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
192188

193-
@keras_parameterized.run_all_keras_modes
194189
def testDepthwiseConv2DLayerNonClusterable(self):
195190
"""Verifies that we don't cluster a DepthwiseConv2D layer, because clustering of this type of layer gives big unrecoverable accuracy loss."""
196191
wrapped_layer = self._build_clustered_layer_model(
@@ -200,7 +195,6 @@ def testDepthwiseConv2DLayerNonClusterable(self):
200195
wrapped_layer)
201196
self.assertEqual([], wrapped_layer.layer.get_clusterable_weights())
202197

203-
@keras_parameterized.run_all_keras_modes
204198
def testDenseLayer(self):
205199
"""Verifies that we can cluster a Dense layer."""
206200
input_shape = (28, 1)
@@ -214,7 +208,6 @@ def testDenseLayer(self):
214208
self.assertEqual([1, 10],
215209
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
216210

217-
@keras_parameterized.run_all_keras_modes
218211
def testConv1DLayer(self):
219212
"""Verifies that we can cluster a Conv1D layer."""
220213
input_shape = (28, 1)
@@ -227,7 +220,6 @@ def testConv1DLayer(self):
227220
self.assertEqual([5, 1, 3],
228221
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
229222

230-
@keras_parameterized.run_all_keras_modes
231223
def testConv1DTransposeLayer(self):
232224
"""Verifies that we can cluster a Conv1DTranspose layer."""
233225
input_shape = (28, 1)
@@ -240,7 +232,6 @@ def testConv1DTransposeLayer(self):
240232
self.assertEqual([5, 3, 1],
241233
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
242234

243-
@keras_parameterized.run_all_keras_modes
244235
def testConv2DLayer(self):
245236
"""Verifies that we can cluster a Conv2D layer."""
246237
input_shape = (28, 28, 1)
@@ -253,7 +244,6 @@ def testConv2DLayer(self):
253244
self.assertEqual([4, 5, 1, 3],
254245
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
255246

256-
@keras_parameterized.run_all_keras_modes
257247
def testConv2DTransposeLayer(self):
258248
"""Verifies that we can cluster a Conv2DTranspose layer."""
259249
input_shape = (28, 28, 1)
@@ -266,7 +256,6 @@ def testConv2DTransposeLayer(self):
266256
self.assertEqual([4, 5, 3, 1],
267257
wrapped_layer.layer.get_clusterable_weights()[0][1].shape)
268258

269-
@keras_parameterized.run_all_keras_modes
270259
def testConv3DLayer(self):
271260
"""Verifies that we can cluster a Conv3D layer."""
272261
input_shape = (28, 28, 28, 1)
@@ -287,7 +276,6 @@ def testClusterKerasUnsupportedLayer(self):
287276
with self.assertRaises(ValueError):
288277
cluster.cluster_weights(keras_unsupported_layer, **self.params)
289278

290-
@keras_parameterized.run_all_keras_modes
291279
def testClusterCustomClusterableLayer(self):
292280
"""Verifies that a custom clusterable layer is being clustered correctly."""
293281
wrapped_layer = self._build_clustered_layer_model(
@@ -297,7 +285,6 @@ def testClusterCustomClusterableLayer(self):
297285
self.assertEqual([('kernel', wrapped_layer.layer.kernel)],
298286
wrapped_layer.layer.get_clusterable_weights())
299287

300-
@keras_parameterized.run_all_keras_modes
301288
def testClusterCustomClusterableLayerWithSparsityPreservation(self):
302289
"""Verifies that a custom clusterable layer is being clustered correctly when sparsity preservation is enabled."""
303290
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -424,7 +411,6 @@ def testStripClusteringSequentialModelWithBiasConstraint(self):
424411
keras_file = os.path.join(tmp_dir_name, 'cluster_test')
425412
stripped_model.save(keras_file, save_traces=True)
426413

427-
@keras_parameterized.run_all_keras_modes
428414
def testClusterSequentialModelSelectively(self):
429415
clustered_model = keras.Sequential()
430416
clustered_model.add(
@@ -437,7 +423,6 @@ def testClusterSequentialModelSelectively(self):
437423
self.assertNotIsInstance(clustered_model.layers[1],
438424
cluster_wrapper.ClusterWeights)
439425

440-
@keras_parameterized.run_all_keras_modes
441426
def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
442427
"""Verifies that layers within a sequential model can be clustered selectively when sparsity preservation is enabled."""
443428
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -454,7 +439,6 @@ def testClusterSequentialModelSelectivelyWithSparsityPreservation(self):
454439
self.assertNotIsInstance(clustered_model.layers[1],
455440
cluster_wrapper.ClusterWeights)
456441

457-
@keras_parameterized.run_all_keras_modes
458442
def testClusterFunctionalModelSelectively(self):
459443
"""Verifies that layers within a functional model can be clustered selectively."""
460444
i1 = keras.Input(shape=(10,))
@@ -469,7 +453,6 @@ def testClusterFunctionalModelSelectively(self):
469453
self.assertNotIsInstance(clustered_model.layers[3],
470454
cluster_wrapper.ClusterWeights)
471455

472-
@keras_parameterized.run_all_keras_modes
473456
def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
474457
"""Verifies that layers within a functional model can be clustered selectively when sparsity preservation is enabled."""
475458
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -486,7 +469,6 @@ def testClusterFunctionalModelSelectivelyWithSparsityPreservation(self):
486469
self.assertNotIsInstance(clustered_model.layers[3],
487470
cluster_wrapper.ClusterWeights)
488471

489-
@keras_parameterized.run_all_keras_modes
490472
def testClusterModelValidLayersSuccessful(self):
491473
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered."""
492474
model = keras.Sequential([
@@ -500,7 +482,6 @@ def testClusterModelValidLayersSuccessful(self):
500482
for layer, clustered_layer in zip(model.layers, clustered_model.layers):
501483
self._validate_clustered_layer(layer, clustered_layer)
502484

503-
@keras_parameterized.run_all_keras_modes
504485
def testClusterModelValidLayersSuccessfulWithSparsityPreservation(self):
505486
"""Verifies that clustering a sequential model results in all clusterable layers within the model being clustered when sparsity preservation is enabled."""
506487
preserve_sparsity_params = {'preserve_sparsity': True}
@@ -540,7 +521,6 @@ def testClusterModelCustomNonClusterableLayerRaisesError(self):
540521
self.custom_clusterable_layer, custom_non_clusterable_layer
541522
]), **self.params)
542523

543-
@keras_parameterized.run_all_keras_modes
544524
def testClusterModelDoesNotWrapAlreadyWrappedLayer(self):
545525
"""Verifies that clustering a model that contains an already clustered layer does not result in wrapping the clustered layer into another cluster_wrapper."""
546526
model = keras.Sequential([
@@ -579,7 +559,6 @@ def testClusterSequentialModelNoInput(self):
579559
clustered_model = cluster.cluster_weights(model, **self.params)
580560
self.assertEqual(self._count_clustered_layers(clustered_model), 2)
581561

582-
@keras_parameterized.run_all_keras_modes
583562
def testClusterSequentialModelWithInput(self):
584563
"""Verifies that a sequential model with an input layer is being clustered correctly."""
585564
# With InputLayer
@@ -607,7 +586,6 @@ def testClusterSequentialModelPreservesBuiltStateNoInput(self):
607586
json.loads(clustered_model.to_json()))
608587
self.assertEqual(loaded_model.built, False)
609588

610-
@keras_parameterized.run_all_keras_modes
611589
def testClusterSequentialModelPreservesBuiltStateWithInput(self):
612590
"""Verifies that clustering a sequential model with an input layer preserves the built state of the model."""
613591
# With InputLayer
@@ -625,7 +603,6 @@ def testClusterSequentialModelPreservesBuiltStateWithInput(self):
625603
json.loads(clustered_model.to_json()))
626604
self.assertEqual(loaded_model.built, True)
627605

628-
@keras_parameterized.run_all_keras_modes
629606
def testClusterFunctionalModelPreservesBuiltState(self):
630607
"""Verifies that clustering a functional model preserves the built state of the model."""
631608
i1 = keras.Input(shape=(10,))
@@ -644,7 +621,6 @@ def testClusterFunctionalModelPreservesBuiltState(self):
644621
json.loads(clustered_model.to_json()))
645622
self.assertEqual(loaded_model.built, True)
646623

647-
@keras_parameterized.run_all_keras_modes
648624
def testClusterFunctionalModel(self):
649625
"""Verifies that a functional model is being clustered correctly."""
650626
i1 = keras.Input(shape=(10,))
@@ -656,7 +632,6 @@ def testClusterFunctionalModel(self):
656632
clustered_model = cluster.cluster_weights(model, **self.params)
657633
self.assertEqual(self._count_clustered_layers(clustered_model), 3)
658634

659-
@keras_parameterized.run_all_keras_modes
660635
def testClusterFunctionalModelWithLayerReused(self):
661636
"""Verifies that a layer reused within a functional model multiple times is only being clustered once."""
662637
# The model reuses the Dense() layer. Make sure it's only clustered once.
@@ -668,22 +643,19 @@ def testClusterFunctionalModelWithLayerReused(self):
668643
clustered_model = cluster.cluster_weights(model, **self.params)
669644
self.assertEqual(self._count_clustered_layers(clustered_model), 1)
670645

671-
@keras_parameterized.run_all_keras_modes
672646
def testClusterSubclassModel(self):
673647
"""Verifies that attempting to cluster an instance of a subclass of keras.Model raises an exception."""
674648
model = TestModel()
675649
with self.assertRaises(ValueError):
676650
_ = cluster.cluster_weights(model, **self.params)
677651

678-
@keras_parameterized.run_all_keras_modes
679652
def testClusterSubclassModelAsSubmodel(self):
680653
"""Verifies that attempting to cluster a model with submodel that is a subclass throws an exception."""
681654
model_subclass = TestModel()
682655
model = keras.Sequential([layers.Dense(10), model_subclass])
683656
with self.assertRaisesRegex(ValueError, 'Subclassed models.*'):
684657
_ = cluster.cluster_weights(model, **self.params)
685658

686-
@keras_parameterized.run_all_keras_modes
687659
def testStripClusteringSequentialModel(self):
688660
"""Verifies that stripping the clustering wrappers from a sequential model produces the expected config."""
689661
model = keras.Sequential([
@@ -697,7 +669,6 @@ def testStripClusteringSequentialModel(self):
697669
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
698670
self.assertEqual(model.get_config(), stripped_model.get_config())
699671

700-
@keras_parameterized.run_all_keras_modes
701672
def testClusterStrippingFunctionalModel(self):
702673
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""
703674
i1 = keras.Input(shape=(10,))
@@ -713,7 +684,6 @@ def testClusterStrippingFunctionalModel(self):
713684
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
714685
self.assertEqual(model.get_config(), stripped_model.get_config())
715686

716-
@keras_parameterized.run_all_keras_modes
717687
def testClusterWeightsStrippedWeights(self):
718688
"""Verifies that stripping the clustering wrappers from a functional model preserves the clustered weights."""
719689
i1 = keras.Input(shape=(10,))
@@ -728,7 +698,6 @@ def testClusterWeightsStrippedWeights(self):
728698
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
729699
self.assertLen(stripped_model.get_weights(), cluster_weight_length)
730700

731-
@keras_parameterized.run_all_keras_modes
732701
def testStrippedKernel(self):
733702
"""Verifies that stripping the clustering wrappers from a functional model restores the layers kernel and the layers weight array to the new clustered weight value."""
734703
i1 = keras.Input(shape=(1, 1, 1))
@@ -746,7 +715,6 @@ def testStrippedKernel(self):
746715
self.assertIsNot(stripped_conv2d_layer.kernel, clustered_kernel)
747716
self.assertIn(stripped_conv2d_layer.kernel, stripped_conv2d_layer.weights)
748717

749-
@keras_parameterized.run_all_keras_modes
750718
def testStripSelectivelyClusteredFunctionalModel(self):
751719
"""Verifies that invoking strip_clustering() on a selectively clustered functional model strips the clustering wrappers from the clustered layers."""
752720
i1 = keras.Input(shape=(10,))
@@ -761,7 +729,6 @@ def testStripSelectivelyClusteredFunctionalModel(self):
761729
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
762730
self.assertIsInstance(stripped_model.layers[2], layers.Dense)
763731

764-
@keras_parameterized.run_all_keras_modes
765732
def testStripSelectivelyClusteredSequentialModel(self):
766733
"""Verifies that invoking strip_clustering() on a selectively clustered sequential model strips the clustering wrappers from the clustered layers."""
767734
clustered_model = keras.Sequential([
@@ -775,7 +742,6 @@ def testStripSelectivelyClusteredSequentialModel(self):
775742
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
776743
self.assertIsInstance(stripped_model.layers[0], layers.Dense)
777744

778-
@keras_parameterized.run_all_keras_modes
779745
def testStripClusteringAndSetOriginalWeightsBack(self):
780746
"""Verifies that we can set_weights onto the stripped model."""
781747
model = keras.Sequential([

0 commit comments

Comments
 (0)