1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Convert caffe model
19"""
20from __future__ import print_function
21import argparse
22import sys
23import re
24import numpy as np
25import caffe_parser
26import mxnet as mx
27from convert_symbol import convert_symbol
28
29def prob_label(arg_names):
30    candidates = [arg for arg in arg_names if
31                  not arg.endswith('data') and
32                  not arg.endswith('_weight') and
33                  not arg.endswith('_bias') and
34                  not arg.endswith('_gamma') and
35                  not arg.endswith('_beta')]
36    if len(candidates) == 0:
37        return 'prob_label'
38    return candidates[-1]
39
40def convert_model(prototxt_fname, caffemodel_fname, output_prefix=None):
41    """Convert caffe model
42
43    Parameters
44    ----------
45
46    prototxt_fname : str
47         Filename of the prototxt model definition
48    caffemodel_fname : str
49         Filename of the binary caffe model
50    output_prefix : str, optinoal
51         If given, then save the converted MXNet into output_prefx+'.json' and
52         output_prefx+'.params'
53
54    Returns
55    -------
56    sym : Symbol
57         Symbol convereted from prototxt
58    arg_params : list of NDArray
59         Argument parameters
60    aux_params : list of NDArray
61         Aux parameters
62    input_dim : tuple
63         Input dimension
64    """
65    sym, input_dim = convert_symbol(prototxt_fname)
66    arg_shapes, _, aux_shapes = sym.infer_shape(data=tuple(input_dim))
67    arg_names = sym.list_arguments()
68    aux_names = sym.list_auxiliary_states()
69    arg_shape_dic = dict(zip(arg_names, arg_shapes))
70    aux_shape_dic = dict(zip(aux_names, aux_shapes))
71    arg_params = {}
72    aux_params = {}
73    first_conv = True
74
75    layers, names = caffe_parser.read_caffemodel(prototxt_fname, caffemodel_fname)
76    layer_iter = caffe_parser.layer_iter(layers, names)
77    layers_proto = caffe_parser.get_layers(caffe_parser.read_prototxt(prototxt_fname))
78
79    for layer_name, layer_type, layer_blobs in layer_iter:
80        if layer_type in ('Convolution', 'InnerProduct', 4, 14, 'PReLU', 'Deconvolution',
81                          39):
82            if layer_type == 'PReLU':
83                assert (len(layer_blobs) == 1)
84                weight_name = layer_name + '_gamma'
85                wmat = np.array(layer_blobs[0].data).reshape(arg_shape_dic[weight_name])
86                arg_params[weight_name] = mx.nd.zeros(wmat.shape)
87                arg_params[weight_name][:] = wmat
88                continue
89            wmat_dim = []
90            if getattr(layer_blobs[0].shape, 'dim', None) is not None:
91                if len(layer_blobs[0].shape.dim) > 0:
92                    wmat_dim = layer_blobs[0].shape.dim
93                else:
94                    wmat_dim = [layer_blobs[0].num, layer_blobs[0].channels,
95                                layer_blobs[0].height, layer_blobs[0].width]
96            else:
97                wmat_dim = list(layer_blobs[0].shape)
98            wmat = np.array(layer_blobs[0].data).reshape(wmat_dim)
99
100            channels = wmat_dim[1]
101            if channels in (3, 4):  # RGB or RGBA
102                if first_conv:
103                    # Swapping BGR of caffe into RGB in mxnet
104                    wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :]
105
106            assert(wmat.flags['C_CONTIGUOUS'] is True)
107            sys.stdout.write('converting layer {0}, wmat shape = {1}'.format(
108                layer_name, wmat.shape))
109            if len(layer_blobs) == 2:
110                bias = np.array(layer_blobs[1].data)
111                bias = bias.reshape((bias.shape[0], 1))
112                assert(bias.flags['C_CONTIGUOUS'] is True)
113                bias_name = layer_name + "_bias"
114
115                if bias_name not in arg_shape_dic:
116                    print(bias_name + ' not found in arg_shape_dic.')
117                    continue
118                bias = bias.reshape(arg_shape_dic[bias_name])
119                arg_params[bias_name] = mx.nd.zeros(bias.shape)
120                arg_params[bias_name][:] = bias
121                sys.stdout.write(', bias shape = {}'.format(bias.shape))
122
123            sys.stdout.write('\n')
124            sys.stdout.flush()
125            wmat = wmat.reshape((wmat.shape[0], -1))
126            weight_name = layer_name + "_weight"
127
128            if weight_name not in arg_shape_dic:
129                print(weight_name + ' not found in arg_shape_dic.')
130                continue
131            wmat = wmat.reshape(arg_shape_dic[weight_name])
132            arg_params[weight_name] = mx.nd.zeros(wmat.shape)
133            arg_params[weight_name][:] = wmat
134
135            if first_conv and layer_type in ('Convolution', 4):
136                first_conv = False
137
138        elif layer_type == 'Scale':
139            if 'scale' in layer_name:
140                bn_name = layer_name.replace('scale', 'bn')
141            elif 'sc' in layer_name:
142                bn_name = layer_name.replace('sc', 'bn')
143            else:
144                assert False, 'Unknown name convention for bn/scale'
145
146            gamma = np.array(layer_blobs[0].data)
147            beta = np.array(layer_blobs[1].data)
148            # beta = np.expand_dims(beta, 1)
149            beta_name = '{}_beta'.format(bn_name)
150            gamma_name = '{}_gamma'.format(bn_name)
151
152            beta = beta.reshape(arg_shape_dic[beta_name])
153            gamma = gamma.reshape(arg_shape_dic[gamma_name])
154            arg_params[beta_name] = mx.nd.zeros(beta.shape)
155            arg_params[gamma_name] = mx.nd.zeros(gamma.shape)
156            arg_params[beta_name][:] = beta
157            arg_params[gamma_name][:] = gamma
158
159            assert gamma.flags['C_CONTIGUOUS'] is True
160            assert beta.flags['C_CONTIGUOUS'] is True
161            print('converting scale layer, beta shape = {}, gamma shape = {}'.format(
162                beta.shape, gamma.shape))
163        elif layer_type == 'BatchNorm':
164            bn_name = layer_name
165            mean = np.array(layer_blobs[0].data)
166            var = np.array(layer_blobs[1].data)
167            rescale_factor = layer_blobs[2].data[0]
168            if rescale_factor != 0:
169                rescale_factor = 1 / rescale_factor
170            mean_name = '{}_moving_mean'.format(bn_name)
171            var_name = '{}_moving_var'.format(bn_name)
172            mean = mean.reshape(aux_shape_dic[mean_name])
173            var = var.reshape(aux_shape_dic[var_name])
174            aux_params[mean_name] = mx.nd.zeros(mean.shape)
175            aux_params[var_name] = mx.nd.zeros(var.shape)
176            # Get the original epsilon
177            for idx, layer in enumerate(layers_proto):
178                if layer.name == bn_name or re.sub('[-/]', '_', layer.name) == bn_name:
179                    bn_index = idx
180            eps_caffe = layers_proto[bn_index].batch_norm_param.eps
181            # Compensate for the epsilon shift performed in convert_symbol
182            eps_symbol = float(sym.attr_dict()[bn_name + '_moving_mean']['eps'])
183            eps_correction = eps_caffe - eps_symbol
184            # Fill parameters
185            aux_params[mean_name][:] = mean * rescale_factor
186            aux_params[var_name][:] = var * rescale_factor + eps_correction
187            assert var.flags['C_CONTIGUOUS'] is True
188            assert mean.flags['C_CONTIGUOUS'] is True
189            print('converting batchnorm layer, mean shape = {}, var shape = {}'.format(
190                mean.shape, var.shape))
191
192            fix_gamma = layers_proto[bn_index+1].type != 'Scale'
193            if fix_gamma:
194                gamma_name = '{}_gamma'.format(bn_name)
195                gamma = np.array(np.ones(arg_shape_dic[gamma_name]))
196                beta_name = '{}_beta'.format(bn_name)
197                beta = np.array(np.zeros(arg_shape_dic[beta_name]))
198                arg_params[beta_name] = mx.nd.zeros(beta.shape)
199                arg_params[gamma_name] = mx.nd.zeros(gamma.shape)
200                arg_params[beta_name][:] = beta
201                arg_params[gamma_name][:] = gamma
202                assert gamma.flags['C_CONTIGUOUS'] is True
203                assert beta.flags['C_CONTIGUOUS'] is True
204
205        else:
206            print('\tskipping layer {} of type {}'.format(layer_name, layer_type))
207            assert len(layer_blobs) == 0
208
209    if output_prefix is not None:
210        model = mx.mod.Module(symbol=sym, label_names=[prob_label(arg_names), ])
211        model.bind(data_shapes=[('data', tuple(input_dim))])
212        model.init_params(arg_params=arg_params, aux_params=aux_params)
213        model.save_checkpoint(output_prefix, 0)
214
215    return sym, arg_params, aux_params, input_dim
216
217def main():
218    parser = argparse.ArgumentParser(
219        description='Caffe prototxt to mxnet model parameter converter.')
220    parser.add_argument('prototxt', help='The prototxt filename')
221    parser.add_argument('caffemodel', help='The binary caffemodel filename')
222    parser.add_argument('save_model_name', help='The name of the output model prefix')
223    args = parser.parse_args()
224
225    convert_model(args.prototxt, args.caffemodel, args.save_model_name)
226    print('Saved model successfully to {}'.format(args.save_model_name))
227
228if __name__ == '__main__':
229    main()
230