diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py index 59df68c1..af71b85b 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_layer.py @@ -54,6 +54,7 @@ def __init__(self, quantizer, **kwargs): self.quantizer = quantizer def build(self, input_shape): + super(QuantizeLayer, self).build(input_shape) if self.quantizer: self.quantizer_vars = self.quantizer.build( input_shape, self.name, self) diff --git a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py index 32a6e2de..6424ddf3 100644 --- a/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py +++ b/tensorflow_model_optimization/python/core/quantization/keras/quantize_wrapper.py @@ -252,8 +252,8 @@ def losses(self): class QuantizeWrapperV2(QuantizeWrapper): def build(self, input_shape): - self._trainable_weights.extend(self.layer.trainable_weights) super(QuantizeWrapperV2, self).build(input_shape) + self._trainable_weights = self.layer.trainable_weights + self._trainable_weights @property def trainable_weights(self):