Skip to content

Commit 053183f

Browse files
committed
Add EfficientNet.
0 parents  commit 053183f

10 files changed

+2410
-0
lines changed

README.md

+69
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# EfficientNets
2+
3+
[1] Mingxing Tan and Quoc V. Le. EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks. ICML 2019.
4+
Arxiv link: https://arxiv.org/abs/1905.11946.
5+
6+
7+
## 1. About EfficientNet Models
8+
9+
EfficientNets are a family of image classification models, which achieve state-of-the-art accuracy, yet being an order-of-magnitude smaller and faster than previous models.
10+
11+
We develop EfficientNets based on AutoML and Compound Scaling. In particular, we first use [AutoML Mobile framework](https://ai.googleblog.com/2018/08/mnasnet-towards-automating-design-of.html) to develop a mobile-size baseline network, named as EfficientNet-B0; Then, we use the compound scaling method to scale up this baseline to obtain EfficientNet-B1 to B7.
12+
13+
<table border="0">
14+
<tr>
15+
<td>
16+
<img src="./g3doc/params.png" width="100%" />
17+
</td>
18+
<td>
19+
<img src="./g3doc/flops.png", width="90%" />
20+
</td>
21+
</tr>
22+
</table>
23+
24+
EfficientNets achieve state-of-the-art accuracy on ImageNet with an order of magnitude better efficiency:
25+
26+
27+
* In high-accuracy regime, our EfficientNet-B7 achieves state-of-the-art 84.4% top-1 / 97.1% top-5 accuracy on ImageNet with 66M parameters and 37B FLOPS, being 8.4x smaller and 6.1x faster on CPU inference than previous best [Gpipe](https://arxiv.org/abs/1811.06965).
28+
29+
* In middle-accuracy regime, our EfficientNet-B1 is 7.6x smaller and 5.7x faster on CPU inference than [ResNet-152](https://arxiv.org/abs/1512.03385), with similar ImageNet accuracy.
30+
31+
* Compared with the widely used [ResNet-50](https://arxiv.org/abs/1512.03385), our EfficientNet-B4 improves the top-1 accuracy from 76.3% of ResNet-50 to 82.6% (+6.3%), under similar FLOPS constraint.
32+
33+
## 2. Using Pretrained EfficientNet Checkpoints
34+
35+
We have provided a list of EfficientNet checkpoints for [EfficientNet-B0](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/efficientnet-b0.tar.gz), [EfficientNet-B1](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/efficientnet-b1.tar.gz), [EfficientNet-B2](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/efficientnet-b2.tar.gz), and [EfficientNet-B3](https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/efficientnet-b3.tar.gz). A quick way to use these checkpoints is to run:
36+
37+
$ export MODEL=efficientnet-b0
38+
$ wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/${MODEL}.tar.gz
39+
$ tar zxf ${MODEL}.tar.gz
40+
$ wget https://upload.wikimedia.org/wikipedia/commons/f/fe/Giant_Panda_in_Beijing_Zoo_1.JPG -O panda.jpg
41+
$ wget https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/eval_data/labels_map.txt
42+
$ python eval_ckpt_main.py --model_name=$MODEL --ckpt_dir=$MODEL --example_img=panda.jpg --labels_map_file=labels_map.txt
43+
44+
Please refer to the following colab for more instructions on how to obtain and use those checkpoints.
45+
46+
* [`eval_ckpt_example.ipynb`](eval_ckpt_example.ipynb): A colab example to load
47+
EfficientNet pretrained checkpoints files and use the restored model to classify images.
48+
49+
50+
## 3. Training EfficientNets on TPUs.
51+
52+
53+
To train this model on Cloud TPU, you will need:
54+
55+
* A GCE VM instance with an associated Cloud TPU resource
56+
* A GCS bucket to store your training checkpoints (the "model directory")
57+
* Install TensorFlow version >= 1.13 for both GCE VM and Cloud.
58+
59+
Then train the model:
60+
61+
$ export PYTHONPATH="$PYTHONPATH:/path/to/models"
62+
$ python main.py --tpu=TPU_NAME --data_dir=DATA_DIR --model_dir=MODEL_DIR
63+
64+
# TPU_NAME is the name of the TPU node, the same name that appears when you run gcloud compute tpus list, or ctpu ls.
65+
# MODEL_DIR is a GCS location (a URL starting with gs:// where both the GCE VM and the associated Cloud TPU have write access
66+
# DATA_DIR is a GCS location to which both the GCE VM and associated Cloud TPU have read access.
67+
68+
69+
For more instructions, please refer to our tutorial: https://cloud.google.com/tpu/docs/tutorials/efficientnet

efficientnet_builder.py

+231
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Model Builder for EfficientNet."""
16+
17+
from __future__ import absolute_import
18+
from __future__ import division
19+
from __future__ import print_function
20+
21+
import os
22+
import re
23+
import tensorflow as tf
24+
25+
import efficientnet_model
26+
27+
28+
def efficientnet_params(model_name):
29+
"""Get efficientnet params based on model name."""
30+
params_dict = {
31+
# (width_coefficient, depth_coefficient, resolution, dropout_rate)
32+
'efficientnet-b0': (1.0, 1.0, 224, 0.2),
33+
'efficientnet-b1': (1.0, 1.1, 240, 0.2),
34+
'efficientnet-b2': (1.1, 1.2, 260, 0.3),
35+
'efficientnet-b3': (1.2, 1.4, 300, 0.3),
36+
'efficientnet-b4': (1.4, 1.8, 380, 0.4),
37+
'efficientnet-b5': (1.6, 2.2, 456, 0.4),
38+
'efficientnet-b6': (1.8, 2.6, 528, 0.5),
39+
'efficientnet-b7': (2.0, 3.1, 600, 0.5),
40+
}
41+
return params_dict[model_name]
42+
43+
44+
class BlockDecoder(object):
45+
"""Block Decoder for readability."""
46+
47+
def _decode_block_string(self, block_string):
48+
"""Gets a block through a string notation of arguments."""
49+
assert isinstance(block_string, str)
50+
ops = block_string.split('_')
51+
options = {}
52+
for op in ops:
53+
splits = re.split(r'(\d.*)', op)
54+
if len(splits) >= 2:
55+
key, value = splits[:2]
56+
options[key] = value
57+
58+
if 's' not in options or len(options['s']) != 2:
59+
raise ValueError('Strides options should be a pair of integers.')
60+
61+
return efficientnet_model.BlockArgs(
62+
kernel_size=int(options['k']),
63+
num_repeat=int(options['r']),
64+
input_filters=int(options['i']),
65+
output_filters=int(options['o']),
66+
expand_ratio=int(options['e']),
67+
id_skip=('noskip' not in block_string),
68+
se_ratio=float(options['se']) if 'se' in options else None,
69+
strides=[int(options['s'][0]), int(options['s'][1])])
70+
71+
def _encode_block_string(self, block):
72+
"""Encodes a block to a string."""
73+
args = [
74+
'r%d' % block.num_repeat,
75+
'k%d' % block.kernel_size,
76+
's%d%d' % (block.strides[0], block.strides[1]),
77+
'e%s' % block.expand_ratio,
78+
'i%d' % block.input_filters,
79+
'o%d' % block.output_filters
80+
]
81+
if block.se_ratio > 0 and block.se_ratio <= 1:
82+
args.append('se%s' % block.se_ratio)
83+
if block.id_skip is False:
84+
args.append('noskip')
85+
return '_'.join(args)
86+
87+
def decode(self, string_list):
88+
"""Decodes a list of string notations to specify blocks inside the network.
89+
90+
Args:
91+
string_list: a list of strings, each string is a notation of block.
92+
93+
Returns:
94+
A list of namedtuples to represent blocks arguments.
95+
"""
96+
assert isinstance(string_list, list)
97+
blocks_args = []
98+
for block_string in string_list:
99+
blocks_args.append(self._decode_block_string(block_string))
100+
return blocks_args
101+
102+
def encode(self, blocks_args):
103+
"""Encodes a list of Blocks to a list of strings.
104+
105+
Args:
106+
blocks_args: A list of namedtuples to represent blocks arguments.
107+
Returns:
108+
a list of strings, each string is a notation of block.
109+
"""
110+
block_strings = []
111+
for block in blocks_args:
112+
block_strings.append(self._encode_block_string(block))
113+
return block_strings
114+
115+
116+
def efficientnet(width_coefficient=None,
117+
depth_coefficient=None,
118+
dropout_rate=0.2,
119+
drop_connect_rate=0.2):
120+
"""Creates a efficientnet model."""
121+
blocks_args = [
122+
'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25',
123+
'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25',
124+
'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25',
125+
'r1_k3_s11_e6_i192_o320_se0.25',
126+
]
127+
global_params = efficientnet_model.GlobalParams(
128+
batch_norm_momentum=0.99,
129+
batch_norm_epsilon=1e-3,
130+
dropout_rate=dropout_rate,
131+
drop_connect_rate=drop_connect_rate,
132+
data_format='channels_last',
133+
num_classes=1000,
134+
width_coefficient=width_coefficient,
135+
depth_coefficient=depth_coefficient,
136+
depth_divisor=8,
137+
min_depth=None)
138+
decoder = BlockDecoder()
139+
return decoder.decode(blocks_args), global_params
140+
141+
142+
def get_model_params(model_name, override_params):
143+
"""Get the block args and global params for a given model."""
144+
if model_name.startswith('efficientnet'):
145+
width_coefficient, depth_coefficient, _, dropout_rate = (
146+
efficientnet_params(model_name))
147+
blocks_args, global_params = efficientnet(
148+
width_coefficient, depth_coefficient, dropout_rate)
149+
else:
150+
raise NotImplementedError('model name is not pre-defined: %s' % model_name)
151+
152+
if override_params:
153+
# ValueError will be raised here if override_params has fields not included
154+
# in global_params.
155+
global_params = global_params._replace(**override_params)
156+
157+
tf.logging.info('global_params= %s', global_params)
158+
tf.logging.info('blocks_args= %s', blocks_args)
159+
return blocks_args, global_params
160+
161+
162+
def build_model(images,
163+
model_name,
164+
training,
165+
override_params=None,
166+
model_dir=None):
167+
"""A helper functiion to creates a model and returns predicted logits.
168+
169+
Args:
170+
images: input images tensor.
171+
model_name: string, the predefined model name.
172+
training: boolean, whether the model is constructed for training.
173+
override_params: A dictionary of params for overriding. Fields must exist in
174+
efficientnet_model.GlobalParams.
175+
model_dir: string, optional model dir for saving configs.
176+
177+
Returns:
178+
logits: the logits tensor of classes.
179+
endpoints: the endpoints for each layer.
180+
181+
Raises:
182+
When model_name specified an undefined model, raises NotImplementedError.
183+
When override_params has invalid fields, raises ValueError.
184+
"""
185+
assert isinstance(images, tf.Tensor)
186+
blocks_args, global_params = get_model_params(model_name, override_params)
187+
188+
if model_dir:
189+
param_file = os.path.join(model_dir, 'model_params.txt')
190+
if not tf.gfile.Exists(param_file):
191+
with tf.gfile.GFile(param_file, 'w') as f:
192+
tf.logging.info('writing to %s' % param_file)
193+
f.write('model_name= %s\n\n' % model_name)
194+
f.write('global_params= %s\n\n' % str(global_params))
195+
f.write('blocks_args= %s\n\n' % str(blocks_args))
196+
197+
with tf.variable_scope(model_name):
198+
model = efficientnet_model.Model(blocks_args, global_params)
199+
logits = model(images, training=training)
200+
201+
logits = tf.identity(logits, 'logits')
202+
return logits, model.endpoints
203+
204+
205+
def build_model_base(images, model_name, training, override_params=None):
206+
"""A helper functiion to create a base model and return global_pool.
207+
208+
Args:
209+
images: input images tensor.
210+
model_name: string, the model name of a pre-defined MnasNet.
211+
training: boolean, whether the model is constructed for training.
212+
override_params: A dictionary of params for overriding. Fields must exist in
213+
mnasnet_model.GlobalParams.
214+
215+
Returns:
216+
features: global pool features.
217+
endpoints: the endpoints for each layer.
218+
219+
Raises:
220+
When model_name specified an undefined model, raises NotImplementedError.
221+
When override_params has invalid fields, raises ValueError.
222+
"""
223+
assert isinstance(images, tf.Tensor)
224+
blocks_args, global_params = get_model_params(model_name, override_params)
225+
226+
with tf.variable_scope(model_name):
227+
model = efficientnet_model.Model(blocks_args, global_params)
228+
features = model(images, training=training, features_only=True)
229+
230+
features = tf.identity(features, 'global_pool')
231+
return features, model.endpoints

0 commit comments

Comments
 (0)