Skip to content

[clustering] Possible wrong implementation of get_weight_from_layer #799

Open
@tisma

Description

@tisma

Describe the bug
Problem with custom layer weights clustering. When layer implements ClusterableLayer it should override get_clusterable_weights but later call of get_weights_from_layer causes AttributeError

System information

MMMMMMMMMMMMMMMMMMMMMMMMMmds+.        dellboy@thunderstruck 
MMm----::-://////////////oymNMd+`     --------------------- 
MMd      /++                -sNMd:    OS: Linux Mint 19.3 Tricia x86_64 
MMNso/`  dMM    `.::-. .-::.` .hMN:   Host: Z390 AORUS MASTER 
ddddMMh  dMM   :hNMNMNhNMNMNh: `NMm   Kernel: 5.4.0-81-generic 
    NMm  dMM  .NMN/-+MMM+-/NMN` dMM   Uptime: 12 hours, 21 mins 
    NMm  dMM  -MMm  `MMM   dMM. dMM   Packages: 4694 
    NMm  dMM  -MMm  `MMM   dMM. dMM   Shell: bash 4.4.20 
    NMm  dMM  .mmd  `mmm   yMM. dMM   Resolution: 3840x2160, 3840x2160 
    NMm  dMM`  ..`   ...   ydm. dMM   DE: Cinnamon 4.4.8 
    hMM- +MMd/-------...-:sdds  dMM   WM: Mutter (Muffin) 
    -NMm- :hNMNNNmdddddddddy/`  dMM   WM Theme: Linux Mint (Mint-Y-Dark) 
     -dMNs-``-::::-------.``    dMM   Theme: Mint-Y-Dark [GTK2/3] 
      `/dMNmy+/:-------------:/yMMM   Icons: Mint-Y [GTK2/3] 
         ./ydNMMMMMMMMMMMMMMMMMMMMM   Terminal: gnome-terminal 
            .MMMMMMMMMMMMMMMMMMM      CPU: Intel i9-9900K (16) @ 5.000GHz 
                                      GPU: NVIDIA GeForce GTX 1080 Ti 
                                      Memory: 23072MiB / 64320MiB 

TensorFlow version (installed from source or binary):
Installed with pip, tensorflow-gpu 2.6.0

TensorFlow Model Optimization version (installed from source or binary):
Installed with pip, tensorflow-model-optimization 0.6.0

Python version: Python 3.6.9

Describe the expected behavior

Describe the current behavior

Code to reproduce the issue

import numpy as np

import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

cluster_weights = tfmot.clustering.keras.cluster_weights
CentroidInitialization = tfmot.clustering.keras.CentroidInitialization

clustering_params = {
  'number_of_clusters': 3,
  'cluster_centroids_init': CentroidInitialization.DENSITY_BASED
}

class MyCustomLayer(keras.layers.Layer, tfmot.clustering.keras.ClusterableLayer):
    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyCustomLayer, self).__init__(**kwargs)
        
    def build(self, input_shape):
        self.kernel = self.add_weight(
            name = 'kernel',
            shape = (input_shape[1], self.output_dim),
            initializer = 'normal',
            trainable = True
        )
        super(MyCustomLayer, self).build(input_shape)

    def call(self, input_data):
        return keras.backend.dot(input_data, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    
    def get_clusterable_weights(self):
        clusterable_weights = []
        for weight in self.trainable_weights:
            clusterable_weights.append((weight.name, weight.read_value()))
        return clusterable_weights


def get_model():
    # Create a simple model.
    model = keras.Sequential(
        [
            keras.Input(shape=(32,)),
            MyCustomLayer(32, input_shape=(32,)),
            keras.layers.Dense(2, activation="relu", name="layer1"),
            keras.layers.Dense(3, activation="relu", name="layer2"),
            keras.layers.Dense(4, name="layer3"),
        ]
    )
    
    model.compile(optimizer="adam", loss="mean_squared_error")
    return model

model = get_model()

# Train the model.
test_input = np.random.random((128, 32))
test_target = np.random.random((128, 1))
model.fit(test_input, test_target)
print(model.summary())

# Print all weights in model.
for weight in model.weights:
    print(weight.name)#, weight.read_value())

clustered_model = cluster_weights(model, **clustering_params)

clustered_model.summary(line_length=180, positions=[0.25, 0.60, 0.70, 1.0])

The output is:

(bug) dellboy@thunderstruck:~/git/tisma/tf-learn$ python simple_model.py 
2021-08-23 18:40:26.443329: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.465328: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.465628: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.466063: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2021-08-23 18:40:26.466461: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.466787: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:26.467059: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.063134: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.063747: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.064148: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2021-08-23 18:40:27.064580: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1510] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 6866 MB memory:  -> device: 0, name: GeForce GTX 1080 Ti, pci bus id: 0000:01:00.0, compute capability: 6.1
2021-08-23 18:40:27.214358: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
4/4 [==============================] - 0s 664us/step - loss: 0.2719
Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
my_custom_layer (MyCustomLay (None, 32)                1024      
_________________________________________________________________
layer1 (Dense)               (None, 2)                 66        
_________________________________________________________________
layer2 (Dense)               (None, 3)                 9         
_________________________________________________________________
layer3 (Dense)               (None, 4)                 16        
=================================================================
Total params: 1,115
Trainable params: 1,115
Non-trainable params: 0
_________________________________________________________________
None
my_custom_layer/kernel:0
layer1/kernel:0
layer1/bias:0
layer2/kernel:0
layer2/bias:0
layer3/kernel:0
layer3/bias:0
Traceback (most recent call last):
  File "simple_model.py", line 69, in <module>
    clustered_model = cluster_weights(model, **clustering_params)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster.py", line 133, in cluster_weights
    **kwargs)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster.py", line 261, in _cluster_weights
    to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/models.py", line 449, in clone_model
    model, input_tensors=input_tensors, layer_fn=clone_function)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/models.py", line 332, in _clone_sequential_model
    cloned_model = Sequential(layers=layers, name=model.name)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/sequential.py", line 134, in __init__
    self.add(layer)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow/python/training/tracking/base.py", line 530, in _method_wrapper
    result = method(self, *args, **kwargs)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/sequential.py", line 217, in add
    output_tensor = layer(self.outputs[0])
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 977, in __call__
    input_list)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 1115, in _functional_construction_call
    inputs, input_masks, args, kwargs)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 848, in _keras_tensor_symbolic_call
    return self._infer_output_signature(inputs, args, kwargs, input_masks)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 886, in _infer_output_signature
    self._maybe_build(inputs)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/keras/engine/base_layer.py", line 2659, in _maybe_build
    self.build(input_shapes)  # pylint:disable=not-callable
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py", line 160, in build
    original_weight = self.get_weight_from_layer(weight_name)
  File "/home/dellboy/git/tisma/tf-learn/bug/lib/python3.6/site-packages/tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py", line 146, in get_weight_from_layer
    return getattr(self.layer, weight_name)
AttributeError: 'MyCustomLayer' object has no attribute 'my_custom_layer/kernel:0'

But if I run this snippet:

# Print all weights in model.
for weight in model.weights:
    print(weight.name)#, weight.read_value())

print([layer.name for layer in model.layers])

# Example weight that should return from model.
weight_name = "my_custom_layer/kernel:0"

# This is correct way for getting it (I set layers[0] for example)
for weight in model.layers[0].weights:
    if weight.name == weight_name:
        print("FOUND WEIGHT: ", weight.name, weight.read_value())

It will found the weight:

my_custom_layer/kernel:0
layer1/kernel:0
layer1/bias:0
layer2/kernel:0
layer2/bias:0
layer3/kernel:0
layer3/bias:0
['my_custom_layer', 'layer1', 'layer2', 'layer3']
FOUND WEIGHT:  my_custom_layer/kernel:0 tf.Tensor(
[[-0.01710013  0.08641329  0.00445064 ... -0.00947034 -0.02543414
  -0.02332742]
 [-0.0146245  -0.002941   -0.01422382 ...  0.02857029 -0.04331051
  -0.00299862]
 [-0.07127763  0.07367716 -0.06753001 ... -0.06001836  0.04888764
   0.1081293 ]
 ...
 [ 0.04297659  0.0334582  -0.09708535 ...  0.00098922  0.05463797
  -0.0092663 ]
 [ 0.03690836 -0.061338    0.01662921 ... -0.03843782 -0.08734126
   0.00209901]
 [ 0.10478324  0.07971404  0.05170573 ...  0.05777165 -0.08564453
   0.04021074]], shape=(32, 32), dtype=float32)

My assumption is that implementation of get_weight_from_layer(self, weight_name):

def get_weight_from_layer(self, weight_name):
return getattr(self.layer, weight_name)
is incorrect.

Metadata

Metadata

Assignees

Labels

bugSomething isn't workingtechnique:clusteringRegarding tfmot.clustering.keras APIs and docs

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions