Skip to content

Commit d06772e

Browse files
Merge pull request #815 from DLPerf:master
PiperOrigin-RevId: 399853582
2 parents 84ff062 + d3d5474 commit d06772e

File tree

1 file changed

+2
-1
lines changed
  • tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research

1 file changed

+2
-1
lines changed

tensorflow_model_optimization/python/core/internal/tensor_encoding/stages/research/kashin.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,10 +195,11 @@ def encode(self, x, encode_params):
195195
x.dtype)
196196

197197
# Compute the Kashin coefficients.
198+
ep_trans = tf.cast(encode_params[self.ETA_PARAMS_KEY], x.dtype)
198199
for _ in range(self._num_iters - 1):
199200
residual, kashin_coefficients = self._kashin_iter(
200201
residual, kashin_coefficients, signs, clip_level)
201-
clip_level *= tf.cast(encode_params[self.ETA_PARAMS_KEY], x.dtype)
202+
clip_level *= ep_trans
202203
# The last iteration can be with or without clipping.
203204
kashin_coefficients += self._kashin_forward(residual, signs, clip_level,
204205
last_iter_clip)

0 commit comments

Comments
 (0)