1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable= arguments-differ,unused-argument 20"""DenseNet, implemented in Gluon.""" 21__all__ = ['DenseNet', 'densenet121', 'densenet161', 'densenet169', 'densenet201'] 22 23from mxnet.context import cpu 24from mxnet.gluon.block import HybridBlock 25from mxnet.gluon import nn 26from mxnet.gluon.nn import BatchNorm 27from mxnet.gluon.contrib.nn import HybridConcurrent, Identity 28 29# Helpers 30def _make_dense_block(num_layers, bn_size, growth_rate, dropout, stage_index, 31 norm_layer, norm_kwargs): 32 out = nn.HybridSequential(prefix='stage%d_'%stage_index) 33 with out.name_scope(): 34 for _ in range(num_layers): 35 out.add(_make_dense_layer(growth_rate, bn_size, dropout, norm_layer, norm_kwargs)) 36 return out 37 38def _make_dense_layer(growth_rate, bn_size, dropout, norm_layer, norm_kwargs): 39 new_features = nn.HybridSequential(prefix='') 40 new_features.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs))) 41 new_features.add(nn.Activation('relu')) 42 new_features.add(nn.Conv2D(bn_size * growth_rate, kernel_size=1, use_bias=False)) 43 new_features.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs))) 44 new_features.add(nn.Activation('relu')) 45 new_features.add(nn.Conv2D(growth_rate, kernel_size=3, padding=1, use_bias=False)) 46 if dropout: 47 new_features.add(nn.Dropout(dropout)) 48 49 out = HybridConcurrent(axis=1, prefix='') 50 out.add(Identity()) 51 out.add(new_features) 52 53 return out 54 55def _make_transition(num_output_features, norm_layer, norm_kwargs): 56 out = nn.HybridSequential(prefix='') 57 out.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs))) 58 out.add(nn.Activation('relu')) 59 out.add(nn.Conv2D(num_output_features, kernel_size=1, use_bias=False)) 60 out.add(nn.AvgPool2D(pool_size=2, strides=2)) 61 return out 62 63# Net 64class DenseNet(HybridBlock): 65 r"""Densenet-BC model from the 66 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 67 68 Parameters 69 ---------- 70 num_init_features : int 71 Number of filters to learn in the first convolution layer. 72 growth_rate : int 73 Number of filters to add each layer (`k` in the paper). 74 block_config : list of int 75 List of integers for numbers of layers in each pooling block. 76 bn_size : int, default 4 77 Multiplicative factor for number of bottle neck layers. 78 (i.e. bn_size * k features in the bottleneck layer) 79 dropout : float, default 0 80 Rate of dropout after each dense layer. 81 classes : int, default 1000 82 Number of classification classes. 83 norm_layer : object 84 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 85 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 86 norm_kwargs : dict 87 Additional `norm_layer` arguments, for example `num_devices=4` 88 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 89 """ 90 def __init__(self, num_init_features, growth_rate, block_config, 91 bn_size=4, dropout=0, classes=1000, 92 norm_layer=BatchNorm, norm_kwargs=None, **kwargs): 93 super(DenseNet, self).__init__(**kwargs) 94 with self.name_scope(): 95 self.features = nn.HybridSequential(prefix='') 96 self.features.add(nn.Conv2D(num_init_features, kernel_size=7, 97 strides=2, padding=3, use_bias=False)) 98 self.features.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs))) 99 self.features.add(nn.Activation('relu')) 100 self.features.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1)) 101 # Add dense blocks 102 num_features = num_init_features 103 for i, num_layers in enumerate(block_config): 104 self.features.add(_make_dense_block( 105 num_layers, bn_size, growth_rate, dropout, i+1, norm_layer, norm_kwargs)) 106 num_features = num_features + num_layers * growth_rate 107 if i != len(block_config) - 1: 108 self.features.add(_make_transition(num_features // 2, norm_layer, norm_kwargs)) 109 num_features = num_features // 2 110 self.features.add(norm_layer(**({} if norm_kwargs is None else norm_kwargs))) 111 self.features.add(nn.Activation('relu')) 112 self.features.add(nn.AvgPool2D(pool_size=7)) 113 self.features.add(nn.Flatten()) 114 115 self.output = nn.Dense(classes) 116 117 def hybrid_forward(self, F, x): 118 x = self.features(x) 119 x = self.output(x) 120 return x 121 122 123# Specification 124densenet_spec = {121: (64, 32, [6, 12, 24, 16]), 125 161: (96, 48, [6, 12, 36, 24]), 126 169: (64, 32, [6, 12, 32, 32]), 127 201: (64, 32, [6, 12, 48, 32])} 128 129 130# Constructor 131def get_densenet(num_layers, pretrained=False, ctx=cpu(), 132 root='~/.mxnet/models', **kwargs): 133 r"""Densenet-BC model from the 134 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 135 136 Parameters 137 ---------- 138 num_layers : int 139 Number of layers for the variant of densenet. Options are 121, 161, 169, 201. 140 pretrained : bool or str 141 Boolean value controls whether to load the default pretrained weights for model. 142 String value represents the hashtag for a certain version of pretrained weights. 143 ctx : Context, default CPU 144 The context in which to load the pretrained weights. 145 root : str, default $MXNET_HOME/models 146 Location for keeping the model parameters. 147 norm_layer : object 148 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 149 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 150 norm_kwargs : dict 151 Additional `norm_layer` arguments, for example `num_devices=4` 152 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 153 """ 154 num_init_features, growth_rate, block_config = densenet_spec[num_layers] 155 net = DenseNet(num_init_features, growth_rate, block_config, **kwargs) 156 if pretrained: 157 from .model_store import get_model_file 158 net.load_parameters(get_model_file('densenet%d'%(num_layers), 159 tag=pretrained, root=root), ctx=ctx) 160 from ..data import ImageNet1kAttr 161 attrib = ImageNet1kAttr() 162 net.synset = attrib.synset 163 net.classes = attrib.classes 164 net.classes_long = attrib.classes_long 165 return net 166 167def densenet121(**kwargs): 168 r"""Densenet-BC 121-layer model from the 169 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 170 171 Parameters 172 ---------- 173 pretrained : bool or str 174 Boolean value controls whether to load the default pretrained weights for model. 175 String value represents the hashtag for a certain version of pretrained weights. 176 ctx : Context, default CPU 177 The context in which to load the pretrained weights. 178 root : str, default '$MXNET_HOME/models' 179 Location for keeping the model parameters. 180 norm_layer : object 181 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 182 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 183 norm_kwargs : dict 184 Additional `norm_layer` arguments, for example `num_devices=4` 185 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 186 """ 187 return get_densenet(121, **kwargs) 188 189def densenet161(**kwargs): 190 r"""Densenet-BC 161-layer model from the 191 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 192 193 Parameters 194 ---------- 195 pretrained : bool or str 196 Boolean value controls whether to load the default pretrained weights for model. 197 String value represents the hashtag for a certain version of pretrained weights. 198 ctx : Context, default CPU 199 The context in which to load the pretrained weights. 200 root : str, default '$MXNET_HOME/models' 201 Location for keeping the model parameters. 202 norm_layer : object 203 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 204 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 205 norm_kwargs : dict 206 Additional `norm_layer` arguments, for example `num_devices=4` 207 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 208 """ 209 return get_densenet(161, **kwargs) 210 211def densenet169(**kwargs): 212 r"""Densenet-BC 169-layer model from the 213 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 214 215 Parameters 216 ---------- 217 pretrained : bool or str 218 Boolean value controls whether to load the default pretrained weights for model. 219 String value represents the hashtag for a certain version of pretrained weights. 220 ctx : Context, default CPU 221 The context in which to load the pretrained weights. 222 root : str, default '$MXNET_HOME/models' 223 Location for keeping the model parameters. 224 norm_layer : object 225 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 226 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 227 norm_kwargs : dict 228 Additional `norm_layer` arguments, for example `num_devices=4` 229 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 230 """ 231 return get_densenet(169, **kwargs) 232 233def densenet201(**kwargs): 234 r"""Densenet-BC 201-layer model from the 235 `"Densely Connected Convolutional Networks" <https://arxiv.org/pdf/1608.06993.pdf>`_ paper. 236 237 Parameters 238 ---------- 239 pretrained : bool or str 240 Boolean value controls whether to load the default pretrained weights for model. 241 String value represents the hashtag for a certain version of pretrained weights. 242 ctx : Context, default CPU 243 The context in which to load the pretrained weights. 244 root : str, default '$MXNET_HOME/models' 245 Location for keeping the model parameters. 246 norm_layer : object 247 Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`) 248 Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 249 norm_kwargs : dict 250 Additional `norm_layer` arguments, for example `num_devices=4` 251 for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`. 252 """ 253 return get_densenet(201, **kwargs) 254