Skip to content

Behaviour of stripped models and sequence masking #403

Open
@captainproton1971

Description

@captainproton1971

Describe the bug
Training a sparse model including an Masking Layer and an LSTM, and then stripping model with strip_pruning() produces models that handle the masking differently. This means the models trained with pruning are not useable for inference after stripping.

System information

  • TensorFlow installed from (source or binary): binary (anaconda, mkl_py366ha38f243_0)
  • TensorFlow version: 2.0.0
  • TensorFlow Model Optimization version: 0.3.0

Python version:

  • 3.6.9 |Intel Corporation| (default, Sep 11 2019, 11:39:53)
  • [GCC 4.2.1 Compatible Apple LLVM 7.3.0 (clang-703.0.31)]

Describe the expected behavior
Models with a prune_low_magnitude LSTM layer should generate same output as the same model stripped model.

Describe the current behavior
A model in which a prune_low_magnitude LSTM layer is fed by a masking layer appears to handle the masking differently than its stripped version. Moreover, it's not a simple ignoring of the sequence mask.

Code to reproduce the issue

import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Bidirectional, LSTM
from tensorflow_model_optimization.sparsity import keras as sparsity

# Create some fake masked sequences (I'm sure there's a cleaner way to do this)
pad_val = 7
samples = 10
timesteps = 9
features = 6

data = np.full((samples, timesteps, features),  pad_val, dtype=np.float32)
rng = np.random.default_rng()
num_len = rng.integers(low=3, high=timesteps, size=(samples,1))

for i in range(samples):
    len = int(num_len[i])
    x = np.random.uniform(low=-1, high=1, size=(len))
    for j in range(len):
        data[i,j,:] = x[j]

# Create a model incorporating a masking layer and a pruneable LSTM.
pruning_model = keras.Sequential(
    [keras.layers.Masking(mask_value=pad_val, input_shape=(timesteps, features)),
     sparsity.prune_low_magnitude(keras.layers.LSTM(2))])

# Get the output of the model, as well as just the LSTM (skipping the masking)
x1_a = pruning_model(data).numpy()  #LSTM output with masked input
x1_b = pruning_model.layers[1](data).numpy() #LSTM output skipping masked input

# Now strip the model, and check weights are the same
depruned_model = sparsity.strip_pruning(pruning_model)

for i, w in enumerate(depruned_model.layers[1].weights):
    assert np.allclose(w.numpy(), pruning_model.layers[1].weights[i].numpy())

x2_a = depruned_model(data).numpy()
x2_b = depruned_model.layers[1](data).numpy()

Check behaviour of outputs

Including the masking

print(np.allclose(x1_a,x2_a)) #Returns False: the outputs of the two models are not compatible if masking included
print(np.allclose(x1b,x2b)) #Returns True: Skipping masking yields the same output from the two models
print(np.allclose(x1_a,x1_b) #Returns False: The prunable model isn't just ignoring the masking.

# Check that the masks are the same:
m1 = pruning_model.layers[0](data)._keras_mask
m2 = depruned_model.layers[0](data)._keras_mask
np.all(m1==m2) # Returns True: the same masks are being passed to the LSTM layer.

A quick inspection shows that x1_a, x1_b, and x2_a are very different from each other. I've also confirmed that the layer configurations are the same in the two models.

Additional context
I didn't see anything in the documentation re: behaviour with masked inputs to LSTM layers but the current behaviour (handle it differently than either the 'usual' or ignoring the mask complexly) seems counter-intuitive.

I found this problem after training and pruning a masked LSTM model, stripping it and finding that the output was not consistent with the pruned model.

Thank you in advance for any help you can offer.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingtechnique:pruningRegarding tfmot.sparsity.keras APIs and docs

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions