From ae68a5be70791c8369e5742b6618d6e735c56a30 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 18 Jan 2023 14:10:37 -0800 Subject: [PATCH] Changes external references to `keras.utils.serialize_keras_object/deserialize_keras_object` to legacy serialization API in preparation for switching all of Keras to new serialization format. PiperOrigin-RevId: 502975348 --- .../default_8bit_quantize_registry_test.py | 4 ++-- .../default_n_bit_quantize_registry_test.py | 4 ++-- .../python/core/quantization/keras/quantize.py | 12 +++++++----- .../core/quantization/keras/quantize_annotate.py | 4 ++-- .../keras/quantize_aware_activation_test.py | 4 ++-- .../core/quantization/keras/quantize_config.py | 9 ++++++--- .../core/quantization/keras/quantize_layer.py | 4 ++-- .../core/quantization/keras/quantize_wrapper.py | 4 ++-- .../core/quantization/keras/quantizers_test.py | 4 ++-- .../core/sparsity/keras/pruning_schedule_test.py | 14 ++++++++------ .../python/core/sparsity/keras/pruning_wrapper.py | 2 +- 11 files changed, 36 insertions(+), 29 deletions(-) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py index e603c2d7f..dca020cc7 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py @@ -31,8 +31,8 @@ K = tf.keras.backend l = tf.keras.layers -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object class _TestHelper(object): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py index cc59fef52..7cb0cdfa7 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py @@ -31,8 +31,8 @@ K = tf.keras.backend l = tf.keras.layers -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object class _TestHelper(object): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py index 38fdc6924..2e67d3ba9 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize.py @@ -525,7 +525,7 @@ def _wrap_fixed_range( 'init_min': init_min, 'init_max': init_max, 'narrow_range': narrow_range}) - return tf.keras.utils.serialize_keras_object(config) + return tf.keras.utils.legacy.serialize_keras_object(config) def _is_serialized_node_data(nested): @@ -601,8 +601,9 @@ def fix_input_output_range( init_min=input_min, init_max=input_max, narrow_range=narrow_range) - serialized_fixed_input_quantizer = tf.keras.utils.serialize_keras_object( - fixed_input_quantizer) + serialized_fixed_input_quantizer = ( + tf.keras.utils.legacy.serialize_keras_object(fixed_input_quantizer) + ) if _is_functional_model(model): input_layer_list = _nested_to_flatten_node_data_list(config['input_layers']) @@ -685,8 +686,9 @@ def remove_input_range(model): """ config = model.get_config() no_input_quantizer = quantizers.NoQuantizer() - serialized_input_quantizer = tf.keras.utils.serialize_keras_object( - no_input_quantizer) + serialized_input_quantizer = tf.keras.utils.legacy.serialize_keras_object( + no_input_quantizer + ) if _is_functional_model(model): input_layer_list = _nested_to_flatten_node_data_list(config['input_layers']) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py index e41686221..77fc8daf0 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py @@ -23,8 +23,8 @@ import tensorflow as tf -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object class QuantizeAnnotate(tf.keras.layers.Wrapper): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py index 2a3a5ad36..5fdf2a827 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_aware_activation_test.py @@ -29,8 +29,8 @@ keras = tf.keras activations = tf.keras.activations K = tf.keras.backend -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object QuantizeAwareActivation = quantize_aware_activation.QuantizeAwareActivation MovingAverageQuantizer = quantizers.MovingAverageQuantizer diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py index bf94e130d..036559070 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_config.py @@ -259,13 +259,16 @@ def get_output_quantizers(self, layer): def get_config(self): return { - 'config': tf.keras.utils.serialize_keras_object(self.config), + 'config': tf.keras.utils.legacy.serialize_keras_object(self.config), 'num_bits': self.num_bits, 'init_min': self.init_min, 'init_max': self.init_max, - 'narrow_range': self.narrow_range} + 'narrow_range': self.narrow_range, + } @classmethod def from_config(cls, config): - config['config'] = tf.keras.utils.deserialize_keras_object(config['config']) + config['config'] = tf.keras.utils.legacy.deserialize_keras_object( + config['config'] + ) return cls(**config) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py index be59458ca..b1a160dd8 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py @@ -27,8 +27,8 @@ from tensorflow_model_optimization.python.core.quantization.keras import quantizers -serialize_keras_object = tf.keras.utils.serialize_keras_object -deserialize_keras_object = tf.keras.utils.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object class QuantizeLayer(tf.keras.layers.Layer): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index 02aaeb958..f147d365d 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -34,8 +34,8 @@ from tensorflow_model_optimization.python.core.keras import utils from tensorflow_model_optimization.python.core.quantization.keras import quantize_aware_activation -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object class QuantizeWrapper(tf.keras.layers.Wrapper): diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py index 7df0567f2..dafabdec4 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py @@ -26,8 +26,8 @@ from tensorflow_model_optimization.python.core.keras import compat from tensorflow_model_optimization.python.core.quantization.keras import quantizers -deserialize_keras_object = tf.keras.utils.deserialize_keras_object -serialize_keras_object = tf.keras.utils.serialize_keras_object +deserialize_keras_object = tf.keras.utils.legacy.deserialize_keras_object +serialize_keras_object = tf.keras.utils.legacy.serialize_keras_object @parameterized.parameters( diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py index 916d080ab..4f39a7a66 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_schedule_test.py @@ -242,12 +242,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.ConstantSparsity(0.7, 10, 20, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = tf.keras.utils.legacy.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) @@ -278,12 +279,13 @@ def testSerializeDeserialize(self): sparsity = pruning_schedule.PolynomialDecay(0.2, 0.6, 10, 20, 5, 10) config = sparsity.get_config() - sparsity_deserialized = tf.keras.utils.deserialize_keras_object( + sparsity_deserialized = tf.keras.utils.legacy.deserialize_keras_object( config, custom_objects={ 'ConstantSparsity': pruning_schedule.ConstantSparsity, - 'PolynomialDecay': pruning_schedule.PolynomialDecay - }) + 'PolynomialDecay': pruning_schedule.PolynomialDecay, + }, + ) self.assertEqual(sparsity.__dict__, sparsity_deserialized.__dict__) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py index 0a03643ec..a4ed2fa76 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py @@ -318,7 +318,7 @@ def from_config(cls, config): config = config.copy() pruning_schedule = config.pop('pruning_schedule') - deserialize_keras_object = keras.utils.deserialize_keras_object # pylint: disable=g-import-not-at-top + deserialize_keras_object = keras.utils.legacy.deserialize_keras_object # pylint: disable=g-import-not-at-top # TODO(pulkitb): This should ideally be fetched from pruning_schedule, # which should maintain a list of all the pruning_schedules. custom_objects = {