1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""
19Inception V3, suitable for images with around 299 x 299
20
21Reference:
22
23Szegedy, Christian, et al. "Rethinking the Inception Architecture for Computer Vision." arXiv preprint arXiv:1512.00567 (2015).
24"""
25import mxnet as mx
26
27def Conv(data, num_filter, kernel=(1, 1), stride=(1, 1), pad=(0, 0), name=None, suffix=''):
28    conv = mx.sym.Convolution(data=data, num_filter=num_filter, kernel=kernel, stride=stride, pad=pad, no_bias=True, name='%s%s_conv2d' %(name, suffix))
29    bn = mx.sym.BatchNorm(data=conv, name='%s%s_batchnorm' %(name, suffix), fix_gamma=True)
30    act = mx.sym.Activation(data=bn, act_type='relu', name='%s%s_relu' %(name, suffix))
31    return act
32
33
34def Inception7A(data,
35                num_1x1,
36                num_3x3_red, num_3x3_1, num_3x3_2,
37                num_5x5_red, num_5x5,
38                pool, proj,
39                name):
40    tower_1x1 = Conv(data, num_1x1, name=('%s_conv' % name))
41    tower_5x5 = Conv(data, num_5x5_red, name=('%s_tower' % name), suffix='_conv')
42    tower_5x5 = Conv(tower_5x5, num_5x5, kernel=(5, 5), pad=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
43    tower_3x3 = Conv(data, num_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
44    tower_3x3 = Conv(tower_3x3, num_3x3_1, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
45    tower_3x3 = Conv(tower_3x3, num_3x3_2, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_2')
46    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
47    cproj = Conv(pooling, proj, name=('%s_tower_2' %  name), suffix='_conv')
48    concat = mx.sym.Concat(*[tower_1x1, tower_5x5, tower_3x3, cproj], name='ch_concat_%s_chconcat' % name)
49    return concat
50
51# First Downsample
52def Inception7B(data,
53                num_3x3,
54                num_d3x3_red, num_d3x3_1, num_d3x3_2,
55                pool,
56                name):
57    tower_3x3 = Conv(data, num_3x3, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_conv' % name))
58    tower_d3x3 = Conv(data, num_d3x3_red, name=('%s_tower' % name), suffix='_conv')
59    tower_d3x3 = Conv(tower_d3x3, num_d3x3_1, kernel=(3, 3), pad=(1, 1), stride=(1, 1), name=('%s_tower' % name), suffix='_conv_1')
60    tower_d3x3 = Conv(tower_d3x3, num_d3x3_2, kernel=(3, 3), pad=(0, 0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_2')
61    pooling = mx.symbol.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pad=(0,0), pool_type="max", name=('max_pool_%s_pool' % name))
62    concat = mx.sym.Concat(*[tower_3x3, tower_d3x3, pooling], name='ch_concat_%s_chconcat' % name)
63    return concat
64
65def Inception7C(data,
66                num_1x1,
67                num_d7_red, num_d7_1, num_d7_2,
68                num_q7_red, num_q7_1, num_q7_2, num_q7_3, num_q7_4,
69                pool, proj,
70                name):
71    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
72    tower_d7 = Conv(data=data, num_filter=num_d7_red, name=('%s_tower' % name), suffix='_conv')
73    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower' % name), suffix='_conv_1')
74    tower_d7 = Conv(data=tower_d7, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower' % name), suffix='_conv_2')
75    tower_q7 = Conv(data=data, num_filter=num_q7_red, name=('%s_tower_1' % name), suffix='_conv')
76    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_1, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_1')
77    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_2, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_2')
78    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_3, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_3')
79    tower_q7 = Conv(data=tower_q7, num_filter=num_q7_4, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_4')
80    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
81    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
82    # concat
83    concat = mx.sym.Concat(*[tower_1x1, tower_d7, tower_q7, cproj], name='ch_concat_%s_chconcat' % name)
84    return concat
85
86def Inception7D(data,
87                num_3x3_red, num_3x3,
88                num_d7_3x3_red, num_d7_1, num_d7_2, num_d7_3x3,
89                pool,
90                name):
91    tower_3x3 = Conv(data=data, num_filter=num_3x3_red, name=('%s_tower' % name), suffix='_conv')
92    tower_3x3 = Conv(data=tower_3x3, num_filter=num_3x3, kernel=(3, 3), pad=(0,0), stride=(2, 2), name=('%s_tower' % name), suffix='_conv_1')
93    tower_d7_3x3 = Conv(data=data, num_filter=num_d7_3x3_red, name=('%s_tower_1' % name), suffix='_conv')
94    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_1, kernel=(1, 7), pad=(0, 3), name=('%s_tower_1' % name), suffix='_conv_1')
95    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_2, kernel=(7, 1), pad=(3, 0), name=('%s_tower_1' % name), suffix='_conv_2')
96    tower_d7_3x3 = Conv(data=tower_d7_3x3, num_filter=num_d7_3x3, kernel=(3, 3), stride=(2, 2), name=('%s_tower_1' % name), suffix='_conv_3')
97    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(2, 2), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
98    # concat
99    concat = mx.sym.Concat(*[tower_3x3, tower_d7_3x3, pooling], name='ch_concat_%s_chconcat' % name)
100    return concat
101
102def Inception7E(data,
103                num_1x1,
104                num_d3_red, num_d3_1, num_d3_2,
105                num_3x3_d3_red, num_3x3, num_3x3_d3_1, num_3x3_d3_2,
106                pool, proj,
107                name):
108    tower_1x1 = Conv(data=data, num_filter=num_1x1, kernel=(1, 1), name=('%s_conv' % name))
109    tower_d3 = Conv(data=data, num_filter=num_d3_red, name=('%s_tower' % name), suffix='_conv')
110    tower_d3_a = Conv(data=tower_d3, num_filter=num_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower' % name), suffix='_mixed_conv')
111    tower_d3_b = Conv(data=tower_d3, num_filter=num_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower' % name), suffix='_mixed_conv_1')
112    tower_3x3_d3 = Conv(data=data, num_filter=num_3x3_d3_red, name=('%s_tower_1' % name), suffix='_conv')
113    tower_3x3_d3 = Conv(data=tower_3x3_d3, num_filter=num_3x3, kernel=(3, 3), pad=(1, 1), name=('%s_tower_1' % name), suffix='_conv_1')
114    tower_3x3_d3_a = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_1, kernel=(1, 3), pad=(0, 1), name=('%s_tower_1' % name), suffix='_mixed_conv')
115    tower_3x3_d3_b = Conv(data=tower_3x3_d3, num_filter=num_3x3_d3_2, kernel=(3, 1), pad=(1, 0), name=('%s_tower_1' % name), suffix='_mixed_conv_1')
116    pooling = mx.sym.Pooling(data=data, kernel=(3, 3), stride=(1, 1), pad=(1, 1), pool_type=pool, name=('%s_pool_%s_pool' % (pool, name)))
117    cproj = Conv(data=pooling, num_filter=proj, kernel=(1, 1), name=('%s_tower_2' %  name), suffix='_conv')
118    # concat
119    concat = mx.sym.Concat(*[tower_1x1, tower_d3_a, tower_d3_b, tower_3x3_d3_a, tower_3x3_d3_b, cproj], name='ch_concat_%s_chconcat' % name)
120    return concat
121
122# In[49]:
123
124def get_symbol(num_classes=1000, **kwargs):
125    data = mx.symbol.Variable(name="data")
126    # stage 1
127    conv = Conv(data, 32, kernel=(3, 3), stride=(2, 2), name="conv")
128    conv_1 = Conv(conv, 32, kernel=(3, 3), name="conv_1")
129    conv_2 = Conv(conv_1, 64, kernel=(3, 3), pad=(1, 1), name="conv_2")
130    pool = mx.sym.Pooling(data=conv_2, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool")
131    # stage 2
132    conv_3 = Conv(pool, 80, kernel=(1, 1), name="conv_3")
133    conv_4 = Conv(conv_3, 192, kernel=(3, 3), name="conv_4")
134    pool1 = mx.sym.Pooling(data=conv_4, kernel=(3, 3), stride=(2, 2), pool_type="max", name="pool1")
135    # stage 3
136    in3a = Inception7A(pool1, 64,
137                       64, 96, 96,
138                       48, 64,
139                       "avg", 32, "mixed")
140    in3b = Inception7A(in3a, 64,
141                       64, 96, 96,
142                       48, 64,
143                       "avg", 64, "mixed_1")
144    in3c = Inception7A(in3b, 64,
145                       64, 96, 96,
146                       48, 64,
147                       "avg", 64, "mixed_2")
148    in3d = Inception7B(in3c, 384,
149                       64, 96, 96,
150                       "max", "mixed_3")
151    # stage 4
152    in4a = Inception7C(in3d, 192,
153                       128, 128, 192,
154                       128, 128, 128, 128, 192,
155                       "avg", 192, "mixed_4")
156    in4b = Inception7C(in4a, 192,
157                       160, 160, 192,
158                       160, 160, 160, 160, 192,
159                       "avg", 192, "mixed_5")
160    in4c = Inception7C(in4b, 192,
161                       160, 160, 192,
162                       160, 160, 160, 160, 192,
163                       "avg", 192, "mixed_6")
164    in4d = Inception7C(in4c, 192,
165                       192, 192, 192,
166                       192, 192, 192, 192, 192,
167                       "avg", 192, "mixed_7")
168    in4e = Inception7D(in4d, 192, 320,
169                       192, 192, 192, 192,
170                       "max", "mixed_8")
171    # stage 5
172    in5a = Inception7E(in4e, 320,
173                       384, 384, 384,
174                       448, 384, 384, 384,
175                       "avg", 192, "mixed_9")
176    in5b = Inception7E(in5a, 320,
177                       384, 384, 384,
178                       448, 384, 384, 384,
179                       "max", 192, "mixed_10")
180    # pool
181    pool = mx.sym.Pooling(data=in5b, kernel=(8, 8), stride=(1, 1), pool_type="avg", name="global_pool")
182    flatten = mx.sym.Flatten(data=pool, name="flatten")
183    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=num_classes, name='fc1')
184    softmax = mx.symbol.SoftmaxOutput(data=fc1, name='softmax')
185    return softmax
186