1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6import os
7
8from mmdnn.conversion.common.IR.IR_graph import IRGraph, IRGraphNode
9import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
10from mmdnn.conversion.common.IR.graph_pb2 import NodeDef, GraphDef, DataType
11from mmdnn.conversion.common.DataStructure.emitter import Emitter
12from mmdnn.conversion.common.utils import *
13from mmdnn.conversion.rewriter.folder import Folder
14
15
16class TensorflowEmitter(Emitter):
17
18    dtype_map = {
19        graph_pb2.DT_FLOAT16 : "tf.float16",
20        graph_pb2.DT_FLOAT32 : "tf.float32",
21        graph_pb2.DT_FLOAT64 : "tf.float64",
22        graph_pb2.DT_INT16 : "tf.int16",
23        graph_pb2.DT_INT32 : "tf.int32",
24        graph_pb2.DT_INT64 : "tf.int64",
25        graph_pb2.DT_UINT8 : "tf.uint8",
26        graph_pb2.DT_UINT16 : "tf.uint16"
27    }
28
29
30    @property
31    def header_code(self):
32        return """import tensorflow as tf
33
34_weights_dict = dict()
35
36is_train = {}
37
38def load_weights(weight_file):
39    import numpy as np
40
41    if weight_file == None:
42        return
43
44    try:
45        weights_dict = np.load(weight_file, allow_pickle=True).item()
46    except:
47        weights_dict = np.load(weight_file, allow_pickle=True, encoding='bytes').item()
48
49    return weights_dict
50
51
52def KitModel(weight_file = None):
53    global _weights_dict
54    _weights_dict = load_weights(weight_file)
55""".format(self.trainable)
56
57
58    def __init__(self, model):
59        super(TensorflowEmitter, self).__init__()
60
61        from six import string_types as _string_types
62        if isinstance(model, _string_types):
63            network_path = model
64        else:
65            network_path = model[0]
66            self._load_weights(model[1])
67
68        self.IR_graph = IRGraph(network_path)
69        super(TensorflowEmitter, self)._build()
70
71        folder = Folder(self.IR_graph, self.weights_dict)
72        folder.fold()
73
74    def gen_code(self, phase):
75        self.trainable = (phase == 'train')
76        self.add_body(0, self.header_code)
77
78        for layer in self.IR_graph.topological_sort:
79            current_node = self.IR_graph.get_node(layer)
80            node_type = current_node.type
81
82            if hasattr(self, "emit_" + node_type):
83                func = getattr(self, "emit_" + node_type)
84                line = func(current_node)
85                if line != None:
86                    self.add_body(1, line)
87            else:
88                print("TensorflowEmitter has not supported operator [%s]." % (node_type))
89                self.emit_UNKNOWN(current_node)
90
91
92        self.add_body(1, "return {}, {}".format(
93            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.input_layers if self.IR_graph.get_node(name).type != 'Const' and not self.IR_graph.get_node(name).get_attr('feed_weights')]),
94            ', '.join([self.IR_graph.get_node(name).real_variable_name for name in self.IR_graph.output_layers if self.IR_graph.get_node(name).type != 'Pack' and  self.IR_graph.get_node(name).type !='Shape'])))
95
96
97
98        self.add_body(0, "")
99        for i in self.used_layers:
100            func = getattr(self, "_layer_" + i)
101            func()
102
103        self.add_body(0, "")
104        for code in self.layers_codes.values():
105            self.add_body(0, code)
106
107        return self.body_code
108
109
110    def parent_variable_name(self, IR_node, path=[0]):
111        if not IR_node.in_edges and IR_node.name in self.weights_dict.keys():
112            return "tf.constant(_weights_dict['{}']['weights'], name='{}')".format(
113                IR_node.name,
114                IR_node.name)
115        return super(TensorflowEmitter, self).parent_variable_name(IR_node, path)
116
117
118    @staticmethod
119    def _shapeToStr(shapes):
120        ret = [dim.size if dim.size != -1 else 'None' for dim in shapes.dim]
121        return ', '.join('%s' % i for i in ret)
122
123
124    def emit_Conv(self, IR_node):
125        self.used_layers.add(IR_node.type)
126        strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides')[1:-1])
127        input_node, padding = self._defuse_padding(IR_node)
128        data_format = IR_node.get_attr('data_format')
129        code = "{:<15} = convolution({}, group={}, strides=[{}], padding='{}', name='{}')".format(
130            IR_node.variable_name,
131            input_node,
132            IR_node.get_attr('group', 1),
133            strides_str,
134            padding,
135            IR_node.name)
136        return code
137
138    def _defuse_padding(self, IR_node, extra_str=""):
139        auto_pad = IR_node.get_attr('auto_pad')
140        if auto_pad:
141            input_node = self.parent_variable_name(IR_node)
142            if auto_pad == 'VALID':
143                padding = 'VALID'
144            elif auto_pad.startswith("SAME"):
145                padding = 'SAME'
146            else:
147                raise ValueError("Unknown padding type [{}].".format(auto_pad))
148
149            return input_node, padding
150
151        else:
152            padding = IR_node.get_attr("pads")
153            padding = convert_onnx_pad_to_tf(padding)
154            if not is_valid_padding(padding):
155                input_node = IR_node.variable_name + '_pad'
156                self.add_body(1, "{:<15} = tf.pad({}, paddings = {}{})".format(
157                    input_node,
158                    self.parent_variable_name(IR_node),
159                    padding,
160                    extra_str
161                    ))
162            else:
163                input_node = self.parent_variable_name(IR_node)
164
165            return input_node, 'VALID'
166
167
168    def emit_Constant(self, IR_node):
169        if 'dtype' in IR_node.layer.attr:
170            dtype_str = "{}".format(self.dtype_map[IR_node.layer.attr['dtype'].type])
171        else:
172            dtype_str = "tf.float32"
173        code = "{:<15} = tf.constant({}, dtype={}, name='{}')".format(
174            IR_node.variable_name,
175            "_weights_dict['{}']['value']".format(IR_node.name) if IR_node.get_attr('value')== None else IR_node.get_attr('value'),
176            dtype_str,
177            IR_node.name)
178
179        return code
180
181
182    def emit_Pool(self, IR_node):
183        pooling_type = IR_node.get_attr('pooling_type')
184        if pooling_type == 'MAX':
185            op = 'max_pool'
186            padding_const = ", constant_values=float('-Inf')"
187        elif pooling_type == 'AVG':
188            op = 'avg_pool'
189            padding_const = ""
190        else:
191            raise ValueError("unknown pooling type [{}].".format(pooling_type))
192
193        arrlen = len(IR_node.get_attr('strides'))
194        dim_str = '3d' if arrlen == 5 else ""
195
196        if IR_node.layer.attr['global_pooling'].b:
197            code = "{:<15} = tf.nn.{}{}({}, [1] + {}.get_shape().as_list()[1:-1] + [1], strides = [1] * {}, padding = 'VALID', name = '{}')".format(
198                IR_node.variable_name,
199                op,
200                dim_str,
201                self.parent_variable_name(IR_node),
202                self.parent_variable_name(IR_node),
203                arrlen,
204                IR_node.name)
205        else:
206            dim = len(IR_node.get_attr("strides")) - 2
207            dilations = IR_node.get_attr('dilations')
208            if dilations:
209                for e in IR_node.get_attr('dilations'):
210                    assert e == 1
211
212            pool_size = IR_node.get_attr('kernel_shape')[1:-1]
213            strides = IR_node.get_attr('strides')[1:-1]
214            padding = IR_node.get_attr('pads')[1:dim]
215
216            if pooling_type == "AVG" and pool_size.count(pool_size[0]) == len(pool_size) and strides[0] == 1 and strides.count(strides[0]) == len(strides) and padding.count(padding[0]) == len(padding) and pool_size[0] == padding[0]*2 + 1:
217                kernel_shape_str = ', '.join('%s' % i for i in IR_node.get_attr('kernel_shape'))
218                strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides'))
219
220                code = "{:<15} = tf.nn.{}{}({}, [{}], [{}], padding='{}', name='{}')".format(
221                    IR_node.variable_name,
222                    op,
223                    dim_str,
224                    self.parent_variable_name(IR_node),
225                    kernel_shape_str,
226                    strides_str,
227                    'SAME',
228                    IR_node.name)
229            else:
230                kernel_shape_str = ', '.join('%s' % i for i in IR_node.get_attr('kernel_shape'))
231                strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides'))
232                input_node, padding = self._defuse_padding(IR_node, padding_const)
233                code = "{:<15} = tf.nn.{}{}({}, [{}], [{}], padding='{}', name='{}')".format(
234                    IR_node.variable_name,
235                    op,
236                    dim_str,
237                    input_node,
238                    kernel_shape_str,
239                    strides_str,
240                    padding,
241                    IR_node.name)
242
243        return code
244
245    def emit_UNKNOWN(self, IR_node):
246        print(IR_node.name)
247
248    def emit_Add(self, IR_node):
249        code = "{:<15} = {}".format(
250            IR_node.variable_name,
251            ' + '.join('%s' % self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))))
252
253        return code
254
255    def emit_DataInput(self, IR_node):
256        assert not IR_node.in_edges
257        shape_str = self._shapeToStr(IR_node.layer.attr["shape"].shape)
258
259        if 'dtype' in IR_node.layer.attr:
260            dtype_str = "{}, ".format(self.dtype_map[IR_node.layer.attr['dtype'].type])
261        else:
262            dtype_str = "tf.float32,"
263
264        code = "{:<15} = tf.placeholder({} shape = ({}), name = '{}')".format(
265            IR_node.variable_name, dtype_str, shape_str, IR_node.name
266        )
267        return code
268
269    def emit_Dropout(self, IR_node):
270        parent = self.IR_graph.get_parent(IR_node.name, [0])
271        if self.trainable:
272            self.add_body(1, "{:<15} = Dropout(name = '{}', dropout_rate = {})({})".format(
273                IR_node.variable_name,
274                IR_node.name,
275                1 - IR_node.IR_layer.attr["keep_prob"].f,
276                parent.real_variable_name))
277        else:
278            IR_node.real_name = parent.real_name
279
280
281    def emit_FullyConnected(self, IR_node):
282        if IR_node.name in self.weights_dict and 'weights' in self.weights_dict[IR_node.name]:
283            kernel_str = "kernel_initializer = tf.constant_initializer(_weights_dict['{}']['weights']), ".format(IR_node.name)
284        else: kernel_str = ""
285
286        if IR_node.name in self.weights_dict and 'bias' in self.weights_dict[IR_node.name]:
287            bias_str = "bias_initializer = tf.constant_initializer(_weights_dict['{}']['bias']), ".format(IR_node.name)
288        else: bias_str = ""
289
290        # check whether flatten operator should be added
291        parent = self.IR_graph.get_parent(IR_node.name, [0])
292        parent_shape = shape_to_list(parent.get_attr('_output_shapes')[0])
293        if len(parent_shape) > 2:
294            # flatten is needed
295            self.add_body(1, "{:<15} = tf.contrib.layers.flatten({})".format(
296                IR_node.variable_name + '_flatten',
297                self.parent_variable_name(IR_node)))
298
299            code = "{:<15} = tf.layers.dense({}, {}, {}{}use_bias = {})".format(
300                IR_node.variable_name,
301                IR_node.variable_name + '_flatten',
302                IR_node.layer.attr['units'].i,
303                kernel_str,
304                bias_str,
305                IR_node.layer.attr['use_bias'].b)
306            return code
307
308        else:
309            code = "{:<15} = tf.layers.dense({}, {}, {}{}use_bias = {})".format(
310                IR_node.variable_name,
311                self.parent_variable_name(IR_node),
312                IR_node.layer.attr['units'].i,
313                kernel_str,
314                bias_str,
315                IR_node.layer.attr['use_bias'].b)
316            return code
317
318
319    def emit_UpSampling2D(self, IR_node):
320        scales = IR_node.get_attr('scales')
321        scales = tuple(scales)
322
323        code = "{:<15} = tf.keras.layers.UpSampling2D(size={})({})".format(
324            IR_node.variable_name,
325            scales,
326            self.parent_variable_name(IR_node))
327        return code
328
329    def emit_Flatten(self, IR_node):
330        #self._emit_unary_operation(IR_node, "contrib.layers.flatten")
331        code = "{:<15} = tf.contrib.layers.flatten({})".format(
332            IR_node.variable_name,
333            self.parent_variable_name(IR_node))
334        return code
335
336
337    def emit_Mul(self, IR_node):
338
339        code = "{:<15} = {}".format(
340            IR_node.variable_name,
341            ' * '.join('%s' % self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))))
342        return code
343
344
345    def emit_Const(self, IR_node):
346        if 'dtype' in IR_node.layer.attr:
347            dtype_str = "dtype={}".format(self.dtype_map[IR_node.layer.attr['dtype'].type])
348            if 'int' in dtype_str:
349                code = "{:<15} = tf.constant({}, {}, shape=(1,))".format(
350                    IR_node.variable_name,
351                    IR_node.layer.attr['value'].i,
352                    dtype_str)
353            else:
354                code = "{:<15} = tf.constant({}, {}, shape=(1,))".format(
355                    IR_node.variable_name,
356                    IR_node.layer.attr['value'].f,
357                    dtype_str)
358        else:
359            dtype_str = "dtype=tf.float32"
360            code ="{:<15} = tf.constant({}, {}, shape=(1,))".format(
361                IR_node.variable_name,
362                IR_node.layer.attr['value'].f,
363                dtype_str)
364
365        return code
366
367    def emit_Transpose(self, IR_node):
368        code ="{:<15} = tf.transpose(a = {}, perm = {})".format(
369            IR_node.variable_name,
370            self.parent_variable_name(IR_node, [0]),
371            self.parent_variable_name(IR_node, [1]))
372
373        return code
374
375    def emit_Gather(self, IR_node):
376        variable_str = "tf.convert_to_tensor(_weights_dict['{}']['weights'])".format(IR_node.name)
377
378        code = "{:<15} = tf.gather(params = {}, indices = {}, axis = {})".format(
379            IR_node.variable_name,
380            variable_str,
381            self.parent_variable_name(IR_node),
382            IR_node.get_attr('axis')
383            )
384
385        return code
386
387    def emit_Unstack(self, IR_node):
388        code = "{:<15} = tf.unstack(value={}, num={}, axis={})".format(
389            IR_node.variable_name,
390            self.parent_variable_name(IR_node),
391            IR_node.get_attr('num'),
392            IR_node.get_attr('axis')
393        )
394        return code
395
396    def emit_Reshape(self, IR_node):
397        code = "{:<15} = tf.reshape({}, [{}], '{}')".format(
398            IR_node.variable_name,
399            self.parent_variable_name(IR_node),
400            ', '.join('%s' % i for i in IR_node.get_attr('shape')),
401            IR_node.name)
402
403        return code
404
405
406    def emit_Sub(self, IR_node):
407        code = "{:<15} = {}".format(
408            IR_node.variable_name,
409            ' - '.join('%s' % self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))))
410
411        return code
412
413    def emit_Div(self, IR_node):
414        code = "{:<15} = tf.div({}, {}, name='{}')".format(
415            IR_node.variable_name,
416            self.parent_variable_name(IR_node),
417            self.parent_variable_name(IR_node, [1]),
418            IR_node.name
419        )
420        return code
421
422    def _emit_unary_operation(self, IR_node, op_name):
423        code = "{:<15} = tf.{}({}, name = '{}')".format(
424            IR_node.variable_name,
425            op_name,
426            self.parent_variable_name(IR_node),
427            IR_node.name)
428        return code
429
430    def emit_Tanh(self, IR_node):
431        code = self._emit_unary_operation(IR_node, 'tanh')
432        return code
433
434    def emit_Elu(self, IR_node):
435        return self._emit_unary_operation(IR_node, 'nn.elu')
436
437
438    def emit_Relu(self, IR_node):
439        return self._emit_unary_operation(IR_node, 'nn.relu')
440
441
442    def emit_Relu6(self, IR_node):
443        return self._emit_unary_operation(IR_node, 'nn.relu6')
444
445
446    def emit_CRelu(self, IR_node):
447        return self._emit_unary_operation(IR_node, 'nn.crelu')
448
449
450    def emit_PRelu(self, IR_node):
451        self.used_layers.add(IR_node.type)
452        code = "{:<15} = prelu({}, name='{}')".format(
453            IR_node.variable_name,
454            self.parent_variable_name(IR_node),
455            IR_node.name)
456        return code
457
458    def emit_LeakyRelu(self, IR_node):
459        self.add_body(1, "{:<15} = tf.nn.leaky_relu({}, alpha={}, name='{}')".format(
460            IR_node.variable_name,
461            self.parent_variable_name(IR_node),
462            IR_node.get_attr('alpha'),
463            IR_node.name
464        ))
465
466
467    def emit_Softmax(self, IR_node):
468        return self._emit_unary_operation(IR_node, 'nn.softmax')
469
470
471    def emit_Sigmoid(self, IR_node):
472        code = self._emit_unary_operation(IR_node, 'sigmoid')
473        return code
474
475    def emit_Embedding(self, IR_node):
476        variable_str = "tf.convert_to_tensor(_weights_dict['{}']['weights'])".format(IR_node.name)
477        code = "{:<15} = tf.nn.embedding_lookup(params = {}, ids = {})".format(
478            IR_node.variable_name,
479            variable_str,
480            self.parent_variable_name(IR_node))
481        return code
482
483    def emit_LSTM(self, IR_node):
484        return self.emit_RNNs(IR_node, "LSTM")
485
486
487    def emit_GRU(self, IR_node):
488        return self.emit_RNNs(IR_node, "GRU")
489
490
491    def emit_Concat(self, IR_node):
492
493        code = "{:<15} = tf.concat([{}], {}, name = '{}')".format(
494            IR_node.variable_name,
495            ', '.join(self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))),
496            IR_node.layer.attr['axis'].i,
497            IR_node.name)
498
499        return code
500
501    def emit_BatchNorm(self, IR_node):
502        self.used_layers.add(IR_node.type)
503        code = "{:<15} = batch_normalization({}, variance_epsilon={}, name='{}')".format(
504            IR_node.variable_name,
505            self.parent_variable_name(IR_node),
506            IR_node.get_attr('epsilon'),
507            IR_node.name)
508        return code
509
510    def emit_Scale(self, IR_node):
511        self.used_layers.add(IR_node.type)
512        code = "{:<15} = scale({}, name='{}')".format(
513            IR_node.variable_name,
514            self.parent_variable_name(IR_node),
515            IR_node.name)
516        return code
517
518    def emit_Pad(self, IR_node):
519        padding = IR_node.get_attr('pads')
520        padding = convert_onnx_pad_to_tf(padding)
521
522        mode = IR_node.get_attr('mode', 'constant')
523        mode = mode.lower()
524        if mode == 'constant' or mode == 'reflect':
525            mode = mode.upper()
526        elif mode == 'edge':
527            mode = 'SYMMETRIC'
528        else:
529            raise NotImplementedError("Not support padding mode {}.".format(mode))
530        code = "{:<15} = tf.pad({}, {}, '{}', name='{}')".format(
531            IR_node.variable_name,
532            self.parent_variable_name(IR_node),
533            padding,
534            mode,
535            IR_node.variable_name)
536        return code
537
538    def emit_Squeeze(self, IR_node):
539        code = "{:<15} = tf.squeeze({}, [{}], name = '{}')".format(
540            IR_node.variable_name,
541            self.parent_variable_name(IR_node),
542            ', '.join('%s' % axis for axis in IR_node.layer.attr['axes'].list.i),
543            IR_node.name)
544        return code
545
546
547    def emit_ReduceMean(self, IR_node):
548        code = "{:<15} = tf.reduce_mean({}, [{}], {}, name = '{}')".format(
549            IR_node.variable_name,
550            self.parent_variable_name(IR_node),
551            ','.join('%s' % i for i in IR_node.get_attr('axes')),
552            IR_node.get_attr('keepdims'),
553            IR_node.name)
554        return code
555
556    def emit_LRN(self, IR_node):
557        input_name = IR_node.variable_name
558        output_name = self.parent_variable_name(IR_node)
559        IR_name = IR_node.name
560        size = IR_node.get_attr('size')
561        depth_radius = int(IR_node.get_attr('size') / 2)
562        bias = IR_node.get_attr('bias', 1)
563        alpha = IR_node.get_attr('alpha') / size
564        beta = IR_node.get_attr('beta')
565
566        code = "{:<15} = tf.nn.lrn({}, depth_radius={}, bias={}, alpha={}, beta={}, name='{}')".format(
567            input_name,
568            output_name,
569            depth_radius,
570            bias,
571            alpha,
572            beta,
573            IR_name)
574        return code
575
576    def emit_SeparableConv(self, IR_node):
577        self.used_layers.add(IR_node.type)
578        strides_str = ', '.join('%s' % i for i in IR_node.get_attr('strides'))
579        input_node, padding = self._defuse_padding(IR_node)
580        code = "{:<15} = separable_convolution({}, strides = [{}], padding = '{}', name = '{}')".format(
581            IR_node.variable_name,
582            input_node,
583            strides_str,
584            padding,
585            IR_node.name)
586        return code
587
588
589    def emit_DepthwiseConv(self, IR_node):
590        self.used_layers.add(IR_node.type)
591        strides_str = ', '.join('%s' % i for i in IR_node.layer.attr['strides'].list.i)
592        input_node, padding = self._defuse_padding(IR_node)
593        code = "{:<15} = depthwise_convolution({}, strides = [{}], padding = '{}', name = '{}')".format(
594            IR_node.variable_name,
595            input_node,
596            strides_str,
597            padding,
598            IR_node.name)
599        return code
600
601    def emit_Crop(self, IR_node):
602        border = IR_node.get_attr('border')
603        assert len(border) == 4
604
605        output_shape = IR_node.get_attr('_output_shapes')[0]
606        output_shape = shape_to_list(output_shape)
607
608        code = "{:<15} = tf.image.crop_to_bounding_box({}, offset_height={}, offset_width={}, target_height={}, target_width={})".format(
609            IR_node.variable_name,
610            self.parent_variable_name(IR_node),
611            border[0],
612            border[1],
613            output_shape[1],
614            output_shape[2])
615
616        return code
617
618    def emit_ConvTranspose(self, IR_node):
619        self.used_layers.add(IR_node.type)
620        output_shape = [1] + shape_to_list(IR_node.get_attr('_output_shapes')[0])[1:]
621        input_node, padding = self._defuse_padding(IR_node)
622        code = "{:<15} = convolution_transpose({}, output_shape={}, strides={}, padding='{}', name='{}')".format(
623            IR_node.variable_name,
624            input_node,
625            output_shape,
626            IR_node.get_attr('strides'),
627            padding,
628            IR_node.name)
629        return code
630
631    def emit_Slice(self, IR_node):
632        extra_str = ""
633        if IR_node.get_attr('begin_mask'):
634            extra_str += ", begin_mask={}".format(IR_node.get_attr('begin_mask'))
635        if IR_node.get_attr('end_mask') != None:
636            extra_str += ", end_mask={}".format(IR_node.get_attr('end_mask'))
637        if IR_node.get_attr('shrink_axis_mask') != None:
638            extra_str += ", shrink_axis_mask={}".format(IR_node.get_attr('shrink_axis_mask'))
639        if IR_node.get_attr('new_axis_mask')!= None:
640            extra_str += ", new_axis_mask={}".format(IR_node.get_attr('new_axis_mask'))
641
642        if IR_node.get_attr('starts') != None:
643            starts = IR_node.get_attr('starts')
644        else:
645            starts = self.parent_variable_name(IR_node, [1])
646
647        if IR_node.get_attr('ends') != None:
648            ends = IR_node.get_attr('ends')
649        else:
650            ends = self.parent_variable_name(IR_node, [2])
651
652        if IR_node.get_attr('strides') != None:
653            strides = IR_node.get_attr('strides')
654        else:
655            strides = self.parent_variable_name(IR_node, [3])
656
657        code = "{:<15} = tf.strided_slice({}, {}, {}, {} {}, name='{}')".format(
658            IR_node.variable_name,
659            self.parent_variable_name(IR_node),
660            starts,
661            ends,
662            strides,
663            extra_str,
664            IR_node.name)
665
666        return code
667
668
669    def emit_Shape(self, IR_node):
670        code = "{:<15} = tf.shape({}, name='{}')".format(
671            IR_node.variable_name,
672            self.parent_variable_name(IR_node),
673            IR_node.name)
674        return code
675
676    def emit_Pack(self, IR_node):
677        code = "{:<15} = tf.stack({}, axis={}, name='{}')".format(
678            IR_node.variable_name,
679            '[' +  ','.join('%s' % self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))) + ']',
680            IR_node.get_attr('axis'),
681            IR_node.name)
682        return code
683
684    def emit_Split(self, IR_node):
685        code = "{:<15} = tf.split({}, {}, {}, name='{}')".format(
686            IR_node.variable_name,
687            self.parent_variable_name(IR_node),
688            IR_node.get_attr('split'),
689            IR_node.get_attr('axis'),
690            IR_node.name)
691        return code
692
693    def emit_Unsqueeze(self, IR_node):
694        code = "{:<15} = tf.expand_dims({}, axis={}, name='{}')".format(
695            IR_node.variable_name,
696            self.parent_variable_name(IR_node),
697            IR_node.get_attr('axes')[0],
698            IR_node.name)
699        return code
700
701    def emit_Fill(self, IR_node):
702        code = "{:<15} = tf.fill({}, {}, name='{}')".format(
703            IR_node.variable_name,
704            self.parent_variable_name(IR_node),
705            IR_node.get_attr('value'),
706            IR_node.name)
707        return code
708
709    def emit_Maximum(self, IR_node):
710        code = "{:<15} = tf.maximum({}, {}, name='{}')".format(
711            IR_node.variable_name,
712            self.parent_variable_name(IR_node),
713            self.parent_variable_name(IR_node, [1]),
714            IR_node.name
715        )
716        return code
717
718    def emit_Minimum(self, IR_node):
719        code = "{:<15} = tf.minimum({}, {}, name='{}')".format(
720            IR_node.variable_name,
721            self.parent_variable_name(IR_node),
722            self.parent_variable_name(IR_node, [1]),
723            IR_node.name
724        )
725        return code
726
727    def emit_Scope(self, IR_node):
728        input_vars = [self.parent_variable_name(IR_node, [idx]) for idx in range(len(IR_node.in_edges))]
729        input_vars.append('_weights_dict')
730        code = "{:<15} = _{}({})".format(
731            IR_node.real_variable_name,
732            IR_node.pattern,
733            ', '.join(input_vars))
734        self._gen_scope_code(IR_node)
735        return code
736
737
738    def _gen_scope_code(self, scope_node):
739
740        def _scope_func(scope_name, params, code, return_var):
741            code = """
742def _{}({}):
743{}
744    return {}
745    """.format(scope_name, params, code, ', '.join(return_var))
746            return code
747
748        if not self.layers_codes.get(scope_node.pattern, None):
749            body_code = str()
750            for node_name in scope_node.topology_list:
751                node = self.IR_graph.get_node(node_name)
752                node_type = node.type
753
754                if hasattr(self, "emit_" + node_type):
755                    func = getattr(self, "emit_" + node_type)
756                    line = func(node)
757                    if line != None:
758                        body_code += "    " + line + '\n'
759                else:
760                    print("TensorflowEmitter has not supported operator [%s]." % (node_type))
761                    self.emit_UNKNOWN(node)
762
763            # param_code does not need parameter slice.
764            input_params = scope_node.input_params
765            input_params.append("_weights_dict")
766            param_code = ', '.join(input_params)
767            function_code = _scope_func(scope_node.pattern, param_code, body_code, scope_node.return_variables)
768
769            self.layers_codes[scope_node.pattern] = function_code
770
771
772
773    def _layer_Conv(self):
774        self.add_body(0, """
775def convolution(input, name, group, **kwargs):
776    w = tf.Variable(_weights_dict[name]['weights'], trainable=is_train, name=name + "_weight")
777    if group == 1:
778        layer = tf.nn.convolution(input, w, name=name, **kwargs)
779    else:
780        weight_groups = tf.split(w, num_or_size_splits=group, axis=-1)
781        xs = tf.split(input, num_or_size_splits=group, axis=-1)
782        convolved = [tf.nn.convolution(x, weight, name=name, **kwargs) for
783                    (x, weight) in zip(xs, weight_groups)]
784        layer = tf.concat(convolved, axis=-1)
785
786    if 'bias' in _weights_dict[name]:
787        b = tf.Variable(_weights_dict[name]['bias'], trainable=is_train, name=name + "_bias")
788        layer = layer + b
789    return layer""")
790
791
792    def _layer_PRelu(self):
793        self.add_body(0, """
794def prelu(input, name):
795    gamma = tf.Variable(_weights_dict[name]['gamma'], name=name + "_gamma", trainable=is_train)
796    return tf.maximum(0.0, input) + gamma * tf.minimum(0.0, input)
797    """)
798
799
800    def _layer_BatchNorm(self):
801        self.add_body(0, """
802def batch_normalization(input, name, **kwargs):
803    mean = tf.Variable(_weights_dict[name]['mean'], name = name + "_mean", trainable = is_train)
804    variance = tf.Variable(_weights_dict[name]['var'], name = name + "_var", trainable = is_train)
805    offset = tf.Variable(_weights_dict[name]['bias'], name = name + "_bias", trainable = is_train) if 'bias' in _weights_dict[name] else None
806    scale = tf.Variable(_weights_dict[name]['scale'], name = name + "_scale", trainable = is_train) if 'scale' in _weights_dict[name] else None
807    return tf.nn.batch_normalization(input, mean, variance, offset, scale, name = name, **kwargs)
808""")
809
810
811    def _layer_Scale(self):
812        self.add_body(0, """
813def scale(input, name, **kwargs):
814    mean = tf.Variable(_weights_dict[name]['scale_mean'], name = name + "_mean", trainable = is_train)
815    variance = tf.Variable(_weights_dict[name]['scale_var'], name = name + "_var", trainable = is_train)
816    offset = tf.Variable(_weights_dict[name]['bias'], name = name + "_bias", trainable = is_train) if 'bias' in _weights_dict[name] else None
817    scale = tf.Variable(_weights_dict[name]['scale'], name = name + "_scale", trainable = is_train) if 'scale' in _weights_dict[name] else None
818    return tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon = 0, name = name)
819""")
820
821
822    def _layer_SeparableConv(self):
823        self.add_body(0, """
824def separable_convolution(input, name, **kwargs):
825    depthwise = tf.Variable(_weights_dict[name]['depthwise_filter'], trainable = is_train, name = name + "_df")
826    pointwise = tf.Variable(_weights_dict[name]['pointwise_filter'], trainable = is_train, name = name + "_pf")
827    layer = tf.nn.separable_conv2d(input, depthwise, pointwise, **kwargs)
828    if 'bias' in _weights_dict[name]:
829        b = tf.Variable(_weights_dict[name]['bias'], trainable = is_train, name = name + "_bias")
830        layer = layer + b
831    return layer""")
832
833
834    def _layer_DepthwiseConv(self):
835        self.add_body(0, """
836def depthwise_convolution(input, name, **kwargs):
837    depthwise = tf.Variable(_weights_dict[name]['weights'], trainable = is_train, name = name + "_df")
838    layer = tf.nn.depthwise_conv2d(input, depthwise, **kwargs)
839    if 'bias' in _weights_dict[name]:
840        b = tf.Variable(_weights_dict[name]['bias'], trainable = is_train, name = name + "_bias")
841        layer = layer + b
842    return layer""")
843
844
845    def _layer_ConvTranspose(self):
846        self.add_body(0, """
847def convolution_transpose(input, name, **kwargs):
848    w = tf.Variable(_weights_dict[name]['weights'], trainable=is_train, name=name + "_weight")
849    dim = _weights_dict[name]['weights'].ndim - 2
850    if dim == 2:
851        layer = tf.nn.conv2d_transpose(input, w, **kwargs)
852    elif dim == 3:
853        layer = tf.nn.conv3d_transpose(input, w, **kwargs)
854    else:
855        raise ValueError("Error dim number {} in ConvTranspose".format(dim))
856
857    if 'bias' in _weights_dict[name]:
858        b = tf.Variable(_weights_dict[name]['bias'], trainable=is_train, name=name + "_bias")
859        layer = layer + b
860    return layer""")
861