1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18""" 19Inception V3, suitable for images with around 299 x 299 20 21Reference: 22 23Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015). 24""" 25import mxnet as mx 26 27def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''): 28 conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix)) 29 bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True) 30 act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix)) 31 return act 32 33 34def Inception7A(data, 35 num_1x1, 36 num_3x3_red, num_3x3_1, num_3x3_2, 37 num_5x5_red, num_5x5, 38 pool, proj, 39 name): 40 tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name)) 41 tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv') 42 tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1') 43 tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv') 44 tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') 45 tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2') 46 pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) 47 cproj = Conv(pooling, proj, name=('%s_tower_2' % name), suffix='_conv') 48 concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name) 49 return concat 50 51# First Downsample 52def Inception7B(data, 53 num_3x3, 54 num_d3x3_red, num_d3x3_1, num_d3x3_2, 55 pool, 56 name): 57 tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name)) 58 tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv') 59 tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1') 60 tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2') 61 pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name)) 62 concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name) 63 return concat 64 65def Inception7C(data, 66 num_1x1, 67 num_d7_red, num_d7_1, num_d7_2, 68 num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4, 69 pool, proj, 70 name): 71 tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) 72 tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv') 73 tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1') 74 tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2') 75 tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv') 76 tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1') 77 tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2') 78 tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3') 79 tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4') 80 pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) 81 cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') 82 # concat 83 concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name) 84 return concat 85 86def Inception7D(data, 87 num_3x3_red, num_3x3, 88 num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3, 89 pool, 90 name): 91 tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv') 92 tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1') 93 tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv') 94 tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1') 95 tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2') 96 tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3') 97 pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) 98 # concat 99 concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name) 100 return concat 101 102def Inception7E(data, 103 num_1x1, 104 num_d3_red, num_d3_1, num_d3_2, 105 num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2, 106 pool, proj, 107 name): 108 tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name)) 109 tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv') 110 tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv') 111 tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1') 112 tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv') 113 tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1') 114 tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv') 115 tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1') 116 pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name))) 117 cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' % name), suffix='_conv') 118 # concat 119 concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name) 120 return concat 121 122# In[49]: 123 124def get_symbol(num_classes=1000, **kwargs): 125 data = mx.symbol.Variable(name="data") 126 # stage 1 127 conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv") 128 conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1") 129 conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2") 130 pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool") 131 # stage 2 132 conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3") 133 conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4") 134 pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1") 135 # stage 3 136 in3a = Inception7A(pool1, 64, 137 64, 96, 96, 138 48, 64, 139 "avg", 32, "mixed") 140 in3b = Inception7A(in3a, 64, 141 64, 96, 96, 142 48, 64, 143 "avg", 64, "mixed_1") 144 in3c = Inception7A(in3b, 64, 145 64, 96, 96, 146 48, 64, 147 "avg", 64, "mixed_2") 148 in3d = Inception7B(in3c, 384, 149 64, 96, 96, 150 "max", "mixed_3") 151 # stage 4 152 in4a = Inception7C(in3d, 192, 153 128, 128, 192, 154 128, 128, 128, 128, 192, 155 "avg", 192, "mixed_4") 156 in4b = Inception7C(in4a, 192, 157 160, 160, 192, 158 160, 160, 160, 160, 192, 159 "avg", 192, "mixed_5") 160 in4c = Inception7C(in4b, 192, 161 160, 160, 192, 162 160, 160, 160, 160, 192, 163 "avg", 192, "mixed_6") 164 in4d = Inception7C(in4c, 192, 165 192, 192, 192, 166 192, 192, 192, 192, 192, 167 "avg", 192, "mixed_7") 168 in4e = Inception7D(in4d, 192, 320, 169 192, 192, 192, 192, 170 "max", "mixed_8") 171 # stage 5 172 in5a = Inception7E(in4e, 320, 173 384, 384, 384, 174 448, 384, 384, 384, 175 "avg", 192, "mixed_9") 176 in5b = Inception7E(in5a, 320, 177 384, 384, 384, 178 448, 384, 384, 384, 179 "max", 192, "mixed_10") 180 # pool 181 pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool") 182 flatten = mx.sym.Flatten(data=pool, name="flatten") 183 fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1') 184 softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax') 185 return softmax 186