Skip to content

Commit 8c10cc0

Browse files
mingxingtancopybara-github
authored andcommitted
Open source EfficientNetV2.
PiperOrigin-RevId: 373612565
1 parent 15b9666 commit 8c10cc0

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

tensorflow_examples/lite/model_maker/third_party/efficientdet/Det-AdvProp.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
# Det-AdvProp
2+
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.sandbox.google.com/github/google/automl/blob/master/efficientdet
3+
/det_advprop_tutorial.ipynb)
24

35
[1] Xiangning Chen, Cihang Xie, Mingxing Tan, Li Zhang, Cho-Jui Hsieh, Boqing
46
Gong. CVPR 2021. Arxiv link: https://arxiv.org/abs/2103.13886

tensorflow_examples/lite/model_maker/third_party/efficientdet/utils.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def activation_fn(features: tf.Tensor, act_type: Text):
5757
def cross_replica_mean(t, num_shards_per_group=None):
5858
"""Calculates the average value of input tensor across TPU replicas."""
5959
num_shards = tpu_function.get_tpu_context().number_of_shards
60+
if not num_shards:
61+
return t
62+
6063
if not num_shards_per_group:
6164
return tf.tpu.cross_replica_sum(t) / tf.cast(num_shards, t.dtype)
6265

@@ -75,8 +78,9 @@ def cross_replica_mean(t, num_shards_per_group=None):
7578

7679
def get_ema_vars():
7780
"""Get all exponential moving average (ema) variables."""
78-
ema_vars = tf.trainable_variables() + \
79-
tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES)
81+
ema_vars = (
82+
tf.trainable_variables() +
83+
tf.get_collection(tf.GraphKeys.MOVING_AVERAGE_VARIABLES))
8084
for v in tf.global_variables():
8185
# We maintain mva for batch norm moving mean and variance as well.
8286
if 'moving_mean' in v.name or 'moving_variance' in v.name:
@@ -127,6 +131,8 @@ def get_ckpt_var_map(ckpt_path, ckpt_scope, var_scope, skip_mismatch=None):
127131
logging.info('skip {} -- does not match scope {}'.format(
128132
var_op_name, var_scope))
129133
ckpt_var = ckpt_scope + var_op_name[len(var_scope):]
134+
if 'global_step' in ckpt_var:
135+
continue
130136

131137
if (ckpt_var not in ckpt_var_names and
132138
var_op_name.endswith('/ExponentialMovingAverage')):

0 commit comments

Comments
 (0)