1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Convert caffe model 19""" 20from __future__ import print_function 21import argparse 22import sys 23import re 24import numpy as np 25import caffe_parser 26import mxnet as mx 27from convert_symbol import convert_symbol 28 29def prob_label(arg_names): 30 candidates = [arg for arg in arg_names if 31 not arg.endswith('data') and 32 not arg.endswith('_weight') and 33 not arg.endswith('_bias') and 34 not arg.endswith('_gamma') and 35 not arg.endswith('_beta')] 36 if len(candidates) == 0: 37 return 'prob_label' 38 return candidates[-1] 39 40def convert_model(prototxt_fname, caffemodel_fname, output_prefix=None): 41 """Convert caffe model 42 43 Parameters 44 ---------- 45 46 prototxt_fname : str 47 Filename of the prototxt model definition 48 caffemodel_fname : str 49 Filename of the binary caffe model 50 output_prefix : str, optinoal 51 If given, then save the converted MXNet into output_prefx+'.json' and 52 output_prefx+'.params' 53 54 Returns 55 ------- 56 sym : Symbol 57 Symbol convereted from prototxt 58 arg_params : list of NDArray 59 Argument parameters 60 aux_params : list of NDArray 61 Aux parameters 62 input_dim : tuple 63 Input dimension 64 """ 65 sym, input_dim = convert_symbol(prototxt_fname) 66 arg_shapes, _, aux_shapes = sym.infer_shape(data=tuple(input_dim)) 67 arg_names = sym.list_arguments() 68 aux_names = sym.list_auxiliary_states() 69 arg_shape_dic = dict(zip(arg_names, arg_shapes)) 70 aux_shape_dic = dict(zip(aux_names, aux_shapes)) 71 arg_params = {} 72 aux_params = {} 73 first_conv = True 74 75 layers, names = caffe_parser.read_caffemodel(prototxt_fname, caffemodel_fname) 76 layer_iter = caffe_parser.layer_iter(layers, names) 77 layers_proto = caffe_parser.get_layers(caffe_parser.read_prototxt(prototxt_fname)) 78 79 for layer_name, layer_type, layer_blobs in layer_iter: 80 if layer_type in ('Convolution', 'InnerProduct', 4, 14, 'PReLU', 'Deconvolution', 81 39): 82 if layer_type == 'PReLU': 83 assert (len(layer_blobs) == 1) 84 weight_name = layer_name + '_gamma' 85 wmat = np.array(layer_blobs[0].data).reshape(arg_shape_dic[weight_name]) 86 arg_params[weight_name] = mx.nd.zeros(wmat.shape) 87 arg_params[weight_name][:] = wmat 88 continue 89 wmat_dim = [] 90 if getattr(layer_blobs[0].shape, 'dim', None) is not None: 91 if len(layer_blobs[0].shape.dim) > 0: 92 wmat_dim = layer_blobs[0].shape.dim 93 else: 94 wmat_dim = [layer_blobs[0].num, layer_blobs[0].channels, 95 layer_blobs[0].height, layer_blobs[0].width] 96 else: 97 wmat_dim = list(layer_blobs[0].shape) 98 wmat = np.array(layer_blobs[0].data).reshape(wmat_dim) 99 100 channels = wmat_dim[1] 101 if channels in (3, 4): # RGB or RGBA 102 if first_conv: 103 # Swapping BGR of caffe into RGB in mxnet 104 wmat[:, [0, 2], :, :] = wmat[:, [2, 0], :, :] 105 106 assert(wmat.flags['C_CONTIGUOUS'] is True) 107 sys.stdout.write('converting layer {0}, wmat shape = {1}'.format( 108 layer_name, wmat.shape)) 109 if len(layer_blobs) == 2: 110 bias = np.array(layer_blobs[1].data) 111 bias = bias.reshape((bias.shape[0], 1)) 112 assert(bias.flags['C_CONTIGUOUS'] is True) 113 bias_name = layer_name + "_bias" 114 115 if bias_name not in arg_shape_dic: 116 print(bias_name + ' not found in arg_shape_dic.') 117 continue 118 bias = bias.reshape(arg_shape_dic[bias_name]) 119 arg_params[bias_name] = mx.nd.zeros(bias.shape) 120 arg_params[bias_name][:] = bias 121 sys.stdout.write(', bias shape = {}'.format(bias.shape)) 122 123 sys.stdout.write('\n') 124 sys.stdout.flush() 125 wmat = wmat.reshape((wmat.shape[0], -1)) 126 weight_name = layer_name + "_weight" 127 128 if weight_name not in arg_shape_dic: 129 print(weight_name + ' not found in arg_shape_dic.') 130 continue 131 wmat = wmat.reshape(arg_shape_dic[weight_name]) 132 arg_params[weight_name] = mx.nd.zeros(wmat.shape) 133 arg_params[weight_name][:] = wmat 134 135 if first_conv and layer_type in ('Convolution', 4): 136 first_conv = False 137 138 elif layer_type == 'Scale': 139 if 'scale' in layer_name: 140 bn_name = layer_name.replace('scale', 'bn') 141 elif 'sc' in layer_name: 142 bn_name = layer_name.replace('sc', 'bn') 143 else: 144 assert False, 'Unknown name convention for bn/scale' 145 146 gamma = np.array(layer_blobs[0].data) 147 beta = np.array(layer_blobs[1].data) 148 # beta = np.expand_dims(beta, 1) 149 beta_name = '{}_beta'.format(bn_name) 150 gamma_name = '{}_gamma'.format(bn_name) 151 152 beta = beta.reshape(arg_shape_dic[beta_name]) 153 gamma = gamma.reshape(arg_shape_dic[gamma_name]) 154 arg_params[beta_name] = mx.nd.zeros(beta.shape) 155 arg_params[gamma_name] = mx.nd.zeros(gamma.shape) 156 arg_params[beta_name][:] = beta 157 arg_params[gamma_name][:] = gamma 158 159 assert gamma.flags['C_CONTIGUOUS'] is True 160 assert beta.flags['C_CONTIGUOUS'] is True 161 print('converting scale layer, beta shape = {}, gamma shape = {}'.format( 162 beta.shape, gamma.shape)) 163 elif layer_type == 'BatchNorm': 164 bn_name = layer_name 165 mean = np.array(layer_blobs[0].data) 166 var = np.array(layer_blobs[1].data) 167 rescale_factor = layer_blobs[2].data[0] 168 if rescale_factor != 0: 169 rescale_factor = 1 / rescale_factor 170 mean_name = '{}_moving_mean'.format(bn_name) 171 var_name = '{}_moving_var'.format(bn_name) 172 mean = mean.reshape(aux_shape_dic[mean_name]) 173 var = var.reshape(aux_shape_dic[var_name]) 174 aux_params[mean_name] = mx.nd.zeros(mean.shape) 175 aux_params[var_name] = mx.nd.zeros(var.shape) 176 # Get the original epsilon 177 for idx, layer in enumerate(layers_proto): 178 if layer.name == bn_name or re.sub('[-/]', '_', layer.name) == bn_name: 179 bn_index = idx 180 eps_caffe = layers_proto[bn_index].batch_norm_param.eps 181 # Compensate for the epsilon shift performed in convert_symbol 182 eps_symbol = float(sym.attr_dict()[bn_name + '_moving_mean']['eps']) 183 eps_correction = eps_caffe - eps_symbol 184 # Fill parameters 185 aux_params[mean_name][:] = mean * rescale_factor 186 aux_params[var_name][:] = var * rescale_factor + eps_correction 187 assert var.flags['C_CONTIGUOUS'] is True 188 assert mean.flags['C_CONTIGUOUS'] is True 189 print('converting batchnorm layer, mean shape = {}, var shape = {}'.format( 190 mean.shape, var.shape)) 191 192 fix_gamma = layers_proto[bn_index+1].type != 'Scale' 193 if fix_gamma: 194 gamma_name = '{}_gamma'.format(bn_name) 195 gamma = np.array(np.ones(arg_shape_dic[gamma_name])) 196 beta_name = '{}_beta'.format(bn_name) 197 beta = np.array(np.zeros(arg_shape_dic[beta_name])) 198 arg_params[beta_name] = mx.nd.zeros(beta.shape) 199 arg_params[gamma_name] = mx.nd.zeros(gamma.shape) 200 arg_params[beta_name][:] = beta 201 arg_params[gamma_name][:] = gamma 202 assert gamma.flags['C_CONTIGUOUS'] is True 203 assert beta.flags['C_CONTIGUOUS'] is True 204 205 else: 206 print('\tskipping layer {} of type {}'.format(layer_name, layer_type)) 207 assert len(layer_blobs) == 0 208 209 if output_prefix is not None: 210 model = mx.mod.Module(symbol=sym, label_names=[prob_label(arg_names), ]) 211 model.bind(data_shapes=[('data', tuple(input_dim))]) 212 model.init_params(arg_params=arg_params, aux_params=aux_params) 213 model.save_checkpoint(output_prefix, 0) 214 215 return sym, arg_params, aux_params, input_dim 216 217def main(): 218 parser = argparse.ArgumentParser( 219 description='Caffe prototxt to mxnet model parameter converter.') 220 parser.add_argument('prototxt', help='The prototxt filename') 221 parser.add_argument('caffemodel', help='The binary caffemodel filename') 222 parser.add_argument('save_model_name', help='The name of the output model prefix') 223 args = parser.parse_args() 224 225 convert_model(args.prototxt, args.caffemodel, args.save_model_name) 226 print('Saved model successfully to {}'.format(args.save_model_name)) 227 228if __name__ == '__main__': 229 main() 230