1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Convert caffe prototxt to symbol
19"""
20from __future__ import print_function
21import argparse
22import re
23import mxnet as mx
24import caffe_parser
25
26
27def _get_input(proto):
28    """Get input size
29    """
30    layer = caffe_parser.get_layers(proto)
31    if len(proto.input_dim) > 0:
32        input_dim = proto.input_dim
33    elif len(proto.input_shape) > 0:
34        input_dim = proto.input_shape[0].dim
35    elif layer[0].type == "Input":
36        input_dim = layer[0].input_param.shape[0].dim
37        layer.pop(0)
38    else:
39        raise ValueError('Cannot find input size')
40
41    assert layer[0].type != "Input", 'only support single input'
42    # We assume the first bottom blob of first layer is the output from data layer
43    input_name = layer[0].bottom[0]
44    return input_name, input_dim, layer
45
46def _convert_conv_param(param):
47    """
48    Convert convolution layer parameter from Caffe to MXNet
49    """
50    param_string = "num_filter=%d" % param.num_output
51
52    pad_w = 0
53    pad_h = 0
54    if isinstance(param.pad, int):
55        pad = param.pad
56        param_string += ", pad=(%d, %d)" % (pad, pad)
57    else:
58        if len(param.pad) > 0:
59            pad = param.pad[0]
60            param_string += ", pad=(%d, %d)" % (pad, pad)
61        else:
62            if isinstance(param.pad_w, int):
63                pad_w = param.pad_w
64            if isinstance(param.pad_h, int):
65                pad_h = param.pad_h
66            param_string += ", pad=(%d, %d)" % (pad_h, pad_w)
67
68    if isinstance(param.kernel_size, int):
69        kernel_size = param.kernel_size
70        param_string += ", kernel=(%d,%d)" % (kernel_size, kernel_size)
71    else:
72        if len(param.kernel_size) > 0:
73            kernel_size = param.kernel_size[0]
74            param_string += ", kernel=(%d,%d)" % (kernel_size, kernel_size)
75        else:
76            assert isinstance(param.kernel_w, int)
77            kernel_w = param.kernel_w
78            assert isinstance(param.kernel_h, int)
79            kernel_h = param.kernel_h
80            param_string += ", kernel=(%d,%d)" % (kernel_h, kernel_w)
81
82    stride = 1
83    if isinstance(param.stride, int):
84        stride = param.stride
85    else:
86        stride = 1 if len(param.stride) == 0 else param.stride[0]
87
88    param_string += ", stride=(%d,%d)" % (stride, stride)
89
90    dilate = 1
91    if hasattr(param, 'dilation'):
92        if isinstance(param.dilation, int):
93            dilate = param.dilation
94        else:
95            dilate = 1 if len(param.dilation) == 0 else param.dilation[0]
96
97    param_string += ", no_bias=%s" % (not param.bias_term)
98
99    # deal with dilation. Won't be in deconvolution
100    if dilate > 1:
101        param_string += ", dilate=(%d, %d)" % (dilate, dilate)
102
103    if isinstance(param.group, int):
104        if param.group != 1:
105            param_string += ", num_group=%d" % param.group
106
107    return param_string
108
109def _convert_pooling_param(param):
110    """Convert the pooling layer parameter
111    """
112    param_string = "pooling_convention='full', "
113    if param.global_pooling:
114        param_string += "global_pool=True, kernel=(1,1)"
115    else:
116        param_string += "pad=(%d,%d), kernel=(%d,%d), stride=(%d,%d)" % (
117            param.pad, param.pad, param.kernel_size, param.kernel_size,
118            param.stride, param.stride)
119    if param.pool == 0:
120        param_string += ", pool_type='max'"
121    elif param.pool == 1:
122        param_string += ", pool_type='avg'"
123    else:
124        raise ValueError("Unknown Pooling Method!")
125    return param_string
126
127def _parse_proto(prototxt_fname):
128    """Parse Caffe prototxt into symbol string
129    """
130    proto = caffe_parser.read_prototxt(prototxt_fname)
131
132    # process data layer
133    input_name, input_dim, layers = _get_input(proto)
134    # only support single input, so always use `data` as the input data
135    mapping = {input_name: 'data'}
136    need_flatten = {input_name: False}
137    symbol_string = "import mxnet as mx\ndata = mx.symbol.Variable(name='data')\n"
138
139    flatten_count = 0
140    output_name = ""
141    prev_name = None
142    _output_name = {}
143
144    # convert reset layers one by one
145    for i, layer in enumerate(layers):
146        type_string = ''
147        param_string = ''
148        skip_layer = False
149        name = re.sub('[-/]', '_', layer.name)
150        for k in range(len(layer.bottom)):
151            if layer.bottom[k] in _output_name:
152                _output_name[layer.bottom[k]]['count'] = _output_name[layer.bottom[k]]['count']+1
153            else:
154                _output_name[layer.bottom[k]] = {'count':0}
155        for k in range(len(layer.top)):
156            if layer.top[k] in _output_name:
157                _output_name[layer.top[k]]['count'] = _output_name[layer.top[k]]['count']+1
158            else:
159                _output_name[layer.top[k]] = {'count':0, 'name':name}
160        if layer.type == 'Convolution' or layer.type == 4:
161            type_string = 'mx.symbol.Convolution'
162            param_string = _convert_conv_param(layer.convolution_param)
163            need_flatten[name] = True
164        if layer.type == 'Deconvolution' or layer.type == 39:
165            type_string = 'mx.symbol.Deconvolution'
166            param_string = _convert_conv_param(layer.convolution_param)
167            need_flatten[name] = True
168        if layer.type == 'Pooling' or layer.type == 17:
169            type_string = 'mx.symbol.Pooling'
170            param_string = _convert_pooling_param(layer.pooling_param)
171            need_flatten[name] = True
172        if layer.type == 'ReLU' or layer.type == 18:
173            type_string = 'mx.symbol.Activation'
174            param_string = "act_type='relu'"
175            param = layer.relu_param
176            if hasattr(param, 'negative_slope'):
177                if param.negative_slope > 0:
178                    type_string = 'mx.symbol.LeakyReLU'
179                    param_string = "act_type='leaky', slope=%f" % param.negative_slope
180            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
181        if layer.type == 'TanH' or layer.type == 23:
182            type_string = 'mx.symbol.Activation'
183            param_string = "act_type='tanh'"
184            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
185        if layer.type == 'Sigmoid' or layer.type == 19:
186            type_string = 'mx.symbol.Activation'
187            param_string = "act_type='sigmoid'"
188            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
189        if layer.type == 'LRN' or layer.type == 15:
190            type_string = 'mx.symbol.LRN'
191            param = layer.lrn_param
192            param_string = "alpha=%f, beta=%f, knorm=%f, nsize=%d" % (
193                param.alpha, param.beta, param.k, param.local_size)
194            need_flatten[name] = True
195        if layer.type == 'InnerProduct' or layer.type == 14:
196            type_string = 'mx.symbol.FullyConnected'
197            param = layer.inner_product_param
198            param_string = "num_hidden=%d, no_bias=%s" % (
199                param.num_output, not param.bias_term)
200            need_flatten[name] = False
201        if layer.type == 'Dropout' or layer.type == 6:
202            type_string = 'mx.symbol.Dropout'
203            param = layer.dropout_param
204            param_string = "p=%f" % param.dropout_ratio
205            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
206        if layer.type == 'Softmax' or layer.type == 20:
207            type_string = 'mx.symbol.SoftmaxOutput'
208        if layer.type == 'Flatten' or layer.type == 8:
209            type_string = 'mx.symbol.Flatten'
210            need_flatten[name] = False
211        if layer.type == 'Split' or layer.type == 22:
212            type_string = 'split'  # will process later
213        if layer.type == 'Concat' or layer.type == 3:
214            type_string = 'mx.symbol.Concat'
215            need_flatten[name] = True
216        if layer.type == 'Crop':
217            type_string = 'mx.symbol.Crop'
218            need_flatten[name] = True
219            param_string = 'center_crop=True'
220        if layer.type == 'BatchNorm':
221            type_string = 'mx.symbol.BatchNorm'
222            param = layer.batch_norm_param
223            # CuDNN requires eps to be greater than 1e-05
224            # We compensate for this change in convert_model
225            epsilon = param.eps
226            if (epsilon <= 1e-05):
227                epsilon = 1e-04
228            # if next layer is scale, don't fix gamma
229            fix_gamma = layers[i+1].type != 'Scale'
230            param_string = 'use_global_stats=%s, fix_gamma=%s, eps=%f' % (
231                param.use_global_stats, fix_gamma, epsilon)
232            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
233        if layer.type == 'Scale':
234            assert layers[i-1].type == 'BatchNorm'
235            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
236            skip_layer = True
237            prev_name = re.sub('[-/]', '_', layers[i-1].name)
238        if layer.type == 'PReLU':
239            type_string = 'mx.symbol.LeakyReLU'
240            param = layer.prelu_param
241            param_string = "act_type='prelu', slope=%f" % param.filler.value
242            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
243        if layer.type == 'Eltwise':
244            type_string = 'mx.symbol.broadcast_add'
245            param = layer.eltwise_param
246            param_string = ""
247            need_flatten[name] = False
248        if layer.type == 'Reshape':
249            type_string = 'mx.symbol.Reshape'
250            need_flatten[name] = False
251            param = layer.reshape_param
252            param_string = "shape=(%s)" % (','.join(param.shape.dim),)
253        if layer.type == 'AbsVal':
254            type_string = 'mx.symbol.abs'
255            need_flatten[name] = need_flatten[mapping[layer.bottom[0]]]
256
257        if skip_layer:
258            assert len(layer.bottom) == 1
259            symbol_string += "%s = %s\n" % (name, prev_name)
260        elif type_string == '':
261            raise ValueError('Unknown layer %s!' % layer.type)
262        elif type_string != 'split':
263            bottom = layer.bottom
264            if param_string != "":
265                param_string = ", " + param_string
266            if len(bottom) == 1:
267                if need_flatten[mapping[bottom[0]]] and type_string == 'mx.symbol.FullyConnected':
268                    flatten_name = "flatten_%d" % flatten_count
269                    symbol_string += "%s=mx.symbol.Flatten(name='%s', data=%s)\n" % (
270                        flatten_name, flatten_name, mapping[bottom[0]])
271                    flatten_count += 1
272                    need_flatten[flatten_name] = False
273                    bottom[0] = flatten_name
274                    mapping[bottom[0]] = bottom[0]
275                symbol_string += "%s = %s(name='%s', data=%s %s)\n" % (
276                    name, type_string, name, mapping[bottom[0]], param_string)
277            else:
278                if layer.type == 'Eltwise' and param.operation == 1 and len(param.coeff) > 0:
279                    symbol_string += "%s = " % name
280                    symbol_string += " + ".join(["%s * %s" % (
281                        mapping[bottom[i]], param.coeff[i]) for i in range(len(param.coeff))])
282                    symbol_string += "\n"
283                else:
284                    symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % (
285                        name, type_string, name, ','.join(
286                            [mapping[x] for x in bottom]), param_string)
287        for j in range(len(layer.top)):
288            mapping[layer.top[j]] = name
289        output_name = name
290    output_name = []
291    for i in _output_name:
292        if 'name' in _output_name[i] and _output_name[i]['count'] == 0:
293            output_name.append(_output_name[i]['name'])
294
295    return symbol_string, output_name, input_dim
296
297def convert_symbol(prototxt_fname):
298    """Convert caffe model definition into Symbol
299
300    Parameters
301    ----------
302    prototxt_fname : str
303        Filename of the prototxt file
304
305    Returns
306    -------
307    Symbol
308        Converted Symbol
309    tuple
310        Input shape
311    """
312    sym, output_name, input_dim = _parse_proto(prototxt_fname)
313    exec(sym)                   # pylint: disable=exec-used
314    _locals = locals()
315    ret = []
316    for i in  output_name:
317        exec("ret = " + i, globals(), _locals)  # pylint: disable=exec-used
318        ret.append(_locals['ret'])
319    ret = mx.sym.Group(ret)
320    return ret, input_dim
321
322def main():
323    parser = argparse.ArgumentParser(
324        description='Convert caffe prototxt into Symbol')
325    parser.add_argument('prototxt', help='The prototxt filename')
326    parser.add_argument('output', help='filename for the output json file')
327    args = parser.parse_args()
328
329    sym, _ = convert_symbol(args.prototxt)
330    sym.save(args.output)
331
332if __name__ == '__main__':
333    main()
334