From 8ab56423693f6a61c86b4b840742c422c86d35ff Mon Sep 17 00:00:00 2001 From: Rino Lee Date: Thu, 26 Aug 2021 07:40:49 -0700 Subject: [PATCH] Add 2x4 structured sparsity model to Model Garden PiperOrigin-RevId: 393122268 --- .../sparsity/keras/prune_integration_test.py | 54 +++++++++++++++---- .../core/sparsity/keras/pruning_wrapper.py | 7 --- .../sparsity/keras/pruning_wrapper_test.py | 12 +++-- 3 files changed, 53 insertions(+), 20 deletions(-) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py index 93527dc5e..5c36aeb14 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/prune_integration_test.py @@ -365,6 +365,13 @@ def testPrunesSingleLayer_ReachesTargetSparsity(self, layer_type): 'input_shape': [(8)], 'm_by_n': (1, 2), }, + { + 'testcase_name': 'DepthwiseConv_2by4', + 'layer_type': tf.keras.layers.DepthwiseConv2D, + 'layer_arg': [3], + 'input_shape': (7, 7, 32), + 'm_by_n': (2, 4), + }, ) def testMbyNSparsityPruning_SupportedLayers(self, @@ -392,18 +399,45 @@ def testMbyNSparsityPruning_SupportedLayers(self, test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n) self._check_strip_pruning_matches_original(model, sparsity_ratio) - def testSparsityPruningMbyN_NonSupportedLayers(self): - """Check layer that is not supported for m by n sparsity.""" - self.params.update({'sparsity_m_by_n': (2, 4)}) - - model = keras.Sequential() - layer_type = tf.keras.layers.SeparableConv1D - args, input_shape = ([4, 3], (3, 6)) + def testSparsityPruningMbyN_SupportedSubclassLayers(self): + """Check subclass layer that is supported for m by n sparsity.""" + m_by_n = (2, 4) + self.params.update({'sparsity_m_by_n': m_by_n}) + class SubclassLayer(tf.keras.layers.Layer): + + def __init__(self): + super(SubclassLayer, self).__init__() + self.conv1 = tf.keras.layers.Conv2D( + 2, 3, activation='relu', padding='same', input_shape=[7, 7, 3]) + self.conv2 = tf.keras.layers.DepthwiseConv2D(3) + self.flatten = keras.layers.Flatten() + self.dense = layers.Dense(10, activation='sigmoid') + + def call(self, inputs): + x = self.conv1(inputs) + x = self.conv2(x) + x = self.flatten(x) + x = self.dense(x) + return x + + inputs = keras.Input(shape=(7, 7, 3)) + outputs = SubclassLayer()(inputs) + model = keras.Model(inputs, outputs) with self.assertRaises(ValueError): - model.add( - prune.prune_low_magnitude( - layer_type(*args), input_shape=input_shape, **self.params)) + model = prune.prune_low_magnitude(model, **self.params) + + model.compile( + loss='categorical_crossentropy', optimizer='sgd', metrics=['accuracy']) + + test_utils.assert_model_sparsity(self, 0.0, model) + model.fit( + np.random.randn(*self._batch(model.input.get_shape().as_list(), 32)), + np.random.randn(*self._batch(model.output.get_shape().as_list(), 32)), + callbacks=[pruning_callbacks.UpdatePruningStep()]) + + test_utils.assert_model_sparsity_m_by_n(self, model, m_by_n) + self._check_strip_pruning_matches_original(model, 0.5) @parameterized.parameters(prune_registry.PruneRegistry._RNN_LAYERS - {keras.layers.RNN}) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index a8918058e..d6d14ea25 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -129,13 +129,6 @@ def __init__(self, self.sparsity_m_by_n = None if sparsity_m_by_n: - # Sparsity m_by_n can be applied only to Conv2D and Dense layers. - if (not isinstance(layer, tf.keras.layers.Conv2D) and - not isinstance(layer, tf.keras.layers.Dense)): - raise ValueError('Structural sparsity M by N is applicable only ' - 'to `Conv2D` and `Dense` layers. You passed: ' - '{input}'.format(input=layer.__class__)) - self.sparsity_m_by_n = convert_to_tuple_of_two_int( sparsity_m_by_n, 'sparsity_m_by_n') diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py index 3573d18ab..89ff765f7 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py @@ -156,12 +156,18 @@ def testCollectPrunableLayers(self): self.assertLen(pruning_wrapper.collect_prunable_layers(self.model), 5) - def testConv3DNonPrunableWithSparsityMbyN(self): + def testConv3DWeightNotPrunedWithSparsityMbyN(self): layer = keras.layers.Conv3D(2, 3) inputs = keras.layers.Input(shape=(4, 28, 28, 28, 1)) _ = layer(inputs) - with self.assertRaises(ValueError): - pruning_wrapper.PruneLowMagnitude(layer, sparsity_m_by_n=(2, 4)) + self.model.add(Prune(layer, sparsity_m_by_n=(2, 4))) + + pruned_layers = pruning_wrapper.collect_prunable_layers(self.model) + + self.assertLen(pruned_layers, 1) + # Only rank-2 (e.g, Conv2D) or rank-4 (e.g, Dense) weight are pruned with + # M-by-N sparsity. + self.assertLen(pruned_layers[0].pruning_vars, 0) if __name__ == '__main__':