From 143b415f865912c618712a87461c1e267cc911b6 Mon Sep 17 00:00:00 2001 From: Rino Lee Date: Mon, 21 Jun 2021 07:54:52 -0700 Subject: [PATCH] Add support for Pruning callback function in Model garden trainer PiperOrigin-RevId: 380577575 --- .../core/sparsity/keras/pruning_callbacks.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py index a51b3faa6..246a1309e 100644 --- a/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py +++ b/tensorflow_model_optimization/python/core/sparsity/keras/pruning_callbacks.py @@ -136,3 +136,24 @@ def on_epoch_begin(self, epoch, logs=None): pruning_logs.update({threshold.name + '/threshold': threshold_value}) self._log_pruning_metrics(pruning_logs, '', iteration) + + def generate_log(self): + pruning_logs = {} + params = [] + prunable_layers = pruning_wrapper.collect_prunable_layers(self.model) + for layer in prunable_layers: + for _, mask, threshold in layer.pruning_vars: + params.append(mask) + params.append(threshold) + + values = K.batch_get_value(params) + + param_value_pairs = list(zip(params, values)) + + for mask, mask_value in param_value_pairs[::2]: + pruning_logs.update({mask.name + '/sparsity': 1 - np.mean(mask_value)}) + + for threshold, threshold_value in param_value_pairs[1::2]: + pruning_logs.update({threshold.name + '/threshold': threshold_value}) + + return pruning_logs