1#---------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4#---------------------------------------------------------------------------------------------- 5 6import os 7 8import math 9import mxnet as mx 10import numpy as np 11from mmdnn.conversion.common.IR.IR_graph import IRGraph, IRGraphNode 12import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2 13from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType 14from mmdnn.conversion.common.DataStructure.emitter import Emitter 15from mmdnn.conversion.common.utils import * 16from mmdnn.conversion.rewriter.folder import Folder 17 18class MXNetEmitter(Emitter): 19 20 dtype_map = { 21 graph_pb2.DT_FLOAT16 : "float16", 22 graph_pb2.DT_FLOAT32 : "float32", 23 graph_pb2.DT_FLOAT64 : "float64", 24 graph_pb2.DT_INT32 : "int32", 25 graph_pb2.DT_UINT8 : "uint8" 26 } 27 28 activation_map = { 29 "relu" : "Relu", 30 "sigmoid" : "Sigmoid", 31 "tanh" : "Tanh", 32 "elu" : "Elu" 33 } 34 35 transpose_map = { 36 1 : 2, 37 2 : 3, 38 -1 : 1 39 } 40 41 naive_scope_pattern = [] 42 43 channels_last = ['NDHWC', 'NHWC'] 44 45 def __init__(self, model): 46 super(MXNetEmitter, self).__init__() 47 from six import string_types as _string_types 48 49 if isinstance(model, _string_types): 50 network_path = model 51 self.weight_loaded = False 52 elif len(model) == 3: 53 network_path = model[0] 54 weight_path = model[1] 55 self.output_weights_file = model[2] 56 self.output_weights = dict() 57 self._load_weights(weight_path) 58 self.weights = self.weights_dict 59 else: 60 raise ValueError("the # of input arguments [{}] is not supported" % len(model)) 61 62 self.IR_graph = IRGraph(network_path) 63 self.IR_graph.build() 64 65 folder = Folder(self.IR_graph, self.weights) 66 folder.fold() 67 68 @property 69 def header_code(self): 70 return """import mxnet as mx 71import numpy as np 72import math 73 74# mxnet-cpu only support channel first, default convert the model and weight as channel first 75 76def RefactorModel(): 77""" 78 79 80 def gen_code(self, phase): 81 self.IR_layer_map = dict() 82 self.add_body(0, self.header_code) 83 for layer in self.IR_graph.topological_sort: 84 self.IR_layer_map[layer] = self.IR_graph.get_node(layer) 85 86 shape = dict() 87 for layer in self.IR_graph.topological_sort: 88 current_node = self.IR_graph.get_node(layer) 89 node_type = current_node.type 90 91 92 if len(current_node.in_edges) == 0: 93 current_node.in_edges.append('data') 94 95 if node_type.lower() in MXNetEmitter.activation_map: 96 func = getattr(self, "emit_Activation") 97 line = func(current_node, MXNetEmitter.activation_map[node_type.lower()].lower()) 98 self.add_body(1, line) 99 100 elif hasattr(self, "emit_" + node_type): 101 func = getattr(self, "emit_" + node_type) 102 line = func(current_node) 103 if line != None: 104 self.add_body(1, line) 105 else: 106 print("MXNet Emitter has not supported operator [%s]." % (node_type)) 107 self.emit_UNKNOWN(current_node) 108 109 if node_type == "DataInput": 110 cur_shape = list() 111 first = True 112 for dim in current_node.IR_layer.attr["shape"].shape.dim: 113 if dim.size == -1 and first: 114 cur_shape.append(1) 115 print("Detect input layer [{}] using infer batch size, set it as default value [1]".format(current_node.name)) 116 else: 117 if dim.size == -1: 118 print("Warning: user should change input size manually") 119 cur_shape.append(dim.size) 120 first = False 121 122 cur_shape.insert(1, cur_shape.pop()) 123 shape[current_node.name] = ', '.join('%s' % i for i in cur_shape) 124 self.input_name_shape = {current_node.name: tuple(cur_shape)} 125 126 127 if self.weight_loaded: 128 fullpath = os.path.abspath(self.output_weights_file) 129 dirname = os.path.dirname(fullpath) 130 if not os.path.exists(dirname): 131 os.makedirs(dirname) 132 with open(self.output_weights_file, 'wb') as outfile: 133 np.save(outfile, self.output_weights) 134 135 comment = "\n # if a GPU is available, change mx.cpu() to mx.gpu()" 136 # We use the real_name for specifying the input layer in data_names 137 # since MXNet API wants the actual name of the layer. On the other 138 # hand, the module API wants the last symbol in the symbol chain, so 139 # for the output node we need to use the actual python variable name 140 # of the last layer (real_variable_name). 141 last_line = "{:<15} = mx.mod.Module(symbol = {}, context = mx.cpu(), data_names = ['{}'])".format( 142 "model", 143 ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.output_layers if self.IR_graph.get_node(name).type !='Pack' and self.IR_graph.get_node(name).type != 'Shape']), 144 ', '.join([self.IR_graph.get_node(name).real_name for name in self.IR_graph.input_layers if self.IR_graph.get_node(name).type != 'Const'])) 145 146 self.add_body(1, comment) 147 self.add_body(1, last_line) 148 self.add_body(1, "return model") 149 150 151 self.add_body(0, "") 152 for code in self.layers_codes.values(): 153 self.add_body(0, code) 154 155 weight_code = "" 156 if not self.weight_loaded: 157 weight_code += "# emitter does not detect any import weights, you may generate weights file manually\n" 158 159 weight_code += self.gen_weight_code(shape, phase) 160 161 main_code = "if __name__ == '__main__':\n model = RefactorModel()\n" 162 if self.weight_loaded: 163 main_code += " # remember to adjust params path\n model = deploy_weight(model, '{}')\n".format(self.output_weights_file) 164 165 if phase == 'train': 166 train_code = """def train(model): 167 import logging 168 logging.getLogger().setLevel(logging.DEBUG) 169 model.fit(train_iter, # train data 170 eval_data = val_iter, # validation data 171 optimizer = 'sgd', # Defaults to 'sgd' 172 optimizer_params = {'learning_rate':0.01}, # use fixed learning rate 173 eval_metric = 'acc', # report accuracy during training, other possible predefined metrics are: 'ce', 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy' 174 batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches 175 num_epoch = 10) # train for at most 10 dataset passes\n\n 176""" 177 code = self.body_code + weight_code + train_code + main_code 178 else: 179 test_code = """from collections import namedtuple 180Batch = namedtuple('Batch', ['data']) 181 182 183def get_image(url, show=False): 184 import cv2 185 # download and show the image 186 fname = mx.test_utils.download(url) 187 img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB) 188 if img is None: 189 return None 190 if show: 191 import matplotlib.pyplot as plt 192 plt.imshow(img) 193 plt.axis('off') 194 # convert into format (batch, RGB, width, height) 195 img = cv2.resize(img, (224, 224)) 196 img = np.swapaxes(img, 0, 2) 197 img = np.swapaxes(img, 1, 2) 198 img = img[np.newaxis, :] 199 return img 200 201 202def predict(model, labels, url): 203 # to show the image, change the argument show into True 204 img = get_image(url, show = False) 205 # compute the predict probabilities 206 model.forward(Batch([mx.nd.array(img)])) 207 prob = model.get_outputs()[0].asnumpy() 208 # print the top-5 209 prob = np.squeeze(prob) 210 a = np.argsort(prob)[::-1] 211 for i in a[0:5]: 212 print('prbability = %f, class = %s' %(prob[i], labels[i]))\n\n 213""" 214 215 main_code += """ 216 # # call function predict 217 # with open('synset.txt', 'r') as f: 218 # labels = [l.rstrip() for l in f] 219 # predict(model, labels, 'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg') 220""" 221 222 code = self.body_code + weight_code + test_code + main_code 223 224 return code 225 226 227 def gen_weight_code(self, shape, phase): 228 str = "def deploy_weight(model, weight_file):\n" 229 str += """ 230 if weight_file == None: 231 return 232 233 try: 234 weights_dict = np.load(weight_file, allow_pickle=True).item() 235 except: 236 weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item() 237 238 arg_params = dict() 239 aux_params = dict() 240 for weight_name, weight_data in weights_dict.items(): 241 weight_name = str(weight_name) 242 if "moving" in weight_name: 243 aux_params[weight_name] = mx.nd.array(weight_data) 244 else: 245 arg_params[weight_name] = mx.nd.array(weight_data) 246 247""" 248 if phase == 'train': 249 str += " model.bind(for_training = True, data_shapes = [" 250 else: 251 str += " model.bind(for_training = False, data_shapes = [" 252 first = True 253 for k, v in shape.items(): 254 if not first: 255 str += ", " 256 str += "('" + k + "', " + "(" + v + "))" 257 first = False 258 str += "])\n" 259 str += " model.set_params(arg_params = arg_params, aux_params = aux_params, allow_missing = True, allow_extra=True)\n\n return model\n\n\n" 260 return str 261 262 263 @staticmethod 264 def calculate_same_pad(data_shape, kernel, stride): 265 if (data_shape % stride == 0): 266 pad = max(kernel - stride, 0) 267 else: 268 pad = max(kernel - (data_shape % stride), 0) 269 if pad % 2 == 0: 270 return False, pad 271 else: 272 return True, pad 273 274 275 @staticmethod 276 def transfer_pad(pad_list): 277 defuse_pad = False 278 pad = list() 279 280 assert len(pad_list) % 2 == 0 281 mid = int(len(pad_list)/2) 282 pad_first = pad_list[1:mid-1] 283 pad_second = pad_list[mid+1:-1] 284 285 for i in range(0, mid-2): 286 if not pad_first[i] == pad_second[i]: 287 defuse_pad = True 288 289 if defuse_pad: 290 pad.extend([0] * 4) 291 for i in range(0, mid-2): 292 pad.extend([pad_first[i], pad_second[i]]) 293 else: 294 pad = pad_first 295 296 return defuse_pad, pad 297 298 299 @staticmethod 300 def transpose(data, dim): 301 if dim == 1: 302 data = data.transpose((2, 1, 0)) 303 elif dim == 2: 304 data = data.transpose((3, 2, 0, 1)) 305 elif dim == 3: 306 data = data.transpose((4, 3, 0, 1, 2)) 307 else: 308 raise ValueError("The weight of dim {} cannot transpose" % dim) 309 310 return data 311 312 313 def set_pad(self, IR_node, code, pad, _max_pool): 314 if _max_pool: 315 constant_value = "float('-inf')" 316 else: 317 constant_value = "0.0" 318 319 code = "{:<15} = mx.sym.pad(data = {}, mode = 'constant', pad_width={}, constant_value = {}, name = '{}')".format( 320 IR_node.variable_name + "_pad", 321 self.parent_variable_name(IR_node), 322 tuple(pad), 323 constant_value, 324 IR_node.name + "_pad") 325 326 for e in IR_node.in_edges: 327 e = e.split(':')[0] 328 if e == 'data': 329 continue 330 self.IR_layer_map[e].out_edges = [x if not self.IR_layer_map[x.split(':')[0]].name == IR_node.variable_name else IR_node.variable_name + "_pad" for x in self.IR_layer_map[e].out_edges] 331 332 return code 333 334 335 def emit_UNKNOWN(self, IR_node): 336 print(IR_node.name) 337 338 339 def emit_FullyConnected(self, IR_node): 340 if self.weight_loaded: 341 weight_dict = self.weights[IR_node.name] 342 parent = self.IR_graph.get_parent(IR_node.name, [0]) 343 while parent.type == "Flatten" or parent.type == 'Dropout': 344 parent = self.IR_graph.get_parent(parent.name, [0]) 345 dim = len(parent.layer.attr['_output_shapes'].list.shape[0].dim) 346 if dim > 2: 347 original_dims = weight_dict['weights'].shape 348 dims = [i.size for i in parent.layer.attr['_output_shapes'].list.shape[0].dim[1:]] + [-1] 349 weight_dict['weights'] = np.reshape(weight_dict['weights'], dims) 350 weight_dict['weights'] = np.transpose(weight_dict['weights'], [dim - 2] + list(range(0, dim - 2)) + [dim - 1]) 351 weight_dict['weights'] = np.reshape(weight_dict['weights'], original_dims) 352 self.output_weights[IR_node.name + "_weight"] = weight_dict['weights'].transpose((1, 0)) 353 354 num_hidden = IR_node.IR_layer.attr["units"].i 355 no_bias = not IR_node.IR_layer.attr["use_bias"].b 356 if not no_bias and self.weight_loaded: 357 self.output_weights[IR_node.name + "_bias"] = weight_dict['bias'] 358 359 code = "{:<15} = mx.sym.FullyConnected(data = {}, num_hidden = {}, no_bias = {}, name = '{}')".format( 360 IR_node.variable_name, 361 self.parent_variable_name(IR_node), 362 num_hidden, 363 no_bias, 364 IR_node.name) 365 366 return code 367 368 369 def _emit_convolution(self, IR_node, pattern): 370 if self.weight_loaded: 371 weight_dict = self.weights[IR_node.name] 372 weights = weight_dict['weights'] 373 374 dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2 375 376 kernel = list() 377 for idx in range(0, dim): 378 kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx]) 379 380 stride = list() 381 for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]: 382 stride.append(e) 383 384 dilate = list() 385 for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]: 386 dilate.append(e) 387 if dilate == []: dilate = [1, 1] 388 dilate = ', '.join('%s' % i for i in dilate) 389 390 defuse_pad = False 391 pad = list() 392 if "pads" in IR_node.IR_layer.attr: 393 output_shape = list() 394 for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim: 395 output_shape.append(e.size) 396 397 # print("Warning: MXNet Convolution Layer pad does not match IR Convolution Layer pad") 398 defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i) 399 400 num_filter = 0 401 if pattern == "Deconvolution": 402 num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2] 403 else: 404 num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1] 405 406 use_bias = IR_node.get_attr('use_bias', False) 407 if use_bias and self.weight_loaded: 408 self.output_weights[IR_node.name + "_bias"] = weight_dict['bias'] 409 410 if pattern == "DepthwiseConv": 411 num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2] 412 num_filter = num_filter * num_group 413 pattern = "Convolution" 414 if self.weight_loaded: 415 weights = np.swapaxes(weights, -1, -2) 416 417 else: 418 num_group = IR_node.get_attr('group', 1) 419 420 # layout = IR_node.IR_layer.attr["data_format"].s 421 if dim == 1: 422 layout = 'NCW' 423 elif dim == 2: 424 layout = 'NCHW' 425 elif dim == 3: 426 layout = 'NCDHW' 427 428 if self.weight_loaded: 429 # if layout not in MXNetEmitter.channels_last: 430 weights = MXNetEmitter.transpose(weights, dim) 431 self.output_weights[IR_node.name + "_weight"] = weights 432 433 code = "" 434 if not defuse_pad: 435 code += "{:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), pad={}, num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format( 436 IR_node.variable_name, 437 pattern, 438 self.parent_variable_name(IR_node), 439 tuple(kernel), 440 tuple(stride), 441 dilate, 442 tuple(pad), 443 num_filter, 444 num_group, 445 not use_bias, 446 layout, 447 IR_node.name) 448 else: 449 code += self.set_pad(IR_node, code, pad, False) 450 code += "\n {:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format( 451 IR_node.variable_name, 452 pattern, 453 IR_node.variable_name + "_pad", 454 tuple(kernel), 455 tuple(stride), 456 dilate, 457 num_filter, 458 num_group, 459 not use_bias, 460 layout, 461 IR_node.name) 462 463 return code 464 465 466 def emit_Conv(self, IR_node): 467 return self._emit_convolution(IR_node, "Convolution") 468 469 470 def emit_DepthwiseConv(self, IR_node): 471 return self._emit_convolution(IR_node, "DepthwiseConv") 472 473 474 def emit_ConvTranspose(self, IR_node): 475 return self._emit_convolution(IR_node, "Deconvolution") 476 477 478 def emit_DataInput(self, IR_node): 479 shape = list() 480 shape.extend(IR_node.IR_layer.attr["shape"].list.i) 481 482 code = "{:<15} = mx.sym.var('{}')".format(IR_node.variable_name, IR_node.name) 483 return code 484 485 486 # Add LeakyReLU Elu(slope not support) 487 def emit_Activation(self, IR_node, act_type): 488 489 act_type = act_type 490 func_name = "" 491 492 if act_type == "elu": 493 func_name = "LeakyReLU" 494 else: 495 func_name = "Activation" 496 497 code = "{:<15} = mx.sym.{}(data = {}, act_type = '{}', name = '{}')".format( 498 IR_node.variable_name, 499 func_name, 500 self.parent_variable_name(IR_node), 501 act_type, 502 IR_node.name) 503 504 return code 505 506 507 def emit_BatchNorm(self, IR_node): 508 IR_node_after = self.IR_graph.get_son(IR_node.name, [0]) 509 if IR_node_after.type == 'Scale': 510 if self.weight_loaded: 511 weight_dict = self.weights[IR_node.name] 512 weight_dict_scale = self.weights[IR_node_after.name] 513 514 # axis = IR_node.IR_layer.attr["axis"].i 515 axis = 1 516 eps = IR_node.IR_layer.attr["epsilon"].f 517 momentum = IR_node.IR_layer.attr["momentum"].f 518 519 fix_gamma = not IR_node.IR_layer.attr["scale"].b 520 521 if self.weight_loaded: 522 if not fix_gamma: 523 # self.output_weights[IR_node.name + "_gamma"] = np.multiply(weight_dict['scale'], weight_dict_scale['scale']) 524 # self.output_weights[IR_node.name + "_beta"] = np.multiply(weight_dict['bias'], weight_dict_scale['scale']) + weight_dict_scale['bias'] 525 self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale'] 526 self.output_weights[IR_node.name + "_beta"] = weight_dict['bias'] 527 528 # not supported yet 529 use_global_stats = "False" 530 if self.weight_loaded: 531 self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var'] 532 self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean'] 533 534 code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format( 535 IR_node.variable_name, 536 self.parent_variable_name(IR_node), 537 axis, 538 eps, 539 momentum, 540 fix_gamma, 541 use_global_stats, 542 IR_node.name) 543 544 return code 545 546 else: 547 if self.weight_loaded: 548 weight_dict = self.weights[IR_node.name] 549 550 # axis = IR_node.IR_layer.attr["axis"].i 551 axis = 1 552 eps = IR_node.IR_layer.attr["epsilon"].f 553 momentum = IR_node.IR_layer.attr["momentum"].f 554 555 fix_gamma = not IR_node.IR_layer.attr["scale"].b 556 557 if self.weight_loaded: 558 if not fix_gamma: 559 self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale'] 560 self.output_weights[IR_node.name + "_beta"] = weight_dict['bias'] 561 562 # not supported yet 563 use_global_stats = "False" 564 if self.weight_loaded: 565 self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var'] 566 self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean'] 567 568 code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format( 569 IR_node.variable_name, 570 self.parent_variable_name(IR_node), 571 axis, 572 eps, 573 momentum, 574 fix_gamma, 575 use_global_stats, 576 IR_node.name) 577 578 return code 579 580 def emit_Scale(self, IR_node): 581 if self.weight_loaded: 582 weight_dict = self.weights[IR_node.name] 583 584 # axis = IR_node.IR_layer.attr["axis"].i 585 axis = 1 586 eps = 0.0 587 momentum = 0.0 588 589 fix_gamma = not IR_node.IR_layer.attr["scale"].b 590 591 if self.weight_loaded: 592 if not fix_gamma: 593 self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale'] 594 self.output_weights[IR_node.name + "_beta"] = weight_dict['bias'] 595 596 # not supported yet 597 use_global_stats = "False" 598 if self.weight_loaded: 599 self.output_weights[IR_node.name + "_moving_var"] = weight_dict['scale_var'] 600 self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['scale_mean'] 601 602 code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format( 603 IR_node.variable_name, 604 self.parent_variable_name(IR_node), 605 axis, 606 eps, 607 momentum, 608 fix_gamma, 609 use_global_stats, 610 IR_node.name) 611 612 return code 613 614 615 616 def emit_Pool(self, IR_node): 617 618 global_pool = IR_node.IR_layer.attr["global_pooling"].b 619 620 kernel = list() 621 if global_pool: 622 kernel = [1] * (len(IR_node.IR_layer.attr["strides"].list.i) - 2) 623 else: 624 for e in IR_node.IR_layer.attr["kernel_shape"].list.i[1:-1]: 625 kernel.append(e) 626 627 pool_type = IR_node.get_attr('pooling_type').lower() 628 629 stride = list() 630 for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]: 631 stride.append(e) 632 633 defuse_pad = False 634 pad = list() 635 if "pads" in IR_node.IR_layer.attr: 636 output_shape = list() 637 for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim: 638 output_shape.append(e.size) 639 640 # print("Warning: MXNet Pooling Layer pad does not match IR Pooling Layer pad") 641 defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i) 642 code = "" 643 if not defuse_pad: 644 code += "{:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, pad={}, name = '{}')".format( 645 IR_node.variable_name, 646 self.parent_variable_name(IR_node), 647 global_pool, 648 tuple(kernel), 649 pool_type, 650 tuple(stride), 651 tuple(pad), 652 IR_node.name) 653 else: 654 code += self.set_pad(IR_node, code, pad, pool_type == "max") 655 code += "\n {:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, name = '{}')".format( 656 IR_node.variable_name, 657 IR_node.variable_name + "_pad", 658 global_pool, 659 tuple(kernel), 660 pool_type, 661 tuple(stride), 662 IR_node.name) 663 664 return code 665 666 667 def emit_SoftmaxOutput(self, IR_node): 668 669 code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format( 670 IR_node.variable_name, 671 self.parent_variable_name(IR_node) 672 ) 673 674 return code 675 676 677 def emit_Softmax(self, IR_node): 678 679 code = "" 680 681 if len(IR_node.out_edges) == 0: 682 code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format( 683 IR_node.variable_name, 684 self.parent_variable_name(IR_node)) 685 else: 686 axis = IR_node.IR_layer.attr["dim"].i 687 code = "{:<15} = mx.sym.softmax(data = {}, axis = {}, name = '{}')".format( 688 IR_node.variable_name, 689 self.parent_variable_name(IR_node), 690 axis, 691 IR_node.name) 692 693 return code 694 695 696 def emit_Squeeze(self, IR_node): 697 return self.emit_Flatten(IR_node) 698 699 700 # def emit_ConvTranspose(self, IR_node): 701 # if self.weight_loaded: 702 # weight_dict = self.weights[IR_node.name] 703 # weights = weight_dict['weights'] 704 705 # dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2 706 707 # kernel = list() 708 # for idx in range(0, dim): 709 # kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx]) 710 711 # stride = list() 712 # for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]: 713 # stride.append(e) 714 715 # dilate = list() 716 # for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]: 717 # dilate.append(e) 718 # dilate = ', '.join('%s' % i for i in dilate) 719 720 # defuse_pad = False 721 # pad = list() 722 # if "pads" in IR_node.IR_layer.attr: 723 # output_shape = list() 724 # for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim: 725 # output_shape.append(e.size) 726 727 # # print("Warning: MXNet Deconvolution Layer pad does not match IR Deconvolution Layer pad") 728 # defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i) 729 # pad = ', '.join('%s' % i for i in pad) 730 731 # kernel = ', '.join('%s' % i for i in kernel) 732 # stride = ', '.join('%s' % i for i in stride) 733 734 # num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2] 735 # no_bias = not IR_node.IR_layer.attr["use_bias"].b 736 # if not no_bias and self.weight_loaded: 737 # self.output_weights[IR_node.replace_scope(IR_node.name) + "_bias"] = weight_dict['bias'] 738 739 # # layout = IR_node.IR_layer.attr["data_format"].s 740 # if dim == 1: 741 # layout = 'NCW' 742 # elif dim == 2: 743 # layout = 'NCHW' 744 # elif dim == 3: 745 # layout = 'NCDHW' 746 747 # if self.weight_loaded: 748 # # if layout not in MXNetEmitter.channels_last: 749 # weights = MXNetEmitter.transpose(weights, dim) 750 # self.output_weights[IR_node.replace_scope(IR_node.name) + "_weight"] = weights 751 752 # code = "" 753 # if not defuse_pad: 754 # code = "{:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), pad = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format( 755 # IR_node.replace_scope(IR_node.name), 756 # IR_node.replace_scope(IR_node.in_edges[0]), 757 # kernel, 758 # stride, 759 # dilate, 760 # pad, 761 # num_filter, 762 # no_bias, 763 # layout, 764 # IR_node.replace_scope(IR_node.name)) 765 # else: 766 # code = self.set_pad(IR_node, code, pad) 767 # code += "\n {:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format( 768 # IR_node.replace_scope(IR_node.name), IR_node.replace_scope(IR_node.name) + "_pad", kernel, stride, dilate, num_filter, no_bias, layout, IR_node.replace_scope(IR_node.name)) 769 770 # return code 771 772 773 def emit_Embedding(self, IR_node): 774 775 input_dim = IR_node.IR_layer.attr["input_dim"].i 776 output_dim = IR_node.IR_layer.attr["output_dim"].i 777 dtype = MXNetEmitter.dtype_map.get(IR_node.layer.attr["dtype"].type, "float32") 778 779 weight_dict = self.weights[IR_node.name] 780 781 if self.weight_loaded: 782 self.output_weights[IR_node.name + "_weight"] = weight_dict['weights'] 783 784 code = "{:<15} = mx.sym.Embedding(data = {}, input_dim = {}, output_dim = {}, dtype = '{}', name = '{}')".format( 785 IR_node.variable_name, 786 self.parent_variable_name(IR_node), 787 input_dim, 788 output_dim, 789 dtype, 790 IR_node.name) 791 792 return code 793 794 795 def emit_LeakyRelu(self, IR_node): 796 alpha = IR_node.IR_layer.attr['alpha'].f 797 code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, name = '{}')".format( 798 IR_node.variable_name, 799 self.parent_variable_name(IR_node), 800 alpha, 801 IR_node.name 802 ) 803 return code 804 805 def emit_PRelu(self, IR_node): 806 slope = IR_node.get_attr('gamma') 807 code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, act_type = '{}', name = '{}')".format( 808 IR_node.variable_name, 809 self.parent_variable_name(IR_node), 810 slope, 811 'prelu', 812 IR_node.name 813 ) 814 return code 815 816 def emit_Elu(self, IR_node): 817 alpha = IR_node.IR_layer.attr['alpha'].f 818 code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, act_type = {}, name = '{}')".format( 819 IR_node.variable_name, 820 self.parent_variable_name(IR_node), 821 alpha, 822 'elu', 823 IR_node.name 824 ) 825 return code 826 827 def emit_Dropout(self, IR_node): 828 p = IR_node.IR_layer.attr["keep_prob"].f 829 mode = IR_node.IR_layer.attr["mode"].s.lower().decode() if 'mode' in IR_node.layer.attr else 'training' 830 code = "{:<15} = mx.sym.Dropout(data = {}, p = {}, mode = '{}', name = '{}')".format( 831 IR_node.variable_name, 832 self.parent_variable_name(IR_node), 833 p, 834 mode, 835 IR_node.name) 836 837 return code 838 839 840 # reverse cannot support yet 841 def emit_Reshape(self, IR_node): 842 shape = list() 843 for e in IR_node.IR_layer.attr["shape"].list.i: 844 shape.append(e) 845 shape = ', '.join('%s' % i for i in shape) 846 reverse = False 847 848 code = "{:<15} = mx.sym.reshape(data = {}, shape = ({}), reverse = {}, name = '{}')".format( 849 IR_node.variable_name, 850 self.parent_variable_name(IR_node), 851 shape, 852 reverse, 853 IR_node.name) 854 855 return code 856 857 858 def emit_Flatten(self, IR_node): 859 # code = "{:<15} = mx.sym.transpose(data = {}, axes = (0, 2, 3, 1))\n".format("trans", self.parent_variable_name(IR_node)) 860 code = "{:<15} = mx.sym.flatten(data = {}, name = '{}')".format( 861 IR_node.variable_name, 862 self.parent_variable_name(IR_node), 863 IR_node.name) 864 865 return code 866 867 868 @staticmethod 869 def _convert_axis(IR_node, axis): 870 ndim = len(IR_node.layer.attr['_output_shapes'].list.shape[0].dim) 871 if axis == 0: 872 return 0 873 elif axis == ndim - 1: 874 return 1 875 else: 876 return axis + 1 877 878 879 def emit_Concat(self, IR_node): 880 dim = MXNetEmitter._convert_axis(IR_node, IR_node.IR_layer.attr["axis"].i) 881 code = "{:<15} = mx.sym.concat({}, dim = {}, name = '{}')".format( 882 IR_node.variable_name, 883 ', '.join(self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))), 884 dim, 885 IR_node.name) 886 887 return code 888 889 890 def emit_Cast(self, IR_node): 891 dtype = IR_node.IR_layer.attr["dtype"].type 892 code = "{:<15} = mx.sym.cast(data = {}, dtype = {}, name = '{}')".format( 893 IR_node.variable_name, 894 self.parent_variable_name(IR_node), 895 dtype, 896 IR_node.name) 897 898 return code 899 900 901 def emit_Expand_dims(self, IR_node): 902 axis = IR_node.IR_layer.attr["axis"].i 903 code = "{:<15} = mx.sym.expand_dims(data = {}, axis = {}, name = '{}')".format( 904 IR_node.variable_name, 905 self.parent_variable_name(IR_node), 906 axis, 907 IR_node.name) 908 909 return code 910 911 912 def emit_Pad(self, IR_node): 913 mode = IR_node.IR_layer.attr["mode"].s.lower().decode() 914 pad_width = list() 915 pad_width.extend([0]*4) 916 padding = convert_onnx_pad_to_tf(IR_node.get_attr("pads"))[1:-1] 917 for padding_pair in padding: 918 pad_width.extend(padding_pair) 919 920 pad_width = ', '.join('%s' % i for i in pad_width) 921 922 code = "{:<15} = mx.sym.pad(data = {}, mode = '{}', pad_width = ({}), name = '{}')".format( 923 IR_node.variable_name, 924 self.parent_variable_name(IR_node), 925 mode, 926 pad_width, 927 IR_node.name) 928 929 return code 930 931 932 def emit_Add(self, IR_node): 933 code = "{:<15} = mx.sym.broadcast_add({}, {})".format( 934 IR_node.variable_name, 935 self.parent_variable_name(IR_node), 936 self.parent_variable_name(IR_node, [1])) 937 938 return code 939 940 941 def emit_Mul(self, IR_node): 942 943 code = "{:<15} = mx.sym.broadcast_mul({}, {})".format( 944 IR_node.variable_name, 945 self.parent_variable_name(IR_node), 946 self.parent_variable_name(IR_node, [1])) 947 948 return code 949 950 951 def emit_ReduceMean(self, IR_node): 952 axes = IR_node.layer.attr['axes'].list.i[:] 953 axes = ','.join('%s' % MXNetEmitter.transpose_map[i] for i in axes) 954 955 code = "{:<15} = mx.sym.mean(data = {}, axis = ({}), keepdims = {})".format( 956 IR_node.variable_name, 957 self.parent_variable_name(IR_node), 958 axes, 959 IR_node.layer.attr['keepdims'].b) 960 961 return code 962 963 964 def emit_LRN(self, IR_node): 965 output_name = IR_node.variable_name 966 input_name = self.parent_variable_name(IR_node) 967 IR_name = IR_node.name 968 alpha = IR_node.get_attr('alpha') 969 beta = IR_node.get_attr('beta') 970 bias = IR_node.get_attr('bias') 971 size = IR_node.get_attr('size') 972 973 974 code = "{:<15} = mx.sym.LRN(data = {}, alpha = {}, beta = {}, knorm = {}, nsize = {}, name = '{}')".format( 975 output_name, 976 input_name, 977 alpha, 978 beta, 979 bias, 980 size, 981 IR_name) 982 983 return code 984 985 def emit_Constant(self, IR_node): 986 # save the constant into weight dict 987 if IR_node.get_attr('value'): 988 value = IR_node.get_attr('value') 989 else: 990 value = self.weights[IR_node.name]['value'] 991 992 if not isinstance(value, list): 993 self.output_weights[IR_node.name + '_weight'] = [value] # mxnet's bug, it does not surpport scalar weight. 994 code = "{:<15} = mx.sym.var(name = '{}', shape=(1,))".format(IR_node.variable_name, IR_node.name+'_weight') 995 else: 996 shape = np.array(value).shape 997 self.output_weights[IR_node.name + '_weight'] = value 998 999 code = "{:<15} = mx.sym.var(name = '{}', shape={})".format(IR_node.variable_name, IR_node.name+'_weight', shape) 1000 1001 return code 1002 1003 def emit_Sub(self, IR_node): 1004 code = "{:<15} = mx.sym.broadcast_sub({}, {})".format( 1005 IR_node.variable_name, 1006 self.parent_variable_name(IR_node), 1007 self.parent_variable_name(IR_node, [1])) 1008 1009 return code 1010 1011 1012 def emit_Relu6(self, IR_node): 1013 codes = list() 1014 codes.append(self.emit_Activation(IR_node, 'relu')) 1015 old_name = IR_node.variable_name 1016 IR_node.real_name = IR_node.real_name + "_clip" 1017 codes.append("{:<15} = mx.sym.clip({}, a_min=0, a_max=6, name='{}')".format( 1018 IR_node.real_variable_name, 1019 old_name, 1020 IR_node.real_name)) 1021 1022 return codes 1023 1024 1025 def emit_Slice(self, IR_node): 1026 1027 starts = IR_node.get_attr('starts') 1028 starts = [starts[0], starts[-1]] + starts[1:-1] 1029 ends = IR_node.get_attr('ends') 1030 ends = [ends[0], ends[-1]] + ends[1:-1] 1031 ends = [i if i else None for i in ends] 1032 strides = IR_node.get_attr('strides') 1033 if strides: 1034 strides = [strides[0], strides[-1]] + strides[1:-1] 1035 1036 code = "{:<15} = mx.sym.slice({}, begin={}, end={}, step={}, name='{}')".format( 1037 IR_node.real_variable_name, 1038 self.parent_variable_name(IR_node), 1039 starts, 1040 ends, 1041 strides, 1042 IR_node.name 1043 ) 1044 return code 1045 1046 def emit_Const(self, IR_node): 1047 pass 1048 1049 def emit_Shape(self, IR_node): 1050 code = "{:<15} = mx.sym.var(init = mx.init.Constant({}.infer_shape({}={})[1][0]), name='{}')".format( 1051 IR_node.real_variable_name, 1052 self.parent_variable_name(IR_node), 1053 list(self.input_name_shape.keys())[0], 1054 list(self.input_name_shape.values())[0], 1055 IR_node.name 1056 ) 1057 return code 1058 1059 def emit_Pack(self, IR_node): 1060 pass 1061 1062 def emit_Unsqueeze(self, IR_node): 1063 axis = IR_node.get_attr('axes')[0] 1064 code = "{:<15} = mx.sym.expand_dims(data = {}, axis = {}, name = '{}')".format( 1065 IR_node.variable_name, 1066 self.parent_variable_name(IR_node), 1067 axis, 1068 IR_node.name) 1069 1070 return code 1071 1072 def emit_Unstack(self, IR_node): 1073 squeeze_axis = axis = IR_node.get_attr('axis') 1074 num = IR_node.get_attr('num') 1075 if num is None: 1076 args_str = "" 1077 for input_name in self.IR_graph.input_layers: 1078 if self.IR_graph.get_node(input_name).type!='Const': 1079 args_str += '{}={}, '.format(self.IR_graph.get_node(input_name).real_variable_name, self.data_input_shape[input_name]) 1080 1081 args_str = args_str[:-2] 1082 num_outputs = "{}.infer_shape({})[1][0][{}]".format( 1083 IR_node.variable_name, 1084 args_str, 1085 axis 1086 ) 1087 else: 1088 num_outputs = num 1089 1090 code = "{:<15} = mx.sym.split({}, num_outputs={}, axis={}, squeeze_axis={})".format( 1091 IR_node.variable_name, 1092 self.parent_variable_name(IR_node), 1093 num_outputs, 1094 axis, 1095 squeeze_axis 1096 ) 1097 return code 1098 1099 def emit_Fill(self, IR_node): 1100 value = IR_node.get_attr('value') 1101 code = "{:<15} = mx.sym.full({}, {})".format( 1102 IR_node.variable_name, 1103 self.parent_variable_name(IR_node), 1104 value 1105 ) 1106 return code 1107 1108 def emit_Split(self, IR_node): 1109 axis = IR_node.get_attr('axis') 1110 num_outputs = IR_node.get_attr('split') 1111 1112 if isinstance(num_outputs, list): 1113 raise NotImplementedError() 1114 code = "{:<15} = mx.sym.split({}, num_outputs={}, axis={})".format( 1115 IR_node.variable_name, 1116 self.parent_variable_name(IR_node), 1117 num_outputs, 1118 axis) 1119 1120 return code 1121 1122 1123 def emit_Sigmoid(self, IR_node): 1124 code = "{:<15} = mx.sym.sigmoid(data={}, name='{}')".format( 1125 IR_node.variable_name, 1126 self.parent_variable_name(IR_node), 1127 IR_node.name 1128 ) 1129 return code 1130 1131 1132 def emit_Tanh(self, IR_node): 1133 code = "{:<15} = mx.sym.tanh(data={}, name='{}')".format( 1134 IR_node.variable_name, 1135 self.parent_variable_name(IR_node), 1136 IR_node.name 1137 ) 1138 return code 1139 1140 1141 def emit_Maxmum(self, IR_node): 1142 code = "{:<15} = mx.sym.maxmum({}, {}, name='{}')".format( 1143 IR_node.variable_name, 1144 self.parent_variable_name(IR_node), 1145 self.parent_variable_name(IR_node, [1]), 1146 IR_node.name 1147 ) 1148 return code 1149 1150 1151 def emit_Minimum(self, IR_node): 1152 code = "{:<15} = mx.sym.minimum({}, {}, name='{}')".format( 1153 IR_node.variable_name, 1154 self.parent_variable_name(IR_node), 1155 self.parent_variable_name(IR_node, [1]), 1156 IR_node.name 1157 ) 1158 return code 1159 1160 1161 def emit_Scope(self, IR_node): 1162 import re 1163 pattern = IR_node.pattern 1164 1165 if pattern not in self.naive_scope_pattern and re.sub(r'(_\d+)*$', '', IR_node.pattern) not in self.naive_scope_pattern: 1166 origi_pattern = re.sub(r'(_\d+)*$', '', IR_node.pattern) 1167 func = getattr(self, "_emit_" + origi_pattern) 1168 code = func(IR_node) 1169 else: 1170 code = "{:<15} = __{}({})".format( 1171 IR_node.real_variable_name, 1172 IR_node.pattern, 1173 ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges)) 1174 self._gen_scope_code(IR_node) 1175 return code 1176 1177 1178 def _gen_scope_code(self, scope_node): 1179 1180 def _get_weight_related_op_name(node): 1181 weight_related_ops = ['Constant', 'Conv', 'FullyConnected', 'BatchNorm'] 1182 op_type = node.type 1183 if op_type in weight_related_ops: 1184 return op_type, node.name 1185 1186 def _scope_func(params, code, return_var): 1187 code = """ 1188 def __call__(self, {}): 1189{} 1190 return {} 1191 """.format(params, code, ', '.join(return_var)) 1192 return code 1193 1194 class_inits = dict() 1195 1196 body_code = str() 1197 for node_name in scope_node.topology_list: 1198 node = self.IR_graph.get_node(node_name) 1199 node_type = node.type 1200 1201 if hasattr(self, "emit_" + node_type): 1202 func = getattr(self, "emit_" + node_type) 1203 line = func(node) 1204 if line != None: 1205 body_code += " " + line + '\n' 1206 inits = _get_weight_related_op_name(node) 1207 if inits: 1208 if class_inits.get(inits[0], None): 1209 class_inits[inits[0]].append(inits[1]) 1210 else: 1211 class_inits[inits[0]] = list([inits[1]]) 1212 else: 1213 print("MXNetEmitter has not supported operator [%s]." % (node_type)) 1214 self.emit_UNKNOWN(node) 1215 1216 # param_code does not need parameter slice. 1217 param_code = ', '.join('%s' %self.IR_graph.get_node(s).real_variable_name for s in scope_node.in_edges) 1218 function_code = _scope_func(param_code, body_code, scope_node.return_variables) 1219 1220 return class_inits, function_code 1221 1222 1223 def _emit_gru_cell(self, IR_node): 1224 if not self.layers_codes.get(IR_node.pattern, None): 1225 class_inits, func_code = self._gen_scope_code(IR_node) 1226 variables, variable_codes, init_code, func_code = self.process_inits_func_code(class_inits, func_code) 1227 1228 states = [self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges] 1229 states.pop(0) 1230 states_code = ', '.join(states) 1231 1232 class_code =''' 1233class _{}(mx.rnn.BaseRNNCell): 1234 def __init__(self, {}): 1235 1236{} 1237 1238{} 1239 1240 '''.format(IR_node.pattern, 1241 ', '.join(variables), 1242 init_code, 1243 func_code) 1244 self.layers_codes[IR_node.pattern] = class_code 1245 1246 if not hasattr(self, 'pattern_variables'): 1247 self.pattern_variables = {IR_node.pattern: variables} 1248 else: 1249 self.pattern_variables[IR_node.pattern] = variables 1250 1251 code = variable_codes 1252 code.append("{:<15} = _{}({})({})".format( 1253 IR_node.real_variable_name, 1254 IR_node.pattern, 1255 ', '.join(variables), 1256 ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges))) 1257 else: 1258 code = "{:<15} = _{}({})({})".format( 1259 IR_node.real_variable_name, 1260 IR_node.pattern, 1261 ', '.join(self.pattern_variables[IR_node.pattern]), 1262 ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges)) 1263 1264 return code 1265 1266 1267 def _emit_h_zero(self, IR_node): 1268 code = "{:<15} = mx.sym.full((1, {}), {})".format( 1269 IR_node.variable_name, 1270 IR_node.get_attr('fill_size'), 1271 IR_node.get_attr('fill_value') 1272 ) 1273 return code 1274 1275 1276 def _emit_lstm_cell(self, IR_node): 1277 1278 if not self.layers_codes.get(IR_node.pattern, None): 1279 class_inits, func_code = self._gen_scope_code(IR_node) 1280 variables, variable_codes, init_code, func_code = self.process_inits_func_code(class_inits, func_code) 1281 1282 states = [self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges] 1283 states.pop(0) 1284 states_code = ', '.join(states) 1285 1286 class_code =''' 1287class _{}(mx.rnn.BaseRNNCell): 1288 def __init__(self, {}): 1289 1290{} 1291 1292{} 1293 1294 '''.format(IR_node.pattern, 1295 ', '.join(variables), 1296 init_code, 1297 func_code) 1298 self.layers_codes[IR_node.pattern] = class_code 1299 1300 if not hasattr(self, 'pattern_variables'): 1301 self.pattern_variables = {IR_node.pattern: variables} 1302 else: 1303 self.pattern_variables[IR_node.pattern] = variables 1304 1305 code = variable_codes 1306 code.append("{:<15} = _{}({})({})".format( 1307 IR_node.real_variable_name, 1308 IR_node.pattern, 1309 ', '.join(variables), 1310 ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges))) 1311 else: 1312 code = "{:<15} = _{}({})({})".format( 1313 IR_node.real_variable_name, 1314 IR_node.pattern, 1315 ', '.join(self.pattern_variables[IR_node.pattern]), 1316 ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges)) 1317 1318 return code 1319 1320 1321 def process_inits_func_code(self, class_inits, func_code): 1322 init_code = str() 1323 variables = list() 1324 variable_codes = list() 1325 for k, v in class_inits.items(): 1326 if k == 'FullyConnected': 1327 for i, name in enumerate(class_inits[k]): 1328 variable_name = self.IR_graph.get_node(name).variable_name 1329 variables.append("W_" + variable_name) 1330 variable_codes.append("W_{:<15} = mx.sym.var(name='{}_weight')".format(variable_name, name)) 1331 init_code += " self.W_{} = W_{}\n".format(variable_name, variable_name) 1332 1333 if self.weight_loaded and self.weights[name].get('bias', None).any() != None: 1334 variable_codes.append("B_{:<15} = mx.sym.var(name='{}_bias')".format(variable_name, name)) 1335 variables.append("B_" + variable_name) 1336 init_code += " self.B_{} = B_{}\n".format(variable_name, variable_name) 1337 func_code = func_code.replace("name = '{}'".format(name), "name = '{}', weight = self.W_{}, bias = self.B_{}".format(name, variable_name, variable_name)) 1338 else: 1339 func_code = func_code.replace("name = '{}'".format(name), "name = '{}', weight = self.W_{}".format(name, variable_name)) 1340 elif k == 'Constant': 1341 for name in class_inits[k]: 1342 variable_name = self.IR_graph.get_node(name.replace('_weight', '')).variable_name 1343 variables.append(variable_name) 1344 constant_line = self.emit_Constant(self.IR_graph.get_node(name.replace('_weight', ''))) 1345 variable_codes.append("{:<15} = {}".format(variable_name, '='.join(constant_line.split('=')[1:]))) 1346 init_code += " self.{} = {}\n".format(variable_name, variable_name) 1347 func_code = func_code.replace(constant_line, constant_line.split('=')[0] + ' = self.'+constant_line.split('=')[0]) 1348 else: 1349 raise NotImplementedError 1350 1351 return variables, variable_codes, init_code, func_code 1352 1353