Description
Describe the bug
When using the PruneForLatencyOnXNNPack()
policy to prune MobileNet, I get the following error:
ValueError: Could not find `Conv2D 3x3` layer with stride 2x2, `input filters == 3` and `VALID` padding and preceding `ZeroPadding2D` with `padding == 1` in all input branches of the model
System information
Training on Arch Linux with no GPU. Targeting raspberry pi's cpu for inference.
TensorFlow version (installed from source or binary): 2.5.0
TensorFlow Model Optimization version (installed from source or binary): 0.5.1-dev
Python version: 3.9.5
Describe the expected behavior
I would expect MobileNet to be prunable with and without the XNNPACK policy.
Describe the current behavior
MobileNet is prunable without XNNPACK and unprunable with XNNPACK.
Code to reproduce the issue
Relevant part of source:
"""
Construct model
"""
inputs = tf.keras.layers.Input(shape=(img_width, img_height, 3))
# Prune mobilenet
mobilenet = tf.keras.applications.MobileNet(weights='imagenet', include_top=False)
pruning_schedule = tfmot.sparsity.keras.PolynomialDecay(
initial_sparsity=0.0, final_sparsity=0.8, begin_step=0, end_step=600)
pruning_policy = tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
pruned_mobilenet = tfmot.sparsity.keras.prune_low_magnitude(
mobilenet, pruning_schedule=pruning_schedule,
pruning_policy=pruning_policy)
# Convert mobilenet logits into predictions
avg = tf.keras.layers.GlobalAveragePooling2D()(pruned_mobilenet(inputs))
outputs = tf.keras.layers.Dense(num_classes, activation='softmax')(avg)
# Create model
model = tf.keras.Model(inputs, outputs)
Additional context
I would be happy to make a contribution once this issue gets resolved to help others in the future. If there were docs explaining why a Conv2D with a 3x3 kernel is needed to signal the start of a sparse subgraph, that would be very helpful (I have not read the paper yet, though, so I might be missing something). I am reluctant to add that layer, as I do not want to ruin my pretrained weights.
Activity