1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17# pylint: disable=invalid-name, import-self
18"""Keras frontend."""
19from __future__ import absolute_import as _abs
20import sys
21import numpy as np
22import tvm
23from .. import analysis
24from .. import expr as _expr
25from .. import module as _module
26from .. import op as _op
27from ... import nd as _nd
28from .common import ExprTable, new_var
29
30__all__ = ['from_keras']
31
32
33def _check_data_format(keras_layer):
34    if hasattr(keras_layer, ('data_format')):
35        if keras_layer.data_format != 'channels_last':
36            raise ValueError("Keras frontend currently supports data_format = channels_last only.")
37
38
39def _get_pad_pair(input1d, kernel1d, stride1d):
40    out1d = (input1d + stride1d - 1) // stride1d
41    pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0)
42    pad_before = pad // 2
43    pad_after = pad - pad_before
44    return [pad_before, pad_after]
45
46
47def _get_elu(inexpr, alpha):
48    """A helper method for elu."""
49    return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \
50        _op.exp(inexpr)) + _op.nn.relu(inexpr)
51
52
53def _as_list(arr):
54    """Force being a list, ignore if already is."""
55    if isinstance(arr, list):
56        return arr
57    return [arr]
58
59
60def _convert_recurrent_activation(inexpr, keras_layer):
61    act_type = keras_layer.recurrent_activation.__name__
62    return _convert_activation(inexpr, act_type, None)
63
64
65def _convert_activation(inexpr, keras_layer, _):
66    if isinstance(keras_layer, str):
67        act_type = keras_layer
68    else:
69        if sys.version_info.major < 3:
70            act_type = keras_layer.activation.func_name
71        else:
72            act_type = keras_layer.activation.__name__
73    if act_type == 'linear':
74        if isinstance(keras_layer, str):
75            return inexpr
76        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
77        beta = keras_layer.beta if hasattr(keras_layer, 'beta') else 0.
78        alpha = _expr.const(alpha, dtype='float32')
79        beta = _expr.const(beta, dtype='float32')
80        return _op.add(_op.multiply(inexpr, alpha), beta)
81    if act_type == 'softmax':
82        return _op.nn.softmax(inexpr, axis=1)
83    if act_type == 'sigmoid':
84        return _op.sigmoid(inexpr)
85    if act_type == 'tanh':
86        return _op.tanh(inexpr)
87    if act_type == 'relu':
88        return _op.nn.relu(inexpr)
89    if act_type == 'softplus':
90        return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32')))
91    if act_type == 'elu':
92        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
93        alpha = _expr.const(alpha, dtype='float32')
94        return _get_elu(inexpr, alpha)
95    if act_type == 'selu':
96        # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515
97        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \
98            else 1.6732632423543772848170429916717
99        gamma = keras_layer.gamma if hasattr(keras_layer, 'gamma') \
100            else 1.0507009873554804934193349852946
101        alpha = _expr.const(alpha, dtype='float32')
102        gamma = _expr.const(gamma, dtype='float32')
103        return gamma * _get_elu(inexpr, alpha)
104    if act_type == 'relu6':
105        return _op.clip(inexpr, a_min=0., a_max=6.)
106    if act_type == 'softsign':
107        return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr))
108    if act_type == 'hard_sigmoid':
109        x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32')
110        return _op.clip(x, a_min=0., a_max=1.)
111
112    raise tvm.error.OpNotImplemented(
113        'Operator {} is not supported in frontend Keras.'.format(act_type))
114
115
116def _convert_advanced_activation(inexpr, keras_layer, etab):
117    act_type = type(keras_layer).__name__
118
119    if act_type == 'Softmax':
120        axis = keras_layer.axis
121        dims = len(keras_layer.input_shape)
122        if isinstance(axis, list):
123            raise tvm.error.OpAttributeUnImplemented(
124                'Softmax with axes {} is not supported.'.format(axis))
125        if axis == -1:
126            axis = 1
127        else:
128            axis = axis + 1 if axis < dims - 1 else 1
129        return _op.nn.softmax(inexpr, axis=axis)
130    if act_type == 'ReLU':
131        threshold = _expr.const(keras_layer.threshold, dtype='float32')
132        if keras_layer.max_value and float(keras_layer.threshold) == 0:
133            # f(x) = max_value, for x >= max_value
134            # f(x) = x,         for threshold <= x < max_value
135            return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value))
136        elif keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'):
137            # f(x) = negative_slope * (inexpr - threshold)
138            negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32')
139            return _op.multiply(negative_slope, _op.subtract(inexpr, threshold))
140        return _op.nn.relu(inexpr)
141    if act_type == 'LeakyReLU':
142        return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha))
143    if act_type == 'ELU':
144        alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1.
145        alpha = _expr.const(alpha, dtype='float32')
146        return _get_elu(inexpr, alpha)
147    if act_type == 'PReLU':
148        assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU."
149        _check_data_format(keras_layer)
150        size = len(keras_layer.alpha.shape)
151        alpha = etab.new_const(keras_layer.get_weights()[0] \
152                               .transpose(np.roll(range(size), 1)))
153        return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr)
154    if act_type == 'ThresholdedReLU':
155        theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1.
156        return _op.multiply(inexpr, _op.greater(inexpr, \
157            _expr.const(theta, dtype='float32')).astype('float32'))
158
159    raise tvm.error.OpNotImplemented(
160        'Operator {} is not supported in frontend Keras.'.format(act_type))
161
162
163def _convert_merge(inexpr, keras_layer, _):
164    merge_type = type(keras_layer).__name__
165    ret = inexpr[0]
166    if merge_type == 'Dot':
167        axes = keras_layer.axes
168        if isinstance(keras_layer.axes, int):
169            axes = [keras_layer.axes, keras_layer.axes]
170        if isinstance(axes, list):
171            if len(axes) != 2:
172                raise tvm.error.OpAttributeUnImplemented(
173                    'Dot with axes {} is not supported.'.format(keras_layer.axes))
174            for i, axis in enumerate(axes):
175                if axis not in [1, 2]:
176                    raise tvm.error.OpAttributeUnImplemented(
177                        'Dot with axes {} is not supported.'.format(keras_layer.axes))
178                if axes[i] == 2:
179                    inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1])
180        else:
181            raise tvm.error.OpAttributeUnImplemented(
182                'Dot with axes {} is not supported.'.format(keras_layer.axes))
183        ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1])
184        ret = _op.transpose(ret_dot, axes=[0, 2, 1])
185    elif merge_type == 'Subtract':
186        assert len(inexpr) == 2, "Subtract merge takes 2 inputs."
187        ret = _op.subtract(ret, inexpr[1])
188    elif merge_type in ['Add', 'Multiply', 'Maximum']:
189        op_map = {'Add':_op.add, 'Multiply':_op.multiply, 'Maximum':_op.maximum}
190        for i in range(1, len(inexpr)):
191            ret = op_map[merge_type](ret, inexpr[i])
192    elif merge_type == 'Average':
193        for i in range(1, len(inexpr)):
194            ret = _op.add(ret, inexpr[i])
195        ret = ret / _expr.const(len(inexpr), dtype='float32')
196    else:
197        raise tvm.error.OpNotImplemented(
198            'Operator {} is not supported in frontend Keras.'.format(merge_type))
199    return ret
200
201
202def _convert_permute(inexpr, keras_layer, _):
203    return _op.transpose(inexpr, axes=(0,) + keras_layer.dims)
204
205
206def _convert_dense(inexpr, keras_layer, etab):
207    weightList = keras_layer.get_weights()
208    weight = etab.new_const(weightList[0].transpose([1, 0]))
209    params = {'weight':weight, 'units':weightList[0].shape[1]}
210    input_shape = keras_layer.input_shape
211    input_dim = len(input_shape)
212    # In case of RNN dense, input shape will be (1, 1, n)
213    if input_dim > 2:
214        input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0])
215        if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1:
216            raise tvm.error.OpAttributeInvalid(
217                'Input shape {} is not valid for operator Dense.'.format(input_shape))
218        inexpr = _op.squeeze(inexpr, axis=0)
219    out = _op.nn.dense(data=inexpr, **params)
220    if keras_layer.use_bias:
221        bias = etab.new_const(weightList[1])
222        out = _op.nn.bias_add(out, bias)
223    # defuse activation
224    if sys.version_info.major < 3:
225        act_type = keras_layer.activation.func_name
226    else:
227        act_type = keras_layer.activation.__name__
228    if act_type != 'linear':
229        out = _convert_activation(out, act_type, etab)
230    if input_dim > 2:
231        out = _op.expand_dims(out, axis=0)
232    return out
233
234
235def _convert_convolution(inexpr, keras_layer, etab):
236    _check_data_format(keras_layer)
237    is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose'
238    is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D'
239    weightList = keras_layer.get_weights()
240    if is_deconv:
241        kernel_h, kernel_w, n_filters, in_channels = weightList[0].shape
242        weight = weightList[0].transpose([3, 2, 0, 1])
243    elif is_depthconv:
244        kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
245        weight = weightList[0].transpose([2, 3, 0, 1])
246    else:
247        kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape
248        weight = weightList[0].transpose([3, 2, 0, 1])
249    if isinstance(keras_layer.dilation_rate, (list, tuple)):
250        dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]]
251    else:
252        dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate]
253    dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1
254    dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1
255    stride_h, stride_w = keras_layer.strides
256    params = {'weight': etab.new_const(weight),
257              'kernel_size': [kernel_h, kernel_w],
258              'strides': [stride_h, stride_w],
259              'dilation': dilation,
260              'padding': [0, 0]}
261    if is_depthconv:
262        params['channels'] = in_channels * depth_mult
263        params['groups'] = in_channels
264    else:
265        params['channels'] = n_filters
266    if keras_layer.padding == 'valid':
267        pass
268    # we insert a separate pad operator
269    elif keras_layer.padding == 'same':
270        in_h = keras_layer.input_shape[1]
271        in_w = keras_layer.input_shape[2]
272        pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h)
273        pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w)
274        if pad_t == pad_b and pad_l == pad_r:
275            params['padding'] = (pad_t, pad_l)
276        else:
277            inexpr = _op.nn.pad(data=inexpr, pad_width=(
278                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
279    else:
280        msg = 'Padding with {} is not supported for operator Convolution ' \
281              'in frontend Keras.'
282        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
283    if is_deconv:
284        out = _op.nn.conv2d_transpose(data=inexpr, **params)
285    else:
286        out = _op.nn.conv2d(data=inexpr, **params)
287    if keras_layer.use_bias:
288        bias = etab.new_const(weightList[1])
289        out = _op.nn.bias_add(out, bias)
290    # defuse activation
291    if sys.version_info.major < 3:
292        act_type = keras_layer.activation.func_name
293    else:
294        act_type = keras_layer.activation.__name__
295    if act_type != 'linear':
296        out = _convert_activation(out, act_type, etab)
297    return out
298
299
300def _convert_separable_convolution(inexpr, keras_layer, etab):
301    _check_data_format(keras_layer)
302    weightList = keras_layer.get_weights()
303    # depthwise conv
304    kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape
305    stride_h, stride_w = keras_layer.strides
306    weight0 = weightList[0].transpose([2, 3, 0, 1])
307    params0 = {'weight': etab.new_const(weight0),
308               'channels': in_channels * depth_mult,
309               'groups': in_channels,
310               'kernel_size': [kernel_h, kernel_w],
311               'strides': [stride_h, stride_w],
312               'dilation': [1, 1],
313               'padding': [0, 0]}
314    if keras_layer.padding == 'valid':
315        pass
316    # we insert a separate pad operator
317    elif keras_layer.padding == 'same':
318        in_h = keras_layer.input_shape[1]
319        in_w = keras_layer.input_shape[2]
320        pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h)
321        pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w)
322        if pad_t == pad_b and pad_l == pad_r:
323            params0['padding'] = (pad_t, pad_l)
324        else:
325            inexpr = _op.nn.pad(data=inexpr, pad_width=(
326                (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r)))
327    else:
328        msg = 'Padding with {} is not supported for operator Separable ' \
329              'Convolution in frontend Keras.'
330        raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding))
331
332    depthconv = _op.nn.conv2d(data=inexpr, **params0)
333    # pointwise conv
334    weight1 = weightList[1].transpose([3, 2, 0, 1])
335    params1 = {'weight': etab.new_const(weight1),
336               'channels': weight1.shape[0],
337               'groups': 1,
338               'kernel_size': [1, 1],
339               'strides': [1, 1],
340               'dilation': [1, 1]}
341    out = _op.nn.conv2d(data=depthconv, **params1)
342    if keras_layer.use_bias:
343        bias = etab.new_const(weightList[2])
344        out = _op.nn.bias_add(out, bias)
345    # defuse activation
346    if sys.version_info.major < 3:
347        act_type = keras_layer.activation.func_name
348    else:
349        act_type = keras_layer.activation.__name__
350    if act_type != 'linear':
351        out = _convert_activation(out, act_type, etab)
352    return out
353
354
355def _convert_flatten(inexpr, keras_layer, _):
356    _check_data_format(keras_layer)
357    # NCHW -> NHWC so that dense can be correctly converted
358    inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1])
359    return _op.nn.batch_flatten(inexpr)
360
361
362def _convert_pooling(inexpr, keras_layer, etab):
363    _check_data_format(keras_layer)
364    pool_type = type(keras_layer).__name__
365    # global pool in keras = global pool + flatten in nnvm/relay
366    if pool_type == 'GlobalMaxPooling2D':
367        return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab)
368    if pool_type == 'GlobalAveragePooling2D':
369        return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab)
370    pool_h, pool_w = keras_layer.pool_size
371    stride_h, stride_w = keras_layer.strides
372    params = {'pool_size': [pool_h, pool_w],
373              'strides': [stride_h, stride_w],
374              'padding': [0, 0]}
375    if keras_layer.padding == 'valid':
376        pass
377    elif keras_layer.padding == 'same':
378        in_h = keras_layer.input_shape[1]
379        in_w = keras_layer.input_shape[2]
380        pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h)
381        pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w)
382        params['padding'] = [pad_t, pad_l, pad_b, pad_r]
383    else:
384        raise tvm.error.OpAttributeUnImplemented(
385            'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding))
386    if pool_type == 'MaxPooling2D':
387        return _op.nn.max_pool2d(inexpr, **params)
388    if pool_type == 'AveragePooling2D':
389        params['count_include_pad'] = False
390        return _op.nn.avg_pool2d(inexpr, **params)
391    raise tvm.error.OpNotImplemented(
392        'Operator {} is not supported for frontend Keras.'.format(keras_layer))
393
394
395def _convert_upsample(inexpr, keras_layer, _):
396    _check_data_format(keras_layer)
397    upsample_type = type(keras_layer).__name__
398    params = {}
399    if upsample_type == 'UpSampling1D':
400        h = keras_layer.size
401        params['scale_h'] = h
402    elif upsample_type == 'UpSampling2D':
403        h, w = keras_layer.size
404        if h != w:
405            raise tvm.error.OpAttributeInvalid(
406                'Height must equal width for operator Upsample.')
407        params['scale_h'] = h
408        params['scale_w'] = h
409
410        if hasattr(keras_layer, 'interpolation'):
411            interpolation = keras_layer.interpolation
412            if interpolation == 'nearest':
413                params['method'] = 'nearest_neighbor'
414            else:
415                params['method'] = 'bilinear'
416
417    elif upsample_type == 'UpSampling3D':
418        h, w, d = keras_layer.size
419        if h != w or w != d:
420            raise tvm.error.OpAttributeInvalid(
421                'Height, width, and depth must all be equal for operator Upsample.')
422        params['scale_h'] = h
423        params['scale_w'] = h
424    else:
425        raise tvm.error.OpNotImplemented(
426            'Operator {} is not supported for frontend Keras.'.format(upsample_type))
427    return _op.nn.upsampling(inexpr, **params)
428
429
430def _convert_cropping(inexpr, keras_layer, _):
431    _check_data_format(keras_layer)
432    crop_type = type(keras_layer).__name__
433    if crop_type == 'Cropping2D':
434        (_, in_h, in_w, _) = keras_layer.input_shape
435        ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping
436    else:
437        raise tvm.error.OpNotImplemented(
438            'Operator {} is not supported for frontend Keras.'.format(crop_type))
439    int32_max = np.iinfo(np.int32).max
440    return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \
441        end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r])
442
443
444def _convert_batchnorm(inexpr, keras_layer, etab):
445    params = {'scale': False,
446              'center': False,
447              'epsilon': keras_layer.epsilon}
448    idx = 0
449    if keras_layer.scale:
450        params['scale'] = True
451        gamma = keras_layer.get_weights()[idx]
452        params['gamma'] = etab.new_const(gamma)
453        idx += 1
454    if keras_layer.center:
455        params['center'] = True
456        beta = keras_layer.get_weights()[idx]
457        params['beta'] = etab.new_const(beta)
458        idx += 1
459    moving_mean = keras_layer.get_weights()[idx]
460    moving_var = keras_layer.get_weights()[idx + 1]
461    params['moving_mean'] = etab.new_const(moving_mean)
462    params['moving_var'] = etab.new_const(moving_var)
463    # in case beta or gamma is not defined
464    params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \
465                     'beta' not in params else params['beta']
466    params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \
467                      'gamma' not in params else params['gamma']
468    result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params)
469    return result
470
471
472def _convert_padding(inexpr, keras_layer, _):
473    _check_data_format(keras_layer)
474    padding_type = type(keras_layer).__name__
475    padding = keras_layer.padding
476    top = left = bottom = right = 0
477    if padding_type == 'ZeroPadding2D':
478        if isinstance(padding, int):
479            top = left = bottom = right = padding
480        elif isinstance(padding, tuple):
481            if isinstance(padding[0], int):
482                top, left = padding
483                bottom, right = padding
484            elif isinstance(padding[0], tuple):
485                top, bottom = padding[0]
486                left, right = padding[1]
487            else:
488                msg = 'Value {} in attribute "padding" of operator Padding ' \
489                      'is not valid.'
490                raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
491        else:
492            msg = 'Value {} in attribute "padding" of operator Padding is ' \
493                  'not valid.'
494            raise tvm.error.OpAttributeInvalid(msg.format(str(padding)))
495    else:
496        msg = 'Operator {} is not supported in frontend Keras.'
497        raise tvm.error.OpNotImplemented(msg.format(padding_type))
498    return _op.nn.pad(data=inexpr,
499                      pad_width=((0, 0), (0, 0), (top, bottom), (left, right)))
500
501
502def _convert_concat(inexpr, keras_layer, _):
503    _check_data_format(keras_layer)
504    return _op.concatenate(_as_list(inexpr), axis=1)
505
506
507def _convert_reshape(inexpr, keras_layer, _):
508    _check_data_format(keras_layer)
509    inshape = keras_layer.input_shape # includes batch
510    tshape = keras_layer.target_shape # no batch
511    if len(inshape) == 3 and len(tshape) == 1:
512        # (?, a, b) -> (-1, ab)
513        shape = (-1, tshape[0])
514    elif len(inshape) in [2, 3] and len(tshape) == 2:
515        # (?, cc) -> (-1, c, c)
516        # (?, a, b) -> (-1, c, c)
517        assert tshape[0] == tshape[1], \
518            "Only supports square target shapes, but got {}".format(tshape)
519        shape = (-1, ) + tshape
520    else:
521        # (?, h, w, c) -> (-1, c, H, W)
522        # (?, h, w, c) -> (-1, c, hw)
523        # (?, hw, c) -> (-1, c, h, w)
524        ch = inshape[-1]
525        assert ch == tshape[-1], \
526            "Only supports last dimension in target shape being equal to " \
527            "the channel number of input tensor."
528        shape = (-1, ch) + tshape[:-1]
529    return _op.reshape(inexpr, newshape=shape)
530
531
532def _convert_lstm(inexpr, keras_layer, etab):
533    _check_data_format(keras_layer)
534    if not isinstance(inexpr, list):
535        buf = np.zeros((1, keras_layer.units), 'float32')
536        c_op = etab.new_const(buf)
537        h_op = etab.new_const(buf)
538        inexpr = [inexpr, h_op, c_op]
539    in_data = inexpr[0]
540    next_h = inexpr[1]
541    next_c = inexpr[2]
542    weightList = keras_layer.get_weights()
543    in_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0])
544    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
545    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
546    in_bias = etab.new_const(weightList[2])
547    units = list(weightList[0].shape)[1]
548    time_steps = in_shape[1]
549    in_data = _op.squeeze(in_data, axis=[0])
550    in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0)
551    # loop for the number of time_steps
552    for data in in_data:
553        ixh1 = _op.nn.dense(data, kernel_weight, units=units)
554        ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias)
555        gate = ixh1 + ixh2
556        gates = _op.split(gate, indices_or_sections=4, axis=1)
557        in_gate = _convert_recurrent_activation(gates[0], keras_layer)
558        in_transform = _convert_recurrent_activation(gates[1], keras_layer)
559        next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None)
560        out_gate = _convert_recurrent_activation(gates[3], keras_layer)
561        next_h = out_gate * _convert_activation(next_c, keras_layer, None)
562    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
563    out = _op.reshape(next_h, newshape=out_shape)
564    return [out, next_h, next_c]
565
566
567def _convert_simple_rnn(inexpr, keras_layer, etab):
568    _check_data_format(keras_layer)
569    if not isinstance(inexpr, list):
570        buf = np.zeros((1, keras_layer.units), 'float32')
571        prev_op = etab.new_const(buf)
572        inexpr = [inexpr, prev_op]
573    in_data = inexpr[0]
574    prev_op = inexpr[1]
575    weightList = keras_layer.get_weights()
576    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
577    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
578    in_bias = etab.new_const(weightList[2])
579    units = list(weightList[0].shape)[1]
580    in_data = _op.nn.batch_flatten(in_data)
581    ixh = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), bias=in_bias)
582    prev_op = _op.nn.batch_flatten(prev_op)
583    ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units)
584    output = ixh + ixh2
585    output = _convert_activation(output, keras_layer, None)
586    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
587    output = _op.reshape(output, newshape=out_shape)
588    return [output, output]
589
590
591def _convert_gru(inexpr, keras_layer, etab):
592    _check_data_format(keras_layer)
593    if not isinstance(inexpr, list):
594        buf = np.zeros((1, keras_layer.units), 'float32')
595        h_tm1 = etab.new_const(buf)
596        inexpr = [inexpr, h_tm1]
597    in_data = inexpr[0]
598    h_tm1_op = inexpr[1]
599    weightList = keras_layer.get_weights()
600    kernel_weight = etab.new_const(weightList[0].transpose([1, 0]))
601    recurrent_weight = etab.new_const(weightList[1].transpose([1, 0]))
602    in_bias = etab.new_const(weightList[2])
603    units = list(weightList[0].shape)[1]
604    in_data = _op.nn.batch_flatten(in_data)
605    matrix_x = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), in_bias)
606    # inputs projected by all gate matrices at once
607    split_indices = [keras_layer.units, 2 * keras_layer.units]
608    gates = _op.split(matrix_x, indices_or_sections=split_indices, axis=1)
609    x_z = gates[0]
610    x_r = gates[1]
611    x_h = gates[2]
612    # hidden state projected separately for update/reset and new
613    units = 2 * keras_layer.units
614    split_indices = [units]
615    rec_weights = _op.split(recurrent_weight, indices_or_sections=split_indices, axis=0)
616    h_tm1_op = _op.nn.batch_flatten(h_tm1_op)
617    matrix_inner = _op.nn.dense(h_tm1_op, rec_weights[0], units=units)
618    split_indices = [keras_layer.units]
619    recurrent = _op.split(matrix_inner, indices_or_sections=split_indices, axis=1)
620    recurrent_z = recurrent[0]
621    recurrent_r = recurrent[1]
622    rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer)
623    rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer)
624    units = keras_layer.units
625    recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units)
626    act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None)
627    # previous and candidate state mixed by update gate
628    output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh
629    out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0])
630    output = _op.reshape(output, newshape=out_shape)
631    return [output, output]
632
633
634def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument
635    """Layers that can be skipped because they are train time only."""
636    return inexpr
637
638
639_convert_map = {
640    'Dense'                    : _convert_dense,
641    'Activation'               : _convert_activation,
642    'Softmax'                  : _convert_advanced_activation,
643    'ReLU'                     : _convert_advanced_activation,
644    'LeakyReLU'                : _convert_advanced_activation,
645    'PReLU'                    : _convert_advanced_activation,
646    'ELU'                      : _convert_advanced_activation,
647    'ThresholdedReLU'          : _convert_advanced_activation,
648
649    'AveragePooling2D'         : _convert_pooling,
650    'MaxPooling2D'             : _convert_pooling,
651    'GlobalAveragePooling2D'   : _convert_pooling,
652    'GlobalMaxPooling2D'       : _convert_pooling,
653    'Conv2D'                   : _convert_convolution,
654    'Conv2DTranspose'          : _convert_convolution,
655    'DepthwiseConv2D'          : _convert_convolution,
656    'SeparableConv2D'          : _convert_separable_convolution,
657
658    'Flatten'                  : _convert_flatten,
659    'Reshape'                  : _convert_reshape,
660    'Concatenate'              : _convert_concat,
661    'BatchNormalization'       : _convert_batchnorm,
662
663    'Add'                      : _convert_merge,
664    'Subtract'                 : _convert_merge,
665    'Multiply'                 : _convert_merge,
666    'ZeroPadding2D'            : _convert_padding,
667    'UpSampling2D'             : _convert_upsample,
668    'Cropping2D'               : _convert_cropping,
669
670    # 'ZeroPadding1D'          : _convert_padding,
671    # 'AveragePooling1D'       : _convert_pooling,
672    # 'MaxPooling1D'           : _convert_pooling,
673    # 'GlobalAveragePooling1D' : _convert_pooling,
674    # 'GlobalMaxPooling1D'     : _convert_pooling,
675    # 'Cropping1D'             : _convert_cropping,
676    # 'UpSampling1D'           : _convert_upsample,
677    # 'UpSampling3D'           : _convert_upsample,
678    # 'Conv1D'                 : _convert_convolution1d,
679
680    'SimpleRNN'                : _convert_simple_rnn,
681    'LSTM'                     : _convert_lstm,
682    'GRU'                      : _convert_gru,
683    # 'Bidirectional'          : _convert_bidirectional,
684    # 'TimeDistributed'        : _default_skip,
685
686    'Average'                : _convert_merge,
687    'Maximum'                : _convert_merge,
688    'Dot'                    : _convert_merge,
689    'Permute'                : _convert_permute,
690    # 'Embedding'              : _convert_embedding,
691    # 'RepeatVector'           : _convert_repeat_vector,
692
693    'InputLayer'               : _default_skip,
694    'Dropout'                  : _default_skip,
695    'SpatialDropout2D'         : _default_skip,
696    'SpatialDropout1D'         : _default_skip,
697}
698
699
700def _check_unsupported_layers(model):
701    missing_ops = set()
702    for layer in model.layers:
703        op_name = type(layer).__name__
704        if op_name not in _convert_map:
705            missing_ops.add(op_name)
706
707    if missing_ops:
708        raise NotImplementedError( \
709            "The following operators are not implemented: {}".format(missing_ops))
710
711
712def keras_op_to_relay(inexpr, keras_layer, outname, etab):
713    """Convert a Keras layer to a Relay expression and update the expression table.
714
715    Parameters
716    ----------
717    inexpr : relay.expr.Expr or a list of it
718        The input Relay expression(s).
719
720    keras_layer : keras.layers
721        The Keras layer to be converted.
722
723    outname : str
724        Name of the output Relay expression.
725
726    etab : relay.frontend.common.ExprTable
727        The global expression table to be updated.
728    """
729    op_name = type(keras_layer).__name__
730    if op_name not in _convert_map:
731        raise tvm.error.OpNotImplemented(
732            'Operator {} is not supported for frontend Keras.'.format(op_name))
733    outs = _convert_map[op_name](inexpr, keras_layer, etab)
734    outs = _as_list(outs)
735    for t_idx, out in enumerate(outs):
736        name = outname + ":" + str(t_idx)
737        etab.set_expr(name, out)
738
739
740def from_keras(model, shape=None):
741    """Convert keras model to relay Function.
742
743    Parameters
744    ----------
745    model : keras.engine.training.Model
746        The keras model to be converted.
747
748    shape: dict of str to int list/tuple
749        Input shapes of the model, optional
750
751    Returns
752    -------
753    mod : tvm.relay.Module
754        The relay module for compilation.
755
756    params : dict of str to tvm.NDArray
757        The parameter dict to be used by Relay.
758    """
759    try:
760        import keras
761    except ImportError:
762        raise ImportError('Keras must be installed')
763    assert isinstance(model, keras.engine.training.Model)
764    if keras.backend.backend() != 'tensorflow':
765        raise ValueError("Keras frontend currently supports tensorflow backend only.")
766    if keras.backend.image_data_format() != 'channels_last':
767        raise ValueError("Keras frontend currently supports data_format = channels_last only.")
768    _check_unsupported_layers(model)
769
770    def _convert_input_layer(keras_layer):
771        input_name = keras_layer.name
772        input_shape = shape[input_name] if shape is not None and input_name in shape else None
773        etab.set_expr(input_name, new_var(input_name, shape=input_shape))
774
775    etab = ExprTable()
776    for keras_layer in model.layers:
777        if isinstance(keras_layer, keras.engine.InputLayer):
778            _convert_input_layer(keras_layer)
779        else:
780            inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \
781                       else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \
782                       else None
783            if inbound_nodes is None:
784                raise TypeError("Unknown layer type or unsupported Keras version : {}"
785                                .format(keras_layer))
786            for node_idx, node in enumerate(inbound_nodes):
787                # If some nodes in imported model is not relevant to the current model,
788                # skip such layers. model._network_nodes contains keys of all nodes relevant
789                # to the current model.
790                if not model._node_key(keras_layer, node_idx) in model._network_nodes:
791                    continue
792                inexpr = []
793                # Since Keras allows creating multiple layers from the same name instance,
794                # we append node index to the expr name to make it unique.
795                # The one exception is InputLayer. Changing input variable names after conversion
796                # would confuse users, so we should keep them as far as possible. Fortunately,
797                # they are named uniquely to input_1, input_2, input_3... by default.
798                zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers)
799                for n_idx, t_idx, inbound_layer in zip_node:
800                    if isinstance(inbound_layer, keras.engine.InputLayer):
801                        expr_name = inbound_layer.name
802                        _convert_input_layer(inbound_layer)
803                    else:
804                        expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx)
805                    expr = etab.get_expr(expr_name)
806                    inexpr.append(expr)
807                if len(inexpr) == 1:
808                    inexpr = inexpr[0]
809                keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ':' + str(node_idx), etab)
810    # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2])
811    # Get all output nodes in etab using the name made from above values.
812    # The out exprs were added to etab in keras_op_to_relay using this name.
813    outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \
814               for oc in model._output_coordinates]
815    outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr)
816    func = _expr.Function(analysis.free_vars(outexpr), outexpr)
817    params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()}
818    return _module.Module.from_expr(func), params
819