|
| 1 | +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +# ============================================================================== |
| 15 | +"""Model Builder for EfficientNet.""" |
| 16 | + |
| 17 | +from __future__ import absolute_import |
| 18 | +from __future__ import division |
| 19 | +from __future__ import print_function |
| 20 | + |
| 21 | +import os |
| 22 | +import re |
| 23 | +import tensorflow as tf |
| 24 | + |
| 25 | +import efficientnet_model |
| 26 | + |
| 27 | + |
| 28 | +def efficientnet_params(model_name): |
| 29 | + """Get efficientnet params based on model name.""" |
| 30 | + params_dict = { |
| 31 | + # (width_coefficient, depth_coefficient, resolution, dropout_rate) |
| 32 | + 'efficientnet-b0': (1.0, 1.0, 224, 0.2), |
| 33 | + 'efficientnet-b1': (1.0, 1.1, 240, 0.2), |
| 34 | + 'efficientnet-b2': (1.1, 1.2, 260, 0.3), |
| 35 | + 'efficientnet-b3': (1.2, 1.4, 300, 0.3), |
| 36 | + 'efficientnet-b4': (1.4, 1.8, 380, 0.4), |
| 37 | + 'efficientnet-b5': (1.6, 2.2, 456, 0.4), |
| 38 | + 'efficientnet-b6': (1.8, 2.6, 528, 0.5), |
| 39 | + 'efficientnet-b7': (2.0, 3.1, 600, 0.5), |
| 40 | + } |
| 41 | + return params_dict[model_name] |
| 42 | + |
| 43 | + |
| 44 | +class BlockDecoder(object): |
| 45 | + """Block Decoder for readability.""" |
| 46 | + |
| 47 | + def _decode_block_string(self, block_string): |
| 48 | + """Gets a block through a string notation of arguments.""" |
| 49 | + assert isinstance(block_string, str) |
| 50 | + ops = block_string.split('_') |
| 51 | + options = {} |
| 52 | + for op in ops: |
| 53 | + splits = re.split(r'(\d.*)', op) |
| 54 | + if len(splits) >= 2: |
| 55 | + key, value = splits[:2] |
| 56 | + options[key] = value |
| 57 | + |
| 58 | + if 's' not in options or len(options['s']) != 2: |
| 59 | + raise ValueError('Strides options should be a pair of integers.') |
| 60 | + |
| 61 | + return efficientnet_model.BlockArgs( |
| 62 | + kernel_size=int(options['k']), |
| 63 | + num_repeat=int(options['r']), |
| 64 | + input_filters=int(options['i']), |
| 65 | + output_filters=int(options['o']), |
| 66 | + expand_ratio=int(options['e']), |
| 67 | + id_skip=('noskip' not in block_string), |
| 68 | + se_ratio=float(options['se']) if 'se' in options else None, |
| 69 | + strides=[int(options['s'][0]), int(options['s'][1])]) |
| 70 | + |
| 71 | + def _encode_block_string(self, block): |
| 72 | + """Encodes a block to a string.""" |
| 73 | + args = [ |
| 74 | + 'r%d' % block.num_repeat, |
| 75 | + 'k%d' % block.kernel_size, |
| 76 | + 's%d%d' % (block.strides[0], block.strides[1]), |
| 77 | + 'e%s' % block.expand_ratio, |
| 78 | + 'i%d' % block.input_filters, |
| 79 | + 'o%d' % block.output_filters |
| 80 | + ] |
| 81 | + if block.se_ratio > 0 and block.se_ratio <= 1: |
| 82 | + args.append('se%s' % block.se_ratio) |
| 83 | + if block.id_skip is False: |
| 84 | + args.append('noskip') |
| 85 | + return '_'.join(args) |
| 86 | + |
| 87 | + def decode(self, string_list): |
| 88 | + """Decodes a list of string notations to specify blocks inside the network. |
| 89 | +
|
| 90 | + Args: |
| 91 | + string_list: a list of strings, each string is a notation of block. |
| 92 | +
|
| 93 | + Returns: |
| 94 | + A list of namedtuples to represent blocks arguments. |
| 95 | + """ |
| 96 | + assert isinstance(string_list, list) |
| 97 | + blocks_args = [] |
| 98 | + for block_string in string_list: |
| 99 | + blocks_args.append(self._decode_block_string(block_string)) |
| 100 | + return blocks_args |
| 101 | + |
| 102 | + def encode(self, blocks_args): |
| 103 | + """Encodes a list of Blocks to a list of strings. |
| 104 | +
|
| 105 | + Args: |
| 106 | + blocks_args: A list of namedtuples to represent blocks arguments. |
| 107 | + Returns: |
| 108 | + a list of strings, each string is a notation of block. |
| 109 | + """ |
| 110 | + block_strings = [] |
| 111 | + for block in blocks_args: |
| 112 | + block_strings.append(self._encode_block_string(block)) |
| 113 | + return block_strings |
| 114 | + |
| 115 | + |
| 116 | +def efficientnet(width_coefficient=None, |
| 117 | + depth_coefficient=None, |
| 118 | + dropout_rate=0.2, |
| 119 | + drop_connect_rate=0.2): |
| 120 | + """Creates a efficientnet model.""" |
| 121 | + blocks_args = [ |
| 122 | + 'r1_k3_s11_e1_i32_o16_se0.25', 'r2_k3_s22_e6_i16_o24_se0.25', |
| 123 | + 'r2_k5_s22_e6_i24_o40_se0.25', 'r3_k3_s22_e6_i40_o80_se0.25', |
| 124 | + 'r3_k5_s11_e6_i80_o112_se0.25', 'r4_k5_s22_e6_i112_o192_se0.25', |
| 125 | + 'r1_k3_s11_e6_i192_o320_se0.25', |
| 126 | + ] |
| 127 | + global_params = efficientnet_model.GlobalParams( |
| 128 | + batch_norm_momentum=0.99, |
| 129 | + batch_norm_epsilon=1e-3, |
| 130 | + dropout_rate=dropout_rate, |
| 131 | + drop_connect_rate=drop_connect_rate, |
| 132 | + data_format='channels_last', |
| 133 | + num_classes=1000, |
| 134 | + width_coefficient=width_coefficient, |
| 135 | + depth_coefficient=depth_coefficient, |
| 136 | + depth_divisor=8, |
| 137 | + min_depth=None) |
| 138 | + decoder = BlockDecoder() |
| 139 | + return decoder.decode(blocks_args), global_params |
| 140 | + |
| 141 | + |
| 142 | +def get_model_params(model_name, override_params): |
| 143 | + """Get the block args and global params for a given model.""" |
| 144 | + if model_name.startswith('efficientnet'): |
| 145 | + width_coefficient, depth_coefficient, _, dropout_rate = ( |
| 146 | + efficientnet_params(model_name)) |
| 147 | + blocks_args, global_params = efficientnet( |
| 148 | + width_coefficient, depth_coefficient, dropout_rate) |
| 149 | + else: |
| 150 | + raise NotImplementedError('model name is not pre-defined: %s' % model_name) |
| 151 | + |
| 152 | + if override_params: |
| 153 | + # ValueError will be raised here if override_params has fields not included |
| 154 | + # in global_params. |
| 155 | + global_params = global_params._replace(**override_params) |
| 156 | + |
| 157 | + tf.logging.info('global_params= %s', global_params) |
| 158 | + tf.logging.info('blocks_args= %s', blocks_args) |
| 159 | + return blocks_args, global_params |
| 160 | + |
| 161 | + |
| 162 | +def build_model(images, |
| 163 | + model_name, |
| 164 | + training, |
| 165 | + override_params=None, |
| 166 | + model_dir=None): |
| 167 | + """A helper functiion to creates a model and returns predicted logits. |
| 168 | +
|
| 169 | + Args: |
| 170 | + images: input images tensor. |
| 171 | + model_name: string, the predefined model name. |
| 172 | + training: boolean, whether the model is constructed for training. |
| 173 | + override_params: A dictionary of params for overriding. Fields must exist in |
| 174 | + efficientnet_model.GlobalParams. |
| 175 | + model_dir: string, optional model dir for saving configs. |
| 176 | +
|
| 177 | + Returns: |
| 178 | + logits: the logits tensor of classes. |
| 179 | + endpoints: the endpoints for each layer. |
| 180 | +
|
| 181 | + Raises: |
| 182 | + When model_name specified an undefined model, raises NotImplementedError. |
| 183 | + When override_params has invalid fields, raises ValueError. |
| 184 | + """ |
| 185 | + assert isinstance(images, tf.Tensor) |
| 186 | + blocks_args, global_params = get_model_params(model_name, override_params) |
| 187 | + |
| 188 | + if model_dir: |
| 189 | + param_file = os.path.join(model_dir, 'model_params.txt') |
| 190 | + if not tf.gfile.Exists(param_file): |
| 191 | + with tf.gfile.GFile(param_file, 'w') as f: |
| 192 | + tf.logging.info('writing to %s' % param_file) |
| 193 | + f.write('model_name= %s\n\n' % model_name) |
| 194 | + f.write('global_params= %s\n\n' % str(global_params)) |
| 195 | + f.write('blocks_args= %s\n\n' % str(blocks_args)) |
| 196 | + |
| 197 | + with tf.variable_scope(model_name): |
| 198 | + model = efficientnet_model.Model(blocks_args, global_params) |
| 199 | + logits = model(images, training=training) |
| 200 | + |
| 201 | + logits = tf.identity(logits, 'logits') |
| 202 | + return logits, model.endpoints |
| 203 | + |
| 204 | + |
| 205 | +def build_model_base(images, model_name, training, override_params=None): |
| 206 | + """A helper functiion to create a base model and return global_pool. |
| 207 | +
|
| 208 | + Args: |
| 209 | + images: input images tensor. |
| 210 | + model_name: string, the model name of a pre-defined MnasNet. |
| 211 | + training: boolean, whether the model is constructed for training. |
| 212 | + override_params: A dictionary of params for overriding. Fields must exist in |
| 213 | + mnasnet_model.GlobalParams. |
| 214 | +
|
| 215 | + Returns: |
| 216 | + features: global pool features. |
| 217 | + endpoints: the endpoints for each layer. |
| 218 | +
|
| 219 | + Raises: |
| 220 | + When model_name specified an undefined model, raises NotImplementedError. |
| 221 | + When override_params has invalid fields, raises ValueError. |
| 222 | + """ |
| 223 | + assert isinstance(images, tf.Tensor) |
| 224 | + blocks_args, global_params = get_model_params(model_name, override_params) |
| 225 | + |
| 226 | + with tf.variable_scope(model_name): |
| 227 | + model = efficientnet_model.Model(blocks_args, global_params) |
| 228 | + features = model(images, training=training, features_only=True) |
| 229 | + |
| 230 | + features = tf.identity(features, 'global_pool') |
| 231 | + return features, model.endpoints |
0 commit comments