1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Convert caffe prototxt to symbol 19""" 20from __future__ import print_function 21import argparse 22import re 23import mxnet as mx 24import caffe_parser 25 26 27def _get_input(proto): 28 """Get input size 29 """ 30 layer = caffe_parser.get_layers(proto) 31 if len(proto.input_dim) > 0: 32 input_dim = proto.input_dim 33 elif len(proto.input_shape) > 0: 34 input_dim = proto.input_shape[0].dim 35 elif layer[0].type == "Input": 36 input_dim = layer[0].input_param.shape[0].dim 37 layer.pop(0) 38 else: 39 raise ValueError('Cannot find input size') 40 41 assert layer[0].type != "Input", 'only support single input' 42 # We assume the first bottom blob of first layer is the output from data layer 43 input_name = layer[0].bottom[0] 44 return input_name, input_dim, layer 45 46def _convert_conv_param(param): 47 """ 48 Convert convolution layer parameter from Caffe to MXNet 49 """ 50 param_string = "num_filter=%d" % param.num_output 51 52 pad_w = 0 53 pad_h = 0 54 if isinstance(param.pad, int): 55 pad = param.pad 56 param_string += ", pad=(%d, %d)" % (pad, pad) 57 else: 58 if len(param.pad) > 0: 59 pad = param.pad[0] 60 param_string += ", pad=(%d, %d)" % (pad, pad) 61 else: 62 if isinstance(param.pad_w, int): 63 pad_w = param.pad_w 64 if isinstance(param.pad_h, int): 65 pad_h = param.pad_h 66 param_string += ", pad=(%d, %d)" % (pad_h, pad_w) 67 68 if isinstance(param.kernel_size, int): 69 kernel_size = param.kernel_size 70 param_string += ", kernel=(%d,%d)" % (kernel_size, kernel_size) 71 else: 72 if len(param.kernel_size) > 0: 73 kernel_size = param.kernel_size[0] 74 param_string += ", kernel=(%d,%d)" % (kernel_size, kernel_size) 75 else: 76 assert isinstance(param.kernel_w, int) 77 kernel_w = param.kernel_w 78 assert isinstance(param.kernel_h, int) 79 kernel_h = param.kernel_h 80 param_string += ", kernel=(%d,%d)" % (kernel_h, kernel_w) 81 82 stride = 1 83 if isinstance(param.stride, int): 84 stride = param.stride 85 else: 86 stride = 1 if len(param.stride) == 0 else param.stride[0] 87 88 param_string += ", stride=(%d,%d)" % (stride, stride) 89 90 dilate = 1 91 if hasattr(param, 'dilation'): 92 if isinstance(param.dilation, int): 93 dilate = param.dilation 94 else: 95 dilate = 1 if len(param.dilation) == 0 else param.dilation[0] 96 97 param_string += ", no_bias=%s" % (not param.bias_term) 98 99 # deal with dilation. Won't be in deconvolution 100 if dilate > 1: 101 param_string += ", dilate=(%d, %d)" % (dilate, dilate) 102 103 if isinstance(param.group, int): 104 if param.group != 1: 105 param_string += ", num_group=%d" % param.group 106 107 return param_string 108 109def _convert_pooling_param(param): 110 """Convert the pooling layer parameter 111 """ 112 param_string = "pooling_convention='full', " 113 if param.global_pooling: 114 param_string += "global_pool=True, kernel=(1,1)" 115 else: 116 param_string += "pad=(%d,%d), kernel=(%d,%d), stride=(%d,%d)" % ( 117 param.pad, param.pad, param.kernel_size, param.kernel_size, 118 param.stride, param.stride) 119 if param.pool == 0: 120 param_string += ", pool_type='max'" 121 elif param.pool == 1: 122 param_string += ", pool_type='avg'" 123 else: 124 raise ValueError("Unknown Pooling Method!") 125 return param_string 126 127def _parse_proto(prototxt_fname): 128 """Parse Caffe prototxt into symbol string 129 """ 130 proto = caffe_parser.read_prototxt(prototxt_fname) 131 132 # process data layer 133 input_name, input_dim, layers = _get_input(proto) 134 # only support single input, so always use `data` as the input data 135 mapping = {input_name: 'data'} 136 need_flatten = {input_name: False} 137 symbol_string = "import mxnet as mx\ndata = mx.symbol.Variable(name='data')\n" 138 139 flatten_count = 0 140 output_name = "" 141 prev_name = None 142 _output_name = {} 143 144 # convert reset layers one by one 145 for i, layer in enumerate(layers): 146 type_string = '' 147 param_string = '' 148 skip_layer = False 149 name = re.sub('[-/]', '_', layer.name) 150 for k in range(len(layer.bottom)): 151 if layer.bottom[k] in _output_name: 152 _output_name[layer.bottom[k]]['count'] = _output_name[layer.bottom[k]]['count']+1 153 else: 154 _output_name[layer.bottom[k]] = {'count':0} 155 for k in range(len(layer.top)): 156 if layer.top[k] in _output_name: 157 _output_name[layer.top[k]]['count'] = _output_name[layer.top[k]]['count']+1 158 else: 159 _output_name[layer.top[k]] = {'count':0, 'name':name} 160 if layer.type == 'Convolution' or layer.type == 4: 161 type_string = 'mx.symbol.Convolution' 162 param_string = _convert_conv_param(layer.convolution_param) 163 need_flatten[name] = True 164 if layer.type == 'Deconvolution' or layer.type == 39: 165 type_string = 'mx.symbol.Deconvolution' 166 param_string = _convert_conv_param(layer.convolution_param) 167 need_flatten[name] = True 168 if layer.type == 'Pooling' or layer.type == 17: 169 type_string = 'mx.symbol.Pooling' 170 param_string = _convert_pooling_param(layer.pooling_param) 171 need_flatten[name] = True 172 if layer.type == 'ReLU' or layer.type == 18: 173 type_string = 'mx.symbol.Activation' 174 param_string = "act_type='relu'" 175 param = layer.relu_param 176 if hasattr(param, 'negative_slope'): 177 if param.negative_slope > 0: 178 type_string = 'mx.symbol.LeakyReLU' 179 param_string = "act_type='leaky', slope=%f" % param.negative_slope 180 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 181 if layer.type == 'TanH' or layer.type == 23: 182 type_string = 'mx.symbol.Activation' 183 param_string = "act_type='tanh'" 184 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 185 if layer.type == 'Sigmoid' or layer.type == 19: 186 type_string = 'mx.symbol.Activation' 187 param_string = "act_type='sigmoid'" 188 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 189 if layer.type == 'LRN' or layer.type == 15: 190 type_string = 'mx.symbol.LRN' 191 param = layer.lrn_param 192 param_string = "alpha=%f, beta=%f, knorm=%f, nsize=%d" % ( 193 param.alpha, param.beta, param.k, param.local_size) 194 need_flatten[name] = True 195 if layer.type == 'InnerProduct' or layer.type == 14: 196 type_string = 'mx.symbol.FullyConnected' 197 param = layer.inner_product_param 198 param_string = "num_hidden=%d, no_bias=%s" % ( 199 param.num_output, not param.bias_term) 200 need_flatten[name] = False 201 if layer.type == 'Dropout' or layer.type == 6: 202 type_string = 'mx.symbol.Dropout' 203 param = layer.dropout_param 204 param_string = "p=%f" % param.dropout_ratio 205 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 206 if layer.type == 'Softmax' or layer.type == 20: 207 type_string = 'mx.symbol.SoftmaxOutput' 208 if layer.type == 'Flatten' or layer.type == 8: 209 type_string = 'mx.symbol.Flatten' 210 need_flatten[name] = False 211 if layer.type == 'Split' or layer.type == 22: 212 type_string = 'split' # will process later 213 if layer.type == 'Concat' or layer.type == 3: 214 type_string = 'mx.symbol.Concat' 215 need_flatten[name] = True 216 if layer.type == 'Crop': 217 type_string = 'mx.symbol.Crop' 218 need_flatten[name] = True 219 param_string = 'center_crop=True' 220 if layer.type == 'BatchNorm': 221 type_string = 'mx.symbol.BatchNorm' 222 param = layer.batch_norm_param 223 # CuDNN requires eps to be greater than 1e-05 224 # We compensate for this change in convert_model 225 epsilon = param.eps 226 if (epsilon <= 1e-05): 227 epsilon = 1e-04 228 # if next layer is scale, don't fix gamma 229 fix_gamma = layers[i+1].type != 'Scale' 230 param_string = 'use_global_stats=%s, fix_gamma=%s, eps=%f' % ( 231 param.use_global_stats, fix_gamma, epsilon) 232 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 233 if layer.type == 'Scale': 234 assert layers[i-1].type == 'BatchNorm' 235 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 236 skip_layer = True 237 prev_name = re.sub('[-/]', '_', layers[i-1].name) 238 if layer.type == 'PReLU': 239 type_string = 'mx.symbol.LeakyReLU' 240 param = layer.prelu_param 241 param_string = "act_type='prelu', slope=%f" % param.filler.value 242 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 243 if layer.type == 'Eltwise': 244 type_string = 'mx.symbol.broadcast_add' 245 param = layer.eltwise_param 246 param_string = "" 247 need_flatten[name] = False 248 if layer.type == 'Reshape': 249 type_string = 'mx.symbol.Reshape' 250 need_flatten[name] = False 251 param = layer.reshape_param 252 param_string = "shape=(%s)" % (','.join(param.shape.dim),) 253 if layer.type == 'AbsVal': 254 type_string = 'mx.symbol.abs' 255 need_flatten[name] = need_flatten[mapping[layer.bottom[0]]] 256 257 if skip_layer: 258 assert len(layer.bottom) == 1 259 symbol_string += "%s = %s\n" % (name, prev_name) 260 elif type_string == '': 261 raise ValueError('Unknown layer %s!' % layer.type) 262 elif type_string != 'split': 263 bottom = layer.bottom 264 if param_string != "": 265 param_string = ", " + param_string 266 if len(bottom) == 1: 267 if need_flatten[mapping[bottom[0]]] and type_string == 'mx.symbol.FullyConnected': 268 flatten_name = "flatten_%d" % flatten_count 269 symbol_string += "%s=mx.symbol.Flatten(name='%s', data=%s)\n" % ( 270 flatten_name, flatten_name, mapping[bottom[0]]) 271 flatten_count += 1 272 need_flatten[flatten_name] = False 273 bottom[0] = flatten_name 274 mapping[bottom[0]] = bottom[0] 275 symbol_string += "%s = %s(name='%s', data=%s %s)\n" % ( 276 name, type_string, name, mapping[bottom[0]], param_string) 277 else: 278 if layer.type == 'Eltwise' and param.operation == 1 and len(param.coeff) > 0: 279 symbol_string += "%s = " % name 280 symbol_string += " + ".join(["%s * %s" % ( 281 mapping[bottom[i]], param.coeff[i]) for i in range(len(param.coeff))]) 282 symbol_string += "\n" 283 else: 284 symbol_string += "%s = %s(name='%s', *[%s] %s)\n" % ( 285 name, type_string, name, ','.join( 286 [mapping[x] for x in bottom]), param_string) 287 for j in range(len(layer.top)): 288 mapping[layer.top[j]] = name 289 output_name = name 290 output_name = [] 291 for i in _output_name: 292 if 'name' in _output_name[i] and _output_name[i]['count'] == 0: 293 output_name.append(_output_name[i]['name']) 294 295 return symbol_string, output_name, input_dim 296 297def convert_symbol(prototxt_fname): 298 """Convert caffe model definition into Symbol 299 300 Parameters 301 ---------- 302 prototxt_fname : str 303 Filename of the prototxt file 304 305 Returns 306 ------- 307 Symbol 308 Converted Symbol 309 tuple 310 Input shape 311 """ 312 sym, output_name, input_dim = _parse_proto(prototxt_fname) 313 exec(sym) # pylint: disable=exec-used 314 _locals = locals() 315 ret = [] 316 for i in output_name: 317 exec("ret = " + i, globals(), _locals) # pylint: disable=exec-used 318 ret.append(_locals['ret']) 319 ret = mx.sym.Group(ret) 320 return ret, input_dim 321 322def main(): 323 parser = argparse.ArgumentParser( 324 description='Convert caffe prototxt into Symbol') 325 parser.add_argument('prototxt', help='The prototxt filename') 326 parser.add_argument('output', help='filename for the output json file') 327 args = parser.parse_args() 328 329 sym, _ = convert_symbol(args.prototxt) 330 sym.save(args.output) 331 332if __name__ == '__main__': 333 main() 334