Skip to content

Commit 31d0154

Browse files
nkovela1tensorflower-gardener
authored andcommitted
Switches Keras object serialization to new logic and changes public API for deserialize_keras_object/serialize_keras_object to the new functions.
PiperOrigin-RevId: 480676373
1 parent 0e08dea commit 31d0154

File tree

7 files changed

+7
-23
lines changed

7 files changed

+7
-23
lines changed

tensorflow_model_optimization/python/core/quantization/keras/default_8bit/default_8bit_quantize_registry_test.py

-2
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,6 @@ def testSerialization(self):
366366

367367
quantize_config_from_config = deserialize_keras_object(
368368
serialized_quantize_config,
369-
module_objects=globals(),
370369
custom_objects=default_8bit_quantize_registry._types_dict())
371370

372371
self.assertEqual(quantize_config, quantize_config_from_config)
@@ -482,7 +481,6 @@ def testSerialization(self):
482481

483482
quantize_config_from_config = deserialize_keras_object(
484483
serialized_quantize_config,
485-
module_objects=globals(),
486484
custom_objects=default_8bit_quantize_registry._types_dict())
487485

488486
self.assertEqual(self.quantize_config, quantize_config_from_config)

tensorflow_model_optimization/python/core/quantization/keras/experimental/default_n_bit/default_n_bit_quantize_registry_test.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -372,9 +372,7 @@ def testSerialization(self):
372372
self.assertEqual(expected_config, serialized_quantize_config)
373373

374374
quantize_config_from_config = deserialize_keras_object(
375-
serialized_quantize_config,
376-
module_objects=globals(),
377-
custom_objects=n_bit_registry._types_dict())
375+
serialized_quantize_config, custom_objects=n_bit_registry._types_dict())
378376

379377
self.assertEqual(quantize_config, quantize_config_from_config)
380378

@@ -491,9 +489,7 @@ def testSerialization(self):
491489
self.assertEqual(expected_config, serialized_quantize_config)
492490

493491
quantize_config_from_config = deserialize_keras_object(
494-
serialized_quantize_config,
495-
module_objects=globals(),
496-
custom_objects=n_bit_registry._types_dict())
492+
serialized_quantize_config, custom_objects=n_bit_registry._types_dict())
497493

498494
self.assertEqual(self.quantize_config, quantize_config_from_config)
499495

tensorflow_model_optimization/python/core/quantization/keras/quantize_annotate.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def from_config(cls, config):
108108
config = config.copy()
109109

110110
quantize_config = deserialize_keras_object(
111-
config.pop('quantize_config'),
112-
module_objects=globals(),
113-
custom_objects=None)
111+
config.pop('quantize_config'), custom_objects=None)
114112

115113
layer = tf.keras.layers.deserialize(config.pop('layer'))
116114

tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -93,8 +93,6 @@ def from_config(cls, config):
9393

9494
# Deserialization code should ensure Quantizer is in keras scope.
9595
quantizer = deserialize_keras_object(
96-
config.pop('quantizer'),
97-
module_objects=globals(),
98-
custom_objects=None)
96+
config.pop('quantizer'), custom_objects=None)
9997

10098
return cls(quantizer=quantizer, **config)

tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -197,9 +197,7 @@ def from_config(cls, config):
197197
# The deserialization code should ensure the QuantizeConfig is in keras
198198
# serialization scope.
199199
quantize_config = deserialize_keras_object(
200-
config.pop('quantize_config'),
201-
module_objects=globals(),
202-
custom_objects=None)
200+
config.pop('quantize_config'), custom_objects=None)
203201

204202
layer = tf.keras.layers.deserialize(config.pop('layer'))
205203

tensorflow_model_optimization/python/core/quantization/keras/quantizers_test.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ def testSerialization(self, quantizer_type):
9797
self.assertEqual(expected_config, serialized_quantizer)
9898

9999
quantizer_from_config = deserialize_keras_object(
100-
serialized_quantizer,
101-
module_objects=globals(),
102-
custom_objects=quantizers._types_dict())
100+
serialized_quantizer, custom_objects=quantizers._types_dict())
103101

104102
self.assertEqual(quantizer, quantizer_from_config)
105103

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,7 @@ def from_config(cls, config):
324324
'PolynomialDecay': pruning_sched.PolynomialDecay
325325
}
326326
config['pruning_schedule'] = deserialize_keras_object(
327-
pruning_schedule,
328-
module_objects=globals(),
329-
custom_objects=custom_objects)
327+
pruning_schedule, custom_objects=custom_objects)
330328

331329
layer = keras.layers.deserialize(config.pop('layer'))
332330
config['layer'] = layer

0 commit comments

Comments
 (0)