1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6import os
7
8import math
9import mxnet as mx
10import numpy as np
11from mmdnn.conversion.common.IR.IR_graph import IRGraph, IRGraphNode
12import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
13from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
14from mmdnn.conversion.common.DataStructure.emitter import Emitter
15from mmdnn.conversion.common.utils import *
16from mmdnn.conversion.rewriter.folder import Folder
17
18class MXNetEmitter(Emitter):
19
20    dtype_map = {
21        graph_pb2.DT_FLOAT16    : "float16",
22        graph_pb2.DT_FLOAT32    : "float32",
23        graph_pb2.DT_FLOAT64    : "float64",
24        graph_pb2.DT_INT32      : "int32",
25        graph_pb2.DT_UINT8      : "uint8"
26    }
27
28    activation_map = {
29        "relu"      : "Relu",
30        "sigmoid"   : "Sigmoid",
31        "tanh"      : "Tanh",
32        "elu"       : "Elu"
33    }
34
35    transpose_map = {
36        1 : 2,
37        2 : 3,
38       -1 : 1
39    }
40
41    naive_scope_pattern = []
42
43    channels_last = ['NDHWC', 'NHWC']
44
45    def __init__(self, model):
46        super(MXNetEmitter, self).__init__()
47        from six import string_types as _string_types
48
49        if isinstance(model, _string_types):
50            network_path = model
51            self.weight_loaded = False
52        elif len(model) == 3:
53            network_path = model[0]
54            weight_path = model[1]
55            self.output_weights_file = model[2]
56            self.output_weights = dict()
57            self._load_weights(weight_path)
58            self.weights = self.weights_dict
59        else:
60            raise ValueError("the # of input arguments [{}] is not supported" % len(model))
61
62        self.IR_graph = IRGraph(network_path)
63        self.IR_graph.build()
64
65        folder = Folder(self.IR_graph, self.weights)
66        folder.fold()
67
68    @property
69    def header_code(self):
70        return """import mxnet as mx
71import numpy as np
72import math
73
74# mxnet-cpu only support channel first, default convert the model and weight as channel first
75
76def RefactorModel():
77"""
78
79
80    def gen_code(self, phase):
81        self.IR_layer_map = dict()
82        self.add_body(0, self.header_code)
83        for layer in self.IR_graph.topological_sort:
84            self.IR_layer_map[layer] = self.IR_graph.get_node(layer)
85
86        shape = dict()
87        for layer in self.IR_graph.topological_sort:
88            current_node = self.IR_graph.get_node(layer)
89            node_type = current_node.type
90
91
92            if len(current_node.in_edges) == 0:
93                current_node.in_edges.append('data')
94
95            if node_type.lower() in MXNetEmitter.activation_map:
96                func = getattr(self, "emit_Activation")
97                line = func(current_node, MXNetEmitter.activation_map[node_type.lower()].lower())
98                self.add_body(1, line)
99
100            elif hasattr(self, "emit_" + node_type):
101                func = getattr(self, "emit_" + node_type)
102                line = func(current_node)
103                if line != None:
104                    self.add_body(1, line)
105            else:
106                print("MXNet Emitter has not supported operator [%s]." % (node_type))
107                self.emit_UNKNOWN(current_node)
108
109            if node_type == "DataInput":
110                cur_shape = list()
111                first = True
112                for dim in current_node.IR_layer.attr["shape"].shape.dim:
113                    if dim.size == -1 and first:
114                        cur_shape.append(1)
115                        print("Detect input layer [{}] using infer batch size, set it as default value [1]".format(current_node.name))
116                    else:
117                        if dim.size == -1:
118                            print("Warning: user should change input size manually")
119                        cur_shape.append(dim.size)
120                    first = False
121
122                cur_shape.insert(1, cur_shape.pop())
123                shape[current_node.name] = ', '.join('%s' % i for i in cur_shape)
124                self.input_name_shape = {current_node.name: tuple(cur_shape)}
125
126
127        if self.weight_loaded:
128            fullpath = os.path.abspath(self.output_weights_file)
129            dirname = os.path.dirname(fullpath)
130            if not os.path.exists(dirname):
131                os.makedirs(dirname)
132            with open(self.output_weights_file, 'wb') as outfile:
133                np.save(outfile, self.output_weights)
134
135        comment = "\n    # if a GPU is available, change mx.cpu() to mx.gpu()"
136        # We use the real_name for specifying the input layer in data_names
137        # since MXNet API wants the actual name of the layer. On the other
138        # hand, the module API wants the last symbol in the symbol chain, so
139        # for the output node we need to use the actual python variable name
140        # of the last layer (real_variable_name).
141        last_line = "{:<15} = mx.mod.Module(symbol = {}, context = mx.cpu(), data_names = ['{}'])".format(
142            "model",
143            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.output_layers if self.IR_graph.get_node(name).type !='Pack' and self.IR_graph.get_node(name).type != 'Shape']),
144            ', '.join([self.IR_graph.get_node(name).real_name for name in self.IR_graph.input_layers if self.IR_graph.get_node(name).type != 'Const']))
145
146        self.add_body(1, comment)
147        self.add_body(1, last_line)
148        self.add_body(1, "return model")
149
150
151        self.add_body(0, "")
152        for code in self.layers_codes.values():
153            self.add_body(0, code)
154
155        weight_code = ""
156        if not self.weight_loaded:
157            weight_code += "# emitter does not detect any import weights, you may generate weights file manually\n"
158
159        weight_code += self.gen_weight_code(shape, phase)
160
161        main_code = "if __name__ == '__main__':\n    model = RefactorModel()\n"
162        if self.weight_loaded:
163            main_code += "    # remember to adjust params path\n    model = deploy_weight(model, '{}')\n".format(self.output_weights_file)
164
165        if phase == 'train':
166            train_code = """def train(model):
167    import logging
168    logging.getLogger().setLevel(logging.DEBUG)
169    model.fit(train_iter, # train data
170            eval_data = val_iter, # validation data
171            optimizer = 'sgd', # Defaults to 'sgd'
172            optimizer_params = {'learning_rate':0.01}, # use fixed learning rate
173            eval_metric = 'acc', # report accuracy during training, other possible predefined metrics are: 'ce', 'f1', 'mae', 'mse', 'rmse', 'top_k_accuracy'
174            batch_end_callback = mx.callback.Speedometer(batch_size, 100), # output progress for each 100 data batches
175            num_epoch = 10) # train for at most 10 dataset passes\n\n
176"""
177            code = self.body_code + weight_code + train_code + main_code
178        else:
179            test_code = """from collections import namedtuple
180Batch = namedtuple('Batch', ['data'])
181
182
183def get_image(url, show=False):
184    import cv2
185    # download and show the image
186    fname = mx.test_utils.download(url)
187    img = cv2.cvtColor(cv2.imread(fname), cv2.COLOR_BGR2RGB)
188    if img is None:
189        return None
190    if show:
191        import matplotlib.pyplot as plt
192        plt.imshow(img)
193        plt.axis('off')
194    # convert into format (batch, RGB, width, height)
195    img = cv2.resize(img, (224, 224))
196    img = np.swapaxes(img, 0, 2)
197    img = np.swapaxes(img, 1, 2)
198    img = img[np.newaxis, :]
199    return img
200
201
202def predict(model, labels, url):
203    # to show the image, change the argument show into True
204    img = get_image(url, show = False)
205    # compute the predict probabilities
206    model.forward(Batch([mx.nd.array(img)]))
207    prob = model.get_outputs()[0].asnumpy()
208    # print the top-5
209    prob = np.squeeze(prob)
210    a = np.argsort(prob)[::-1]
211    for i in a[0:5]:
212        print('prbability = %f, class = %s' %(prob[i], labels[i]))\n\n
213"""
214
215            main_code += """
216    # # call function predict
217    # with open('synset.txt', 'r') as f:
218    #     labels = [l.rstrip() for l in f]
219    # predict(model, labels, 'http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg')
220"""
221
222            code = self.body_code + weight_code + test_code + main_code
223
224        return code
225
226
227    def gen_weight_code(self, shape, phase):
228        str = "def deploy_weight(model, weight_file):\n"
229        str += """
230    if weight_file == None:
231        return
232
233    try:
234        weights_dict = np.load(weight_file, allow_pickle=True).item()
235    except:
236        weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
237
238    arg_params = dict()
239    aux_params = dict()
240    for weight_name, weight_data in weights_dict.items():
241        weight_name = str(weight_name)
242        if "moving" in weight_name:
243            aux_params[weight_name] = mx.nd.array(weight_data)
244        else:
245            arg_params[weight_name] = mx.nd.array(weight_data)
246
247"""
248        if phase == 'train':
249            str += "    model.bind(for_training = True, data_shapes = ["
250        else:
251            str += "    model.bind(for_training = False, data_shapes = ["
252        first = True
253        for k, v in shape.items():
254            if not first:
255                str += ", "
256            str += "('" + k + "', " + "(" + v + "))"
257            first = False
258        str += "])\n"
259        str += "    model.set_params(arg_params = arg_params, aux_params = aux_params, allow_missing = True, allow_extra=True)\n\n    return model\n\n\n"
260        return str
261
262
263    @staticmethod
264    def calculate_same_pad(data_shape, kernel, stride):
265        if (data_shape % stride == 0):
266            pad = max(kernel - stride, 0)
267        else:
268            pad = max(kernel - (data_shape % stride), 0)
269        if pad % 2 == 0:
270            return False, pad
271        else:
272            return True, pad
273
274
275    @staticmethod
276    def transfer_pad(pad_list):
277        defuse_pad = False
278        pad = list()
279
280        assert len(pad_list) % 2 == 0
281        mid = int(len(pad_list)/2)
282        pad_first = pad_list[1:mid-1]
283        pad_second = pad_list[mid+1:-1]
284
285        for i in range(0, mid-2):
286            if not pad_first[i] == pad_second[i]:
287                defuse_pad = True
288
289        if defuse_pad:
290            pad.extend([0] * 4)
291            for i in range(0, mid-2):
292                pad.extend([pad_first[i], pad_second[i]])
293        else:
294            pad = pad_first
295
296        return defuse_pad, pad
297
298
299    @staticmethod
300    def transpose(data, dim):
301        if dim == 1:
302            data = data.transpose((2, 1, 0))
303        elif dim == 2:
304            data = data.transpose((3, 2, 0, 1))
305        elif dim == 3:
306            data = data.transpose((4, 3, 0, 1, 2))
307        else:
308            raise ValueError("The weight of dim {} cannot transpose" % dim)
309
310        return data
311
312
313    def set_pad(self, IR_node, code, pad, _max_pool):
314        if _max_pool:
315            constant_value = "float('-inf')"
316        else:
317            constant_value = "0.0"
318
319        code = "{:<15} = mx.sym.pad(data = {}, mode = 'constant', pad_width={}, constant_value = {}, name = '{}')".format(
320                IR_node.variable_name + "_pad",
321                self.parent_variable_name(IR_node),
322                tuple(pad),
323                constant_value,
324                IR_node.name + "_pad")
325
326        for e in IR_node.in_edges:
327            e = e.split(':')[0]
328            if e == 'data':
329                continue
330            self.IR_layer_map[e].out_edges = [x if not self.IR_layer_map[x.split(':')[0]].name == IR_node.variable_name else IR_node.variable_name + "_pad" for x in self.IR_layer_map[e].out_edges]
331
332        return code
333
334
335    def emit_UNKNOWN(self, IR_node):
336        print(IR_node.name)
337
338
339    def emit_FullyConnected(self, IR_node):
340        if self.weight_loaded:
341            weight_dict = self.weights[IR_node.name]
342            parent = self.IR_graph.get_parent(IR_node.name, [0])
343            while parent.type == "Flatten" or parent.type == 'Dropout':
344                parent = self.IR_graph.get_parent(parent.name, [0])
345            dim = len(parent.layer.attr['_output_shapes'].list.shape[0].dim)
346            if dim > 2:
347                original_dims = weight_dict['weights'].shape
348                dims = [i.size for i in parent.layer.attr['_output_shapes'].list.shape[0].dim[1:]] + [-1]
349                weight_dict['weights'] = np.reshape(weight_dict['weights'], dims)
350                weight_dict['weights'] = np.transpose(weight_dict['weights'], [dim - 2] + list(range(0, dim - 2)) + [dim - 1])
351                weight_dict['weights'] = np.reshape(weight_dict['weights'], original_dims)
352            self.output_weights[IR_node.name + "_weight"] = weight_dict['weights'].transpose((1, 0))
353
354        num_hidden = IR_node.IR_layer.attr["units"].i
355        no_bias = not IR_node.IR_layer.attr["use_bias"].b
356        if not no_bias and self.weight_loaded:
357            self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']
358
359        code = "{:<15} = mx.sym.FullyConnected(data = {}, num_hidden = {}, no_bias = {}, name = '{}')".format(
360                IR_node.variable_name,
361                self.parent_variable_name(IR_node),
362                num_hidden,
363                no_bias,
364                IR_node.name)
365
366        return code
367
368
369    def _emit_convolution(self, IR_node, pattern):
370        if self.weight_loaded:
371            weight_dict = self.weights[IR_node.name]
372            weights = weight_dict['weights']
373
374        dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2
375
376        kernel = list()
377        for idx in range(0, dim):
378            kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx])
379
380        stride = list()
381        for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
382            stride.append(e)
383
384        dilate = list()
385        for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]:
386            dilate.append(e)
387        if dilate == []: dilate = [1, 1]
388        dilate = ', '.join('%s' % i for i in dilate)
389
390        defuse_pad = False
391        pad = list()
392        if "pads" in IR_node.IR_layer.attr:
393            output_shape = list()
394            for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
395                output_shape.append(e.size)
396
397            # print("Warning: MXNet Convolution Layer pad does not match IR Convolution Layer pad")
398            defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)
399
400        num_filter = 0
401        if pattern == "Deconvolution":
402            num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
403        else:
404            num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-1]
405
406        use_bias = IR_node.get_attr('use_bias', False)
407        if use_bias and self.weight_loaded:
408            self.output_weights[IR_node.name + "_bias"] = weight_dict['bias']
409
410        if pattern == "DepthwiseConv":
411            num_group = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
412            num_filter = num_filter * num_group
413            pattern = "Convolution"
414            if self.weight_loaded:
415                weights = np.swapaxes(weights, -1, -2)
416
417        else:
418            num_group = IR_node.get_attr('group', 1)
419
420        # layout = IR_node.IR_layer.attr["data_format"].s
421        if dim == 1:
422            layout = 'NCW'
423        elif dim == 2:
424            layout = 'NCHW'
425        elif dim == 3:
426            layout = 'NCDHW'
427
428        if self.weight_loaded:
429            # if layout not in MXNetEmitter.channels_last:
430            weights = MXNetEmitter.transpose(weights, dim)
431            self.output_weights[IR_node.name + "_weight"] = weights
432
433        code = ""
434        if not defuse_pad:
435            code += "{:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), pad={}, num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format(
436                IR_node.variable_name,
437                pattern,
438                self.parent_variable_name(IR_node),
439                tuple(kernel),
440                tuple(stride),
441                dilate,
442                tuple(pad),
443                num_filter,
444                num_group,
445                not use_bias,
446                layout,
447                IR_node.name)
448        else:
449            code += self.set_pad(IR_node, code, pad, False)
450            code += "\n    {:<15} = mx.sym.{}(data={}, kernel={}, stride={}, dilate = ({}), num_filter = {}, num_group = {}, no_bias = {}, layout = '{}', name = '{}')".format(
451                IR_node.variable_name,
452                pattern,
453                IR_node.variable_name + "_pad",
454                tuple(kernel),
455                tuple(stride),
456                dilate,
457                num_filter,
458                num_group,
459                not use_bias,
460                layout,
461                IR_node.name)
462
463        return code
464
465
466    def emit_Conv(self, IR_node):
467        return self._emit_convolution(IR_node, "Convolution")
468
469
470    def emit_DepthwiseConv(self, IR_node):
471        return self._emit_convolution(IR_node, "DepthwiseConv")
472
473
474    def emit_ConvTranspose(self, IR_node):
475        return self._emit_convolution(IR_node, "Deconvolution")
476
477
478    def emit_DataInput(self, IR_node):
479        shape = list()
480        shape.extend(IR_node.IR_layer.attr["shape"].list.i)
481
482        code = "{:<15} = mx.sym.var('{}')".format(IR_node.variable_name, IR_node.name)
483        return code
484
485
486    # Add LeakyReLU Elu(slope not support)
487    def emit_Activation(self, IR_node, act_type):
488
489        act_type = act_type
490        func_name = ""
491
492        if act_type == "elu":
493            func_name = "LeakyReLU"
494        else:
495            func_name = "Activation"
496
497        code = "{:<15} = mx.sym.{}(data = {}, act_type = '{}', name = '{}')".format(
498                IR_node.variable_name,
499                func_name,
500                self.parent_variable_name(IR_node),
501                act_type,
502                IR_node.name)
503
504        return code
505
506
507    def emit_BatchNorm(self, IR_node):
508        IR_node_after = self.IR_graph.get_son(IR_node.name, [0])
509        if IR_node_after.type == 'Scale':
510            if self.weight_loaded:
511                weight_dict = self.weights[IR_node.name]
512                weight_dict_scale = self.weights[IR_node_after.name]
513
514            # axis = IR_node.IR_layer.attr["axis"].i
515            axis = 1
516            eps = IR_node.IR_layer.attr["epsilon"].f
517            momentum = IR_node.IR_layer.attr["momentum"].f
518
519            fix_gamma = not IR_node.IR_layer.attr["scale"].b
520
521            if self.weight_loaded:
522                if not fix_gamma:
523                #     self.output_weights[IR_node.name + "_gamma"] = np.multiply(weight_dict['scale'], weight_dict_scale['scale'])
524                # self.output_weights[IR_node.name + "_beta"] = np.multiply(weight_dict['bias'], weight_dict_scale['scale']) + weight_dict_scale['bias']
525                    self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale']
526                self.output_weights[IR_node.name + "_beta"] = weight_dict['bias']
527
528            # not supported yet
529            use_global_stats = "False"
530            if self.weight_loaded:
531                self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var']
532                self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean']
533
534            code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
535                    IR_node.variable_name,
536                    self.parent_variable_name(IR_node),
537                    axis,
538                    eps,
539                    momentum,
540                    fix_gamma,
541                    use_global_stats,
542                    IR_node.name)
543
544            return code
545
546        else:
547            if self.weight_loaded:
548                weight_dict = self.weights[IR_node.name]
549
550            # axis = IR_node.IR_layer.attr["axis"].i
551            axis = 1
552            eps = IR_node.IR_layer.attr["epsilon"].f
553            momentum = IR_node.IR_layer.attr["momentum"].f
554
555            fix_gamma = not IR_node.IR_layer.attr["scale"].b
556
557            if self.weight_loaded:
558                if not fix_gamma:
559                    self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale']
560                self.output_weights[IR_node.name + "_beta"] = weight_dict['bias']
561
562            # not supported yet
563            use_global_stats = "False"
564            if self.weight_loaded:
565                self.output_weights[IR_node.name + "_moving_var"] = weight_dict['var']
566                self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['mean']
567
568            code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
569                    IR_node.variable_name,
570                    self.parent_variable_name(IR_node),
571                    axis,
572                    eps,
573                    momentum,
574                    fix_gamma,
575                    use_global_stats,
576                    IR_node.name)
577
578            return code
579
580    def emit_Scale(self, IR_node):
581        if self.weight_loaded:
582            weight_dict = self.weights[IR_node.name]
583
584        # axis = IR_node.IR_layer.attr["axis"].i
585        axis = 1
586        eps = 0.0
587        momentum = 0.0
588
589        fix_gamma = not IR_node.IR_layer.attr["scale"].b
590
591        if self.weight_loaded:
592            if not fix_gamma:
593                self.output_weights[IR_node.name + "_gamma"] = weight_dict['scale']
594            self.output_weights[IR_node.name + "_beta"] = weight_dict['bias']
595
596        # not supported yet
597        use_global_stats = "False"
598        if self.weight_loaded:
599            self.output_weights[IR_node.name + "_moving_var"] = weight_dict['scale_var']
600            self.output_weights[IR_node.name + "_moving_mean"] = weight_dict['scale_mean']
601
602        code = "{:<15} = mx.sym.BatchNorm(data = {}, axis = {}, eps = {}, momentum = {}, fix_gamma = {}, use_global_stats = {}, name = '{}')".format(
603                IR_node.variable_name,
604                self.parent_variable_name(IR_node),
605                axis,
606                eps,
607                momentum,
608                fix_gamma,
609                use_global_stats,
610                IR_node.name)
611
612        return code
613
614
615
616    def emit_Pool(self, IR_node):
617
618        global_pool = IR_node.IR_layer.attr["global_pooling"].b
619
620        kernel = list()
621        if global_pool:
622            kernel = [1] * (len(IR_node.IR_layer.attr["strides"].list.i) - 2)
623        else:
624            for e in IR_node.IR_layer.attr["kernel_shape"].list.i[1:-1]:
625                kernel.append(e)
626
627        pool_type = IR_node.get_attr('pooling_type').lower()
628
629        stride = list()
630        for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
631            stride.append(e)
632
633        defuse_pad = False
634        pad = list()
635        if "pads" in IR_node.IR_layer.attr:
636            output_shape = list()
637            for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
638                output_shape.append(e.size)
639
640            # print("Warning: MXNet Pooling Layer pad does not match IR Pooling Layer pad")
641            defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)
642        code = ""
643        if not defuse_pad:
644            code += "{:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, pad={}, name = '{}')".format(
645                    IR_node.variable_name,
646                    self.parent_variable_name(IR_node),
647                    global_pool,
648                    tuple(kernel),
649                    pool_type,
650                    tuple(stride),
651                    tuple(pad),
652                    IR_node.name)
653        else:
654            code += self.set_pad(IR_node, code, pad, pool_type == "max")
655            code += "\n    {:<15} = mx.sym.Pooling(data = {}, global_pool = {}, kernel={}, pool_type = '{}', stride={}, name = '{}')".format(
656                    IR_node.variable_name,
657                    IR_node.variable_name + "_pad",
658                    global_pool,
659                    tuple(kernel),
660                    pool_type,
661                    tuple(stride),
662                    IR_node.name)
663
664        return code
665
666
667    def emit_SoftmaxOutput(self, IR_node):
668
669        code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format(
670            IR_node.variable_name,
671            self.parent_variable_name(IR_node)
672        )
673
674        return code
675
676
677    def emit_Softmax(self, IR_node):
678
679        code = ""
680
681        if len(IR_node.out_edges) == 0:
682            code = "{:<15} = mx.sym.SoftmaxOutput(data = {}, name = 'softmax')".format(
683                    IR_node.variable_name,
684                    self.parent_variable_name(IR_node))
685        else:
686            axis = IR_node.IR_layer.attr["dim"].i
687            code = "{:<15} = mx.sym.softmax(data = {}, axis = {}, name = '{}')".format(
688                    IR_node.variable_name,
689                    self.parent_variable_name(IR_node),
690                    axis,
691                    IR_node.name)
692
693        return code
694
695
696    def emit_Squeeze(self, IR_node):
697        return self.emit_Flatten(IR_node)
698
699
700    # def emit_ConvTranspose(self, IR_node):
701    #     if self.weight_loaded:
702    #         weight_dict = self.weights[IR_node.name]
703    #         weights = weight_dict['weights']
704
705    #     dim = len(IR_node.IR_layer.attr["kernel_shape"].list.i) - 2
706
707    #     kernel = list()
708    #     for idx in range(0, dim):
709    #         kernel.append(IR_node.IR_layer.attr["kernel_shape"].list.i[idx])
710
711    #     stride = list()
712    #     for e in IR_node.IR_layer.attr["strides"].list.i[1:-1]:
713    #         stride.append(e)
714
715    #     dilate = list()
716    #     for e in IR_node.IR_layer.attr["dilations"].list.i[1:-1]:
717    #         dilate.append(e)
718    #     dilate = ', '.join('%s' % i for i in dilate)
719
720    #     defuse_pad = False
721    #     pad = list()
722    #     if "pads" in IR_node.IR_layer.attr:
723    #         output_shape = list()
724    #         for e in IR_node.IR_layer.attr["_output_shapes"].list.shape[0].dim:
725    #             output_shape.append(e.size)
726
727    #         # print("Warning: MXNet Deconvolution Layer pad does not match IR Deconvolution Layer pad")
728    #         defuse_pad, pad = MXNetEmitter.transfer_pad(IR_node.IR_layer.attr["pads"].list.i)
729    #     pad = ', '.join('%s' % i for i in pad)
730
731    #     kernel = ', '.join('%s' % i for i in kernel)
732    #     stride = ', '.join('%s' % i for i in stride)
733
734    #     num_filter = IR_node.IR_layer.attr["kernel_shape"].list.i[-2]
735    #     no_bias = not IR_node.IR_layer.attr["use_bias"].b
736    #     if not no_bias and self.weight_loaded:
737    #         self.output_weights[IR_node.replace_scope(IR_node.name) + "_bias"] = weight_dict['bias']
738
739    #     # layout = IR_node.IR_layer.attr["data_format"].s
740    #     if dim == 1:
741    #         layout = 'NCW'
742    #     elif dim == 2:
743    #         layout = 'NCHW'
744    #     elif dim == 3:
745    #         layout = 'NCDHW'
746
747    #     if self.weight_loaded:
748    #         # if layout not in MXNetEmitter.channels_last:
749    #         weights = MXNetEmitter.transpose(weights, dim)
750    #         self.output_weights[IR_node.replace_scope(IR_node.name) + "_weight"] = weights
751
752    #     code = ""
753    #     if not defuse_pad:
754    #         code = "{:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), pad = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format(
755    #                 IR_node.replace_scope(IR_node.name),
756    #                 IR_node.replace_scope(IR_node.in_edges[0]),
757    #                 kernel,
758    #                 stride,
759    #                 dilate,
760    #                 pad,
761    #                 num_filter,
762    #                 no_bias,
763    #                 layout,
764    #                 IR_node.replace_scope(IR_node.name))
765    #     else:
766    #         code = self.set_pad(IR_node, code, pad)
767    #         code += "\n    {:<15} = mx.sym.Deconvolution(data = {}, kernel = ({}), stride = ({}), dilate = ({}), num_filter = {}, no_bias = {}, layout = '{}', name = '{}')".format(
768    #                 IR_node.replace_scope(IR_node.name), IR_node.replace_scope(IR_node.name) + "_pad", kernel, stride, dilate, num_filter, no_bias, layout, IR_node.replace_scope(IR_node.name))
769
770    #     return code
771
772
773    def emit_Embedding(self, IR_node):
774
775        input_dim = IR_node.IR_layer.attr["input_dim"].i
776        output_dim = IR_node.IR_layer.attr["output_dim"].i
777        dtype = MXNetEmitter.dtype_map.get(IR_node.layer.attr["dtype"].type, "float32")
778
779        weight_dict = self.weights[IR_node.name]
780
781        if self.weight_loaded:
782            self.output_weights[IR_node.name + "_weight"] = weight_dict['weights']
783
784        code = "{:<15} = mx.sym.Embedding(data = {}, input_dim = {}, output_dim = {}, dtype = '{}', name = '{}')".format(
785                IR_node.variable_name,
786                self.parent_variable_name(IR_node),
787                input_dim,
788                output_dim,
789                dtype,
790                IR_node.name)
791
792        return code
793
794
795    def emit_LeakyRelu(self, IR_node):
796        alpha = IR_node.IR_layer.attr['alpha'].f
797        code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, name = '{}')".format(
798                IR_node.variable_name,
799                self.parent_variable_name(IR_node),
800                alpha,
801                IR_node.name
802        )
803        return code
804
805    def emit_PRelu(self, IR_node):
806        slope = IR_node.get_attr('gamma')
807        code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, act_type = '{}', name = '{}')".format(
808                IR_node.variable_name,
809                self.parent_variable_name(IR_node),
810                slope,
811                'prelu',
812                IR_node.name
813        )
814        return code
815
816    def emit_Elu(self, IR_node):
817        alpha = IR_node.IR_layer.attr['alpha'].f
818        code = "{:<15} = mx.sym.LeakyReLU(data = {}, slope = {}, act_type = {}, name = '{}')".format(
819                IR_node.variable_name,
820                self.parent_variable_name(IR_node),
821                alpha,
822                'elu',
823                IR_node.name
824        )
825        return code
826
827    def emit_Dropout(self, IR_node):
828        p = IR_node.IR_layer.attr["keep_prob"].f
829        mode = IR_node.IR_layer.attr["mode"].s.lower().decode() if 'mode' in IR_node.layer.attr else 'training'
830        code = "{:<15} = mx.sym.Dropout(data = {}, p = {}, mode = '{}', name = '{}')".format(
831                IR_node.variable_name,
832                self.parent_variable_name(IR_node),
833                p,
834                mode,
835                IR_node.name)
836
837        return code
838
839
840    # reverse cannot support yet
841    def emit_Reshape(self, IR_node):
842        shape = list()
843        for e in IR_node.IR_layer.attr["shape"].list.i:
844            shape.append(e)
845        shape = ', '.join('%s' % i for i in shape)
846        reverse = False
847
848        code = "{:<15} = mx.sym.reshape(data = {}, shape = ({}), reverse = {}, name = '{}')".format(
849                IR_node.variable_name,
850                self.parent_variable_name(IR_node),
851                shape,
852                reverse,
853                IR_node.name)
854
855        return code
856
857
858    def emit_Flatten(self, IR_node):
859        # code = "{:<15} = mx.sym.transpose(data = {}, axes = (0, 2, 3, 1))\n".format("trans", self.parent_variable_name(IR_node))
860        code = "{:<15} = mx.sym.flatten(data = {}, name = '{}')".format(
861                IR_node.variable_name,
862                self.parent_variable_name(IR_node),
863                IR_node.name)
864
865        return code
866
867
868    @staticmethod
869    def _convert_axis(IR_node, axis):
870        ndim = len(IR_node.layer.attr['_output_shapes'].list.shape[0].dim)
871        if axis == 0:
872            return 0
873        elif axis == ndim - 1:
874            return 1
875        else:
876            return axis + 1
877
878
879    def emit_Concat(self, IR_node):
880        dim = MXNetEmitter._convert_axis(IR_node, IR_node.IR_layer.attr["axis"].i)
881        code = "{:<15} = mx.sym.concat({}, dim = {}, name = '{}')".format(
882                IR_node.variable_name,
883                ', '.join(self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))),
884                dim,
885                IR_node.name)
886
887        return code
888
889
890    def emit_Cast(self, IR_node):
891        dtype = IR_node.IR_layer.attr["dtype"].type
892        code = "{:<15} = mx.sym.cast(data = {}, dtype = {}, name = '{}')".format(
893                IR_node.variable_name,
894                self.parent_variable_name(IR_node),
895                dtype,
896                IR_node.name)
897
898        return code
899
900
901    def emit_Expand_dims(self, IR_node):
902        axis = IR_node.IR_layer.attr["axis"].i
903        code = "{:<15} = mx.sym.expand_dims(data = {}, axis = {}, name = '{}')".format(
904                IR_node.variable_name,
905                self.parent_variable_name(IR_node),
906                axis,
907                IR_node.name)
908
909        return code
910
911
912    def emit_Pad(self, IR_node):
913        mode = IR_node.IR_layer.attr["mode"].s.lower().decode()
914        pad_width = list()
915        pad_width.extend([0]*4)
916        padding = convert_onnx_pad_to_tf(IR_node.get_attr("pads"))[1:-1]
917        for padding_pair in padding:
918            pad_width.extend(padding_pair)
919
920        pad_width = ', '.join('%s' % i for i in pad_width)
921
922        code = "{:<15} = mx.sym.pad(data = {}, mode = '{}', pad_width = ({}), name = '{}')".format(
923                IR_node.variable_name,
924                self.parent_variable_name(IR_node),
925                mode,
926                pad_width,
927                IR_node.name)
928
929        return code
930
931
932    def emit_Add(self, IR_node):
933        code = "{:<15} = mx.sym.broadcast_add({}, {})".format(
934                IR_node.variable_name,
935                self.parent_variable_name(IR_node),
936                self.parent_variable_name(IR_node, [1]))
937
938        return code
939
940
941    def emit_Mul(self, IR_node):
942
943        code = "{:<15} = mx.sym.broadcast_mul({}, {})".format(
944                IR_node.variable_name,
945                self.parent_variable_name(IR_node),
946                self.parent_variable_name(IR_node, [1]))
947
948        return code
949
950
951    def emit_ReduceMean(self, IR_node):
952        axes = IR_node.layer.attr['axes'].list.i[:]
953        axes = ','.join('%s' % MXNetEmitter.transpose_map[i] for i in axes)
954
955        code = "{:<15} = mx.sym.mean(data = {}, axis = ({}), keepdims = {})".format(
956                IR_node.variable_name,
957                self.parent_variable_name(IR_node),
958                axes,
959                IR_node.layer.attr['keepdims'].b)
960
961        return code
962
963
964    def emit_LRN(self, IR_node):
965        output_name = IR_node.variable_name
966        input_name = self.parent_variable_name(IR_node)
967        IR_name = IR_node.name
968        alpha = IR_node.get_attr('alpha')
969        beta = IR_node.get_attr('beta')
970        bias = IR_node.get_attr('bias')
971        size = IR_node.get_attr('size')
972
973
974        code = "{:<15} = mx.sym.LRN(data = {}, alpha = {}, beta = {}, knorm = {}, nsize = {}, name = '{}')".format(
975                output_name,
976                input_name,
977                alpha,
978                beta,
979                bias,
980                size,
981                IR_name)
982
983        return code
984
985    def emit_Constant(self, IR_node):
986        # save the constant into weight dict
987        if IR_node.get_attr('value'):
988            value = IR_node.get_attr('value')
989        else:
990            value = self.weights[IR_node.name]['value']
991
992        if not isinstance(value, list):
993            self.output_weights[IR_node.name + '_weight'] = [value] # mxnet's bug, it does not surpport scalar weight.
994            code = "{:<15} = mx.sym.var(name = '{}', shape=(1,))".format(IR_node.variable_name, IR_node.name+'_weight')
995        else:
996            shape = np.array(value).shape
997            self.output_weights[IR_node.name + '_weight'] = value
998
999            code = "{:<15} = mx.sym.var(name = '{}', shape={})".format(IR_node.variable_name, IR_node.name+'_weight', shape)
1000
1001        return code
1002
1003    def emit_Sub(self, IR_node):
1004        code = "{:<15} = mx.sym.broadcast_sub({}, {})".format(
1005                IR_node.variable_name,
1006                self.parent_variable_name(IR_node),
1007                self.parent_variable_name(IR_node, [1]))
1008
1009        return code
1010
1011
1012    def emit_Relu6(self, IR_node):
1013        codes = list()
1014        codes.append(self.emit_Activation(IR_node, 'relu'))
1015        old_name = IR_node.variable_name
1016        IR_node.real_name = IR_node.real_name + "_clip"
1017        codes.append("{:<15} = mx.sym.clip({}, a_min=0, a_max=6, name='{}')".format(
1018            IR_node.real_variable_name,
1019            old_name,
1020            IR_node.real_name))
1021
1022        return codes
1023
1024
1025    def emit_Slice(self, IR_node):
1026
1027        starts = IR_node.get_attr('starts')
1028        starts = [starts[0], starts[-1]] + starts[1:-1]
1029        ends = IR_node.get_attr('ends')
1030        ends = [ends[0], ends[-1]] + ends[1:-1]
1031        ends = [i if i else None for i in ends]
1032        strides = IR_node.get_attr('strides')
1033        if strides:
1034            strides = [strides[0], strides[-1]] + strides[1:-1]
1035
1036        code =  "{:<15} = mx.sym.slice({}, begin={}, end={}, step={}, name='{}')".format(
1037            IR_node.real_variable_name,
1038            self.parent_variable_name(IR_node),
1039            starts,
1040            ends,
1041            strides,
1042            IR_node.name
1043        )
1044        return code
1045
1046    def emit_Const(self, IR_node):
1047        pass
1048
1049    def emit_Shape(self, IR_node):
1050        code = "{:<15} = mx.sym.var(init = mx.init.Constant({}.infer_shape({}={})[1][0]), name='{}')".format(
1051            IR_node.real_variable_name,
1052            self.parent_variable_name(IR_node),
1053            list(self.input_name_shape.keys())[0],
1054            list(self.input_name_shape.values())[0],
1055            IR_node.name
1056        )
1057        return code
1058
1059    def emit_Pack(self, IR_node):
1060        pass
1061
1062    def emit_Unsqueeze(self, IR_node):
1063        axis = IR_node.get_attr('axes')[0]
1064        code = "{:<15} = mx.sym.expand_dims(data = {}, axis = {}, name = '{}')".format(
1065                IR_node.variable_name,
1066                self.parent_variable_name(IR_node),
1067                axis,
1068                IR_node.name)
1069
1070        return code
1071
1072    def emit_Unstack(self, IR_node):
1073        squeeze_axis = axis = IR_node.get_attr('axis')
1074        num = IR_node.get_attr('num')
1075        if num is None:
1076            args_str = ""
1077            for input_name in self.IR_graph.input_layers:
1078                if self.IR_graph.get_node(input_name).type!='Const':
1079                    args_str += '{}={}, '.format(self.IR_graph.get_node(input_name).real_variable_name, self.data_input_shape[input_name])
1080
1081            args_str = args_str[:-2]
1082            num_outputs = "{}.infer_shape({})[1][0][{}]".format(
1083                IR_node.variable_name,
1084                args_str,
1085                axis
1086            )
1087        else:
1088            num_outputs = num
1089
1090        code = "{:<15} = mx.sym.split({}, num_outputs={}, axis={}, squeeze_axis={})".format(
1091            IR_node.variable_name,
1092            self.parent_variable_name(IR_node),
1093            num_outputs,
1094            axis,
1095            squeeze_axis
1096        )
1097        return code
1098
1099    def emit_Fill(self, IR_node):
1100        value = IR_node.get_attr('value')
1101        code = "{:<15} = mx.sym.full({}, {})".format(
1102            IR_node.variable_name,
1103            self.parent_variable_name(IR_node),
1104            value
1105        )
1106        return code
1107
1108    def emit_Split(self, IR_node):
1109        axis = IR_node.get_attr('axis')
1110        num_outputs = IR_node.get_attr('split')
1111
1112        if isinstance(num_outputs, list):
1113            raise NotImplementedError()
1114        code = "{:<15} = mx.sym.split({}, num_outputs={}, axis={})".format(
1115            IR_node.variable_name,
1116            self.parent_variable_name(IR_node),
1117            num_outputs,
1118            axis)
1119
1120        return code
1121
1122
1123    def emit_Sigmoid(self, IR_node):
1124        code = "{:<15} = mx.sym.sigmoid(data={}, name='{}')".format(
1125            IR_node.variable_name,
1126            self.parent_variable_name(IR_node),
1127            IR_node.name
1128        )
1129        return code
1130
1131
1132    def emit_Tanh(self, IR_node):
1133        code = "{:<15} = mx.sym.tanh(data={}, name='{}')".format(
1134            IR_node.variable_name,
1135            self.parent_variable_name(IR_node),
1136            IR_node.name
1137        )
1138        return code
1139
1140
1141    def emit_Maxmum(self, IR_node):
1142        code = "{:<15} = mx.sym.maxmum({}, {}, name='{}')".format(
1143            IR_node.variable_name,
1144            self.parent_variable_name(IR_node),
1145            self.parent_variable_name(IR_node, [1]),
1146            IR_node.name
1147        )
1148        return code
1149
1150
1151    def emit_Minimum(self, IR_node):
1152        code = "{:<15} = mx.sym.minimum({}, {}, name='{}')".format(
1153            IR_node.variable_name,
1154            self.parent_variable_name(IR_node),
1155            self.parent_variable_name(IR_node, [1]),
1156            IR_node.name
1157        )
1158        return code
1159
1160
1161    def emit_Scope(self, IR_node):
1162        import re
1163        pattern = IR_node.pattern
1164
1165        if pattern not in self.naive_scope_pattern and re.sub(r'(_\d+)*$', '', IR_node.pattern) not in self.naive_scope_pattern:
1166            origi_pattern = re.sub(r'(_\d+)*$', '', IR_node.pattern)
1167            func = getattr(self, "_emit_" + origi_pattern)
1168            code = func(IR_node)
1169        else:
1170            code = "{:<15} = __{}({})".format(
1171                IR_node.real_variable_name,
1172                IR_node.pattern,
1173                ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges))
1174            self._gen_scope_code(IR_node)
1175        return code
1176
1177
1178    def _gen_scope_code(self, scope_node):
1179
1180        def _get_weight_related_op_name(node):
1181            weight_related_ops = ['Constant', 'Conv', 'FullyConnected', 'BatchNorm']
1182            op_type = node.type
1183            if op_type in weight_related_ops:
1184                return op_type, node.name
1185
1186        def _scope_func(params, code, return_var):
1187            code = """
1188    def __call__(self, {}):
1189{}
1190        return {}
1191    """.format(params, code, ', '.join(return_var))
1192            return code
1193
1194        class_inits = dict()
1195
1196        body_code = str()
1197        for node_name in scope_node.topology_list:
1198            node = self.IR_graph.get_node(node_name)
1199            node_type = node.type
1200
1201            if hasattr(self, "emit_" + node_type):
1202                func = getattr(self, "emit_" + node_type)
1203                line = func(node)
1204                if line != None:
1205                    body_code += "        " + line + '\n'
1206                    inits = _get_weight_related_op_name(node)
1207                    if inits:
1208                        if class_inits.get(inits[0], None):
1209                            class_inits[inits[0]].append(inits[1])
1210                        else:
1211                            class_inits[inits[0]] = list([inits[1]])
1212            else:
1213                print("MXNetEmitter has not supported operator [%s]." % (node_type))
1214                self.emit_UNKNOWN(node)
1215
1216        # param_code does not need parameter slice.
1217        param_code = ', '.join('%s'  %self.IR_graph.get_node(s).real_variable_name for s in scope_node.in_edges)
1218        function_code = _scope_func(param_code, body_code, scope_node.return_variables)
1219
1220        return class_inits, function_code
1221
1222
1223    def _emit_gru_cell(self, IR_node):
1224        if not self.layers_codes.get(IR_node.pattern, None):
1225            class_inits, func_code = self._gen_scope_code(IR_node)
1226            variables, variable_codes, init_code, func_code = self.process_inits_func_code(class_inits, func_code)
1227
1228            states = [self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges]
1229            states.pop(0)
1230            states_code = ', '.join(states)
1231
1232            class_code ='''
1233class _{}(mx.rnn.BaseRNNCell):
1234    def __init__(self, {}):
1235
1236{}
1237
1238{}
1239
1240            '''.format(IR_node.pattern,
1241            ', '.join(variables),
1242            init_code,
1243            func_code)
1244            self.layers_codes[IR_node.pattern] = class_code
1245
1246            if not hasattr(self, 'pattern_variables'):
1247                self.pattern_variables = {IR_node.pattern: variables}
1248            else:
1249                self.pattern_variables[IR_node.pattern] = variables
1250
1251            code = variable_codes
1252            code.append("{:<15} = _{}({})({})".format(
1253                IR_node.real_variable_name,
1254                IR_node.pattern,
1255                ', '.join(variables),
1256                ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges)))
1257        else:
1258            code = "{:<15} = _{}({})({})".format(
1259                IR_node.real_variable_name,
1260                IR_node.pattern,
1261                ', '.join(self.pattern_variables[IR_node.pattern]),
1262                ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges))
1263
1264        return code
1265
1266
1267    def _emit_h_zero(self, IR_node):
1268        code = "{:<15} = mx.sym.full((1, {}), {})".format(
1269            IR_node.variable_name,
1270            IR_node.get_attr('fill_size'),
1271            IR_node.get_attr('fill_value')
1272        )
1273        return code
1274
1275
1276    def _emit_lstm_cell(self, IR_node):
1277
1278        if not self.layers_codes.get(IR_node.pattern, None):
1279            class_inits, func_code = self._gen_scope_code(IR_node)
1280            variables, variable_codes, init_code, func_code = self.process_inits_func_code(class_inits, func_code)
1281
1282            states = [self.IR_graph.get_node(s).real_variable_name for s in IR_node.in_edges]
1283            states.pop(0)
1284            states_code = ', '.join(states)
1285
1286            class_code ='''
1287class _{}(mx.rnn.BaseRNNCell):
1288    def __init__(self, {}):
1289
1290{}
1291
1292{}
1293
1294            '''.format(IR_node.pattern,
1295            ', '.join(variables),
1296            init_code,
1297            func_code)
1298            self.layers_codes[IR_node.pattern] = class_code
1299
1300            if not hasattr(self, 'pattern_variables'):
1301                self.pattern_variables = {IR_node.pattern: variables}
1302            else:
1303                self.pattern_variables[IR_node.pattern] = variables
1304
1305            code = variable_codes
1306            code.append("{:<15} = _{}({})({})".format(
1307                IR_node.real_variable_name,
1308                IR_node.pattern,
1309                ', '.join(variables),
1310                ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges)))
1311        else:
1312            code = "{:<15} = _{}({})({})".format(
1313                IR_node.real_variable_name,
1314                IR_node.pattern,
1315                ', '.join(self.pattern_variables[IR_node.pattern]),
1316                ', '.join(self.parent_variable_name(IR_node, s) for s in IR_node.in_edges))
1317
1318        return code
1319
1320
1321    def process_inits_func_code(self, class_inits, func_code):
1322        init_code = str()
1323        variables = list()
1324        variable_codes = list()
1325        for k, v in class_inits.items():
1326            if k == 'FullyConnected':
1327                for i, name in enumerate(class_inits[k]):
1328                    variable_name = self.IR_graph.get_node(name).variable_name
1329                    variables.append("W_" + variable_name)
1330                    variable_codes.append("W_{:<15} = mx.sym.var(name='{}_weight')".format(variable_name, name))
1331                    init_code += "        self.W_{} = W_{}\n".format(variable_name, variable_name)
1332
1333                    if self.weight_loaded and self.weights[name].get('bias', None).any() != None:
1334                        variable_codes.append("B_{:<15} = mx.sym.var(name='{}_bias')".format(variable_name, name))
1335                        variables.append("B_" + variable_name)
1336                        init_code += "        self.B_{} = B_{}\n".format(variable_name, variable_name)
1337                        func_code = func_code.replace("name = '{}'".format(name), "name = '{}', weight = self.W_{}, bias = self.B_{}".format(name, variable_name, variable_name))
1338                    else:
1339                        func_code = func_code.replace("name = '{}'".format(name), "name = '{}', weight = self.W_{}".format(name, variable_name))
1340            elif k == 'Constant':
1341                for name in class_inits[k]:
1342                    variable_name = self.IR_graph.get_node(name.replace('_weight', '')).variable_name
1343                    variables.append(variable_name)
1344                    constant_line = self.emit_Constant(self.IR_graph.get_node(name.replace('_weight', '')))
1345                    variable_codes.append("{:<15} = {}".format(variable_name, '='.join(constant_line.split('=')[1:])))
1346                    init_code += "        self.{} = {}\n".format(variable_name, variable_name)
1347                    func_code = func_code.replace(constant_line, constant_line.split('=')[0] + ' = self.'+constant_line.split('=')[0])
1348            else:
1349                raise NotImplementedError
1350
1351        return variables, variable_codes, init_code, func_code
1352
1353