Description
Describe the bug
Training a sparse model including an Masking Layer and an LSTM, and then stripping model with strip_pruning() produces models that handle the masking differently. This means the models trained with pruning are not useable for inference after stripping.
System information
- TensorFlow installed from (source or binary): binary (anaconda, mkl_py366ha38f243_0)
- TensorFlow version: 2.0.0
- TensorFlow Model Optimization version: 0.3.0
Python version:
- 3.6.9 |Intel Corporation| (default, Sep 11 2019, 11:39:53)
- [GCC 4.2.1 Compatible Apple LLVM 7.3.0 (clang-703.0.31)]
Describe the expected behavior
Models with a prune_low_magnitude LSTM layer should generate same output as the same model stripped model.
Describe the current behavior
A model in which a prune_low_magnitude LSTM layer is fed by a masking layer appears to handle the masking differently than its stripped version. Moreover, it's not a simple ignoring of the sequence mask.
Code to reproduce the issue
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Bidirectional, LSTM
from tensorflow_model_optimization.sparsity import keras as sparsity
# Create some fake masked sequences (I'm sure there's a cleaner way to do this)
pad_val = 7
samples = 10
timesteps = 9
features = 6
data = np.full((samples, timesteps, features), pad_val, dtype=np.float32)
rng = np.random.default_rng()
num_len = rng.integers(low=3, high=timesteps, size=(samples,1))
for i in range(samples):
len = int(num_len[i])
x = np.random.uniform(low=-1, high=1, size=(len))
for j in range(len):
data[i,j,:] = x[j]
# Create a model incorporating a masking layer and a pruneable LSTM.
pruning_model = keras.Sequential(
[keras.layers.Masking(mask_value=pad_val, input_shape=(timesteps, features)),
sparsity.prune_low_magnitude(keras.layers.LSTM(2))])
# Get the output of the model, as well as just the LSTM (skipping the masking)
x1_a = pruning_model(data).numpy() #LSTM output with masked input
x1_b = pruning_model.layers[1](data).numpy() #LSTM output skipping masked input
# Now strip the model, and check weights are the same
depruned_model = sparsity.strip_pruning(pruning_model)
for i, w in enumerate(depruned_model.layers[1].weights):
assert np.allclose(w.numpy(), pruning_model.layers[1].weights[i].numpy())
x2_a = depruned_model(data).numpy()
x2_b = depruned_model.layers[1](data).numpy()
Check behaviour of outputs
Including the masking
print(np.allclose(x1_a,x2_a)) #Returns False: the outputs of the two models are not compatible if masking included
print(np.allclose(x1b,x2b)) #Returns True: Skipping masking yields the same output from the two models
print(np.allclose(x1_a,x1_b) #Returns False: The prunable model isn't just ignoring the masking.
# Check that the masks are the same:
m1 = pruning_model.layers[0](data)._keras_mask
m2 = depruned_model.layers[0](data)._keras_mask
np.all(m1==m2) # Returns True: the same masks are being passed to the LSTM layer.
A quick inspection shows that x1_a, x1_b, and x2_a are very different from each other. I've also confirmed that the layer configurations are the same in the two models.
Additional context
I didn't see anything in the documentation re: behaviour with masked inputs to LSTM layers but the current behaviour (handle it differently than either the 'usual' or ignoring the mask complexly) seems counter-intuitive.
I found this problem after training and pruning a masked LSTM model, stripping it and finding that the output was not consistent with the pruned model.
Thank you in advance for any help you can offer.