Open
Description
Describe the bug
Problem with custom layer weights clustering. When layer implements ClusterableLayer
it should override get_clusterable_weights
but later call of get_weights_from_layer
causes AttributeError
System information
MMMMMMMMMMMMMMMMMMMMMMMMMmds+. dellboy@thunderstruck
MMm----::-://////////////oymNMd+` ---------------------
MMd /++ -sNMd: OS: Linux Mint 19.3 Tricia x86_64
MMNso/` dMM `.::-. .-::.` .hMN: Host: Z390 AORUS MASTER
ddddMMh dMM :hNMNMNhNMNMNh: `NMm Kernel: 5.4.0-81-generic
NMm dMM .NMN/-+MMM+-/NMN` dMM Uptime: 12 hours, 21 mins
NMm dMM -MMm `MMM dMM. dMM Packages: 4694
NMm dMM -MMm `MMM dMM. dMM Shell: bash 4.4.20
NMm dMM .mmd `mmm yMM. dMM Resolution: 3840x2160, 3840x2160
NMm dMM` ..` ... ydm. dMM DE: Cinnamon 4.4.8
hMM- +MMd/-------...-:sdds dMM WM: Mutter (Muffin)
-NMm- :hNMNNNmdddddddddy/` dMM WM Theme: Linux Mint (Mint-Y-Dark)
-dMNs-``-::::-------.`` dMM Theme: Mint-Y-Dark [GTK2/3]
`/dMNmy+/:-------------:/yMMM Icons: Mint-Y [GTK2/3]
./ydNMMMMMMMMMMMMMMMMMMMMM Terminal: gnome-terminal
.MMMMMMMMMMMMMMMMMMM CPU: Intel i9-9900K (16) @ 5.000GHz
GPU: NVIDIA GeForce GTX 1080 Ti
Memory: 23072MiB / 64320MiB
TensorFlow version (installed from source or binary):
Installed with pip
, tensorflow-gpu 2.6.0
TensorFlow Model Optimization version (installed from source or binary):
Installed with pip
, tensorflow-model-optimization 0.6.0
Python version: Python 3.6.9
Describe the expected behavior
Describe the current behavior
Code to reproduce the issue
import numpy as np
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot
cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization
clustering_params = {
'number_of_clusters': 3,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}
class MyCustomLayer(keras.layers.Layer, tfmot.clustering.keras.ClusterableLayer):
def __init__(self, output_dim, **kwargs):
self.output_dim = output_dim
super(MyCustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(
name = 'kernel',
shape = (input_shape[1], self.output_dim),
initializer = 'normal',
trainable = True
)
super(MyCustomLayer, self).build(input_shape)
def call(self, input_data):
return keras.backend.dot(input_data, self.kernel)
def compute_output_shape(self, input_shape):
return (input_shape[0], self.output_dim)
def get_clusterable_weights(self):
clusterable_weights = []
for weight in self.trainable_weights:
clusterable_weights.append((weight.name, weight.read_value()))
return clusterable_weights
def get_model():
# Create a simple model.
model = keras.Sequential(
[
keras.Input(shape=(32,)),
MyCustomLayer(32, input_shape=(32,)),
keras.layers.Dense(2, activation="relu", name="layer1"),
keras.layers.Dense(3, activation="relu", name="layer2"),
keras.layers.Dense(4, name="layer3"),
]
)
model.compile(optimizer="adam", loss="mean_squared_error")
return model
model = get_model()
# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)
print(model.summary())
# Print all weights in model.
for weight in model.weights:
print(weight.name)#, weight.read_value())
clustered_model = cluster_weights(model, **clustering_params)
clustered_model.summary(line_length=180, positions=[0.25, 0.60, 0.70, 1.0])
The output is:
(bug) dellboy@thunderstruck:~/git/tisma/tf-learn$ python simple_model.py
2021-08-23 18:40:26.443329: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.465328: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.465628: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.466063: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-23 18:40:26.466461: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.466787: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.467059: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.063134: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.063747: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.064148: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.064580: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6866 MB memory: -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1
2021-08-23 18:40:27.214358: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
4/4 [==============================] - 0s 664us/step - loss: 0.2719
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my_custom_layer (MyCustomLay (None, 32) 1024
_________________________________________________________________
layer1 (Dense) (None, 2) 66
_________________________________________________________________
layer2 (Dense) (None, 3) 9
_________________________________________________________________
layer3 (Dense) (None, 4) 16
=================================================================
Total params: 1,115
Trainable params: 1,115
Non-trainable params: 0
_________________________________________________________________
None
my_custom_layer/kernel:0
layer1/kernel:0
layer1/bias:0
layer2/kernel:0
layer2/bias:0
layer3/kernel:0
layer3/bias:0
Traceback (most recent call last):
File "simple_model.py", line 69, in <module>
clustered_model = cluster_weights(model, **clustering_params)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster.py", line 133, in cluster_weights
**kwargs)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster.py", line 261, in _cluster_weights
to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/models.py", line 449, in clone_model
model, input_tensors=input_tensors, layer_fn=clone_function)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/models.py", line 332, in _clone_sequential_model
cloned_model = Sequential(layers=layers, name=model.name)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/sequential.py", line 134, in __init__
self.add(layer)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
result = method(self, *args, **kwargs)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/sequential.py", line 217, in add
output_tensor = layer(self.outputs[0])
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 977, in __call__
input_list)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 1115, in _functional_construction_call
inputs, input_masks, args, kwargs)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
return self._infer_output_signature(inputs, args, kwargs, input_masks)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 886, in _infer_output_signature
self._maybe_build(inputs)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 2659, in _maybe_build
self.build(input_shapes) # pylint:disable=not-callable
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py", line 160, in build
original_weight = self.get_weight_from_layer(weight_name)
File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py", line 146, in get_weight_from_layer
return getattr(self.layer, weight_name)
AttributeError: 'MyCustomLayer' object has no attribute 'my_custom_layer/kernel:0'
But if I run this snippet:
# Print all weights in model.
for weight in model.weights:
print(weight.name)#, weight.read_value())
print([layer.name for layer in model.layers])
# Example weight that should return from model.
weight_name = "my_custom_layer/kernel:0"
# This is correct way for getting it (I set layers[0] for example)
for weight in model.layers[0].weights:
if weight.name == weight_name:
print("FOUND WEIGHT: ", weight.name, weight.read_value())
It will found the weight:
my_custom_layer/kernel:0
layer1/kernel:0
layer1/bias:0
layer2/kernel:0
layer2/bias:0
layer3/kernel:0
layer3/bias:0
['my_custom_layer', 'layer1', 'layer2', 'layer3']
FOUND WEIGHT: my_custom_layer/kernel:0 tf.Tensor(
[[-0.01710013 0.08641329 0.00445064 ... -0.00947034 -0.02543414
-0.02332742]
[-0.0146245 -0.002941 -0.01422382 ... 0.02857029 -0.04331051
-0.00299862]
[-0.07127763 0.07367716 -0.06753001 ... -0.06001836 0.04888764
0.1081293 ]
...
[ 0.04297659 0.0334582 -0.09708535 ... 0.00098922 0.05463797
-0.0092663 ]
[ 0.03690836 -0.061338 0.01662921 ... -0.03843782 -0.08734126
0.00209901]
[ 0.10478324 0.07971404 0.05170573 ... 0.05777165 -0.08564453
0.04021074]], shape=(32, 32), dtype=float32)
My assumption is that implementation of get_weight_from_layer(self, weight_name)
: