1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17# pylint: disable=invalid-name, import-self 18"""Keras frontend.""" 19from __future__ import absolute_import as _abs 20import sys 21import numpy as np 22import tvm 23from .. import analysis 24from .. import expr as _expr 25from .. import module as _module 26from .. import op as _op 27from ... import nd as _nd 28from .common import ExprTable, new_var 29 30__all__ = ['from_keras'] 31 32 33def _check_data_format(keras_layer): 34 if hasattr(keras_layer, ('data_format')): 35 if keras_layer.data_format != 'channels_last': 36 raise ValueError("Keras frontend currently supports data_format = channels_last only.") 37 38 39def _get_pad_pair(input1d, kernel1d, stride1d): 40 out1d = (input1d + stride1d - 1) // stride1d 41 pad = np.maximum((out1d - 1) * stride1d + kernel1d - input1d, 0) 42 pad_before = pad // 2 43 pad_after = pad - pad_before 44 return [pad_before, pad_after] 45 46 47def _get_elu(inexpr, alpha): 48 """A helper method for elu.""" 49 return _op.negative(alpha) * _op.nn.relu(_expr.const(1., dtype='float32') - \ 50 _op.exp(inexpr)) + _op.nn.relu(inexpr) 51 52 53def _as_list(arr): 54 """Force being a list, ignore if already is.""" 55 if isinstance(arr, list): 56 return arr 57 return [arr] 58 59 60def _convert_recurrent_activation(inexpr, keras_layer): 61 act_type = keras_layer.recurrent_activation.__name__ 62 return _convert_activation(inexpr, act_type, None) 63 64 65def _convert_activation(inexpr, keras_layer, _): 66 if isinstance(keras_layer, str): 67 act_type = keras_layer 68 else: 69 if sys.version_info.major < 3: 70 act_type = keras_layer.activation.func_name 71 else: 72 act_type = keras_layer.activation.__name__ 73 if act_type == 'linear': 74 if isinstance(keras_layer, str): 75 return inexpr 76 alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1. 77 beta = keras_layer.beta if hasattr(keras_layer, 'beta') else 0. 78 alpha = _expr.const(alpha, dtype='float32') 79 beta = _expr.const(beta, dtype='float32') 80 return _op.add(_op.multiply(inexpr, alpha), beta) 81 if act_type == 'softmax': 82 return _op.nn.softmax(inexpr, axis=1) 83 if act_type == 'sigmoid': 84 return _op.sigmoid(inexpr) 85 if act_type == 'tanh': 86 return _op.tanh(inexpr) 87 if act_type == 'relu': 88 return _op.nn.relu(inexpr) 89 if act_type == 'softplus': 90 return _op.log(_op.add(_op.exp(inexpr), _expr.const(1., dtype='float32'))) 91 if act_type == 'elu': 92 alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1. 93 alpha = _expr.const(alpha, dtype='float32') 94 return _get_elu(inexpr, alpha) 95 if act_type == 'selu': 96 # Alpha, Gamma values obtained from https://arxiv.org/abs/1706.02515 97 alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') \ 98 else 1.6732632423543772848170429916717 99 gamma = keras_layer.gamma if hasattr(keras_layer, 'gamma') \ 100 else 1.0507009873554804934193349852946 101 alpha = _expr.const(alpha, dtype='float32') 102 gamma = _expr.const(gamma, dtype='float32') 103 return gamma * _get_elu(inexpr, alpha) 104 if act_type == 'relu6': 105 return _op.clip(inexpr, a_min=0., a_max=6.) 106 if act_type == 'softsign': 107 return inexpr / (_expr.const(1., dtype='float32') + _op.abs(inexpr)) 108 if act_type == 'hard_sigmoid': 109 x = (_expr.const(0.2, dtype='float32') * inexpr) + _expr.const(0.5, dtype='float32') 110 return _op.clip(x, a_min=0., a_max=1.) 111 112 raise tvm.error.OpNotImplemented( 113 'Operator {} is not supported in frontend Keras.'.format(act_type)) 114 115 116def _convert_advanced_activation(inexpr, keras_layer, etab): 117 act_type = type(keras_layer).__name__ 118 119 if act_type == 'Softmax': 120 axis = keras_layer.axis 121 dims = len(keras_layer.input_shape) 122 if isinstance(axis, list): 123 raise tvm.error.OpAttributeUnImplemented( 124 'Softmax with axes {} is not supported.'.format(axis)) 125 if axis == -1: 126 axis = 1 127 else: 128 axis = axis + 1 if axis < dims - 1 else 1 129 return _op.nn.softmax(inexpr, axis=axis) 130 if act_type == 'ReLU': 131 threshold = _expr.const(keras_layer.threshold, dtype='float32') 132 if keras_layer.max_value and float(keras_layer.threshold) == 0: 133 # f(x) = max_value, for x >= max_value 134 # f(x) = x, for threshold <= x < max_value 135 return _op.clip(inexpr, a_min=0., a_max=float(keras_layer.max_value)) 136 elif keras_layer.max_value and _op.greater(threshold, inexpr).astype('float32'): 137 # f(x) = negative_slope * (inexpr - threshold) 138 negative_slope = _expr.const(keras_layer.negative_slope, dtype='float32') 139 return _op.multiply(negative_slope, _op.subtract(inexpr, threshold)) 140 return _op.nn.relu(inexpr) 141 if act_type == 'LeakyReLU': 142 return _op.nn.leaky_relu(inexpr, alpha=float(keras_layer.alpha)) 143 if act_type == 'ELU': 144 alpha = keras_layer.alpha if hasattr(keras_layer, 'alpha') else 1. 145 alpha = _expr.const(alpha, dtype='float32') 146 return _get_elu(inexpr, alpha) 147 if act_type == 'PReLU': 148 assert hasattr(keras_layer, 'alpha'), "alpha required for PReLU." 149 _check_data_format(keras_layer) 150 size = len(keras_layer.alpha.shape) 151 alpha = etab.new_const(keras_layer.get_weights()[0] \ 152 .transpose(np.roll(range(size), 1))) 153 return _op.negative(alpha) * _op.nn.relu(_op.negative(inexpr)) + _op.nn.relu(inexpr) 154 if act_type == 'ThresholdedReLU': 155 theta = keras_layer.theta if hasattr(keras_layer, 'theta') else 1. 156 return _op.multiply(inexpr, _op.greater(inexpr, \ 157 _expr.const(theta, dtype='float32')).astype('float32')) 158 159 raise tvm.error.OpNotImplemented( 160 'Operator {} is not supported in frontend Keras.'.format(act_type)) 161 162 163def _convert_merge(inexpr, keras_layer, _): 164 merge_type = type(keras_layer).__name__ 165 ret = inexpr[0] 166 if merge_type == 'Dot': 167 axes = keras_layer.axes 168 if isinstance(keras_layer.axes, int): 169 axes = [keras_layer.axes, keras_layer.axes] 170 if isinstance(axes, list): 171 if len(axes) != 2: 172 raise tvm.error.OpAttributeUnImplemented( 173 'Dot with axes {} is not supported.'.format(keras_layer.axes)) 174 for i, axis in enumerate(axes): 175 if axis not in [1, 2]: 176 raise tvm.error.OpAttributeUnImplemented( 177 'Dot with axes {} is not supported.'.format(keras_layer.axes)) 178 if axes[i] == 2: 179 inexpr[i] = _op.transpose(inexpr[i], axes=[0, 2, 1]) 180 else: 181 raise tvm.error.OpAttributeUnImplemented( 182 'Dot with axes {} is not supported.'.format(keras_layer.axes)) 183 ret_dot = _op.nn.batch_matmul(inexpr[0], inexpr[1]) 184 ret = _op.transpose(ret_dot, axes=[0, 2, 1]) 185 elif merge_type == 'Subtract': 186 assert len(inexpr) == 2, "Subtract merge takes 2 inputs." 187 ret = _op.subtract(ret, inexpr[1]) 188 elif merge_type in ['Add', 'Multiply', 'Maximum']: 189 op_map = {'Add':_op.add, 'Multiply':_op.multiply, 'Maximum':_op.maximum} 190 for i in range(1, len(inexpr)): 191 ret = op_map[merge_type](ret, inexpr[i]) 192 elif merge_type == 'Average': 193 for i in range(1, len(inexpr)): 194 ret = _op.add(ret, inexpr[i]) 195 ret = ret / _expr.const(len(inexpr), dtype='float32') 196 else: 197 raise tvm.error.OpNotImplemented( 198 'Operator {} is not supported in frontend Keras.'.format(merge_type)) 199 return ret 200 201 202def _convert_permute(inexpr, keras_layer, _): 203 return _op.transpose(inexpr, axes=(0,) + keras_layer.dims) 204 205 206def _convert_dense(inexpr, keras_layer, etab): 207 weightList = keras_layer.get_weights() 208 weight = etab.new_const(weightList[0].transpose([1, 0])) 209 params = {'weight':weight, 'units':weightList[0].shape[1]} 210 input_shape = keras_layer.input_shape 211 input_dim = len(input_shape) 212 # In case of RNN dense, input shape will be (1, 1, n) 213 if input_dim > 2: 214 input_shape = tuple(dim if dim else 1 for dim in _as_list(input_shape)[0]) 215 if input_dim != 3 or input_shape[0] != 1 or input_shape[1] != 1: 216 raise tvm.error.OpAttributeInvalid( 217 'Input shape {} is not valid for operator Dense.'.format(input_shape)) 218 inexpr = _op.squeeze(inexpr, axis=0) 219 out = _op.nn.dense(data=inexpr, **params) 220 if keras_layer.use_bias: 221 bias = etab.new_const(weightList[1]) 222 out = _op.nn.bias_add(out, bias) 223 # defuse activation 224 if sys.version_info.major < 3: 225 act_type = keras_layer.activation.func_name 226 else: 227 act_type = keras_layer.activation.__name__ 228 if act_type != 'linear': 229 out = _convert_activation(out, act_type, etab) 230 if input_dim > 2: 231 out = _op.expand_dims(out, axis=0) 232 return out 233 234 235def _convert_convolution(inexpr, keras_layer, etab): 236 _check_data_format(keras_layer) 237 is_deconv = type(keras_layer).__name__ == 'Conv2DTranspose' 238 is_depthconv = type(keras_layer).__name__ == 'DepthwiseConv2D' 239 weightList = keras_layer.get_weights() 240 if is_deconv: 241 kernel_h, kernel_w, n_filters, in_channels = weightList[0].shape 242 weight = weightList[0].transpose([3, 2, 0, 1]) 243 elif is_depthconv: 244 kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape 245 weight = weightList[0].transpose([2, 3, 0, 1]) 246 else: 247 kernel_h, kernel_w, in_channels, n_filters = weightList[0].shape 248 weight = weightList[0].transpose([3, 2, 0, 1]) 249 if isinstance(keras_layer.dilation_rate, (list, tuple)): 250 dilation = [keras_layer.dilation_rate[0], keras_layer.dilation_rate[1]] 251 else: 252 dilation = [keras_layer.dilation_rate, keras_layer.dilation_rate] 253 dilated_kernel_h = (kernel_h - 1) * dilation[0] + 1 254 dilated_kernel_w = (kernel_w - 1) * dilation[1] + 1 255 stride_h, stride_w = keras_layer.strides 256 params = {'weight': etab.new_const(weight), 257 'kernel_size': [kernel_h, kernel_w], 258 'strides': [stride_h, stride_w], 259 'dilation': dilation, 260 'padding': [0, 0]} 261 if is_depthconv: 262 params['channels'] = in_channels * depth_mult 263 params['groups'] = in_channels 264 else: 265 params['channels'] = n_filters 266 if keras_layer.padding == 'valid': 267 pass 268 # we insert a separate pad operator 269 elif keras_layer.padding == 'same': 270 in_h = keras_layer.input_shape[1] 271 in_w = keras_layer.input_shape[2] 272 pad_t, pad_b = _get_pad_pair(in_h, dilated_kernel_h, stride_h) 273 pad_l, pad_r = _get_pad_pair(in_w, dilated_kernel_w, stride_w) 274 if pad_t == pad_b and pad_l == pad_r: 275 params['padding'] = (pad_t, pad_l) 276 else: 277 inexpr = _op.nn.pad(data=inexpr, pad_width=( 278 (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) 279 else: 280 msg = 'Padding with {} is not supported for operator Convolution ' \ 281 'in frontend Keras.' 282 raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding)) 283 if is_deconv: 284 out = _op.nn.conv2d_transpose(data=inexpr, **params) 285 else: 286 out = _op.nn.conv2d(data=inexpr, **params) 287 if keras_layer.use_bias: 288 bias = etab.new_const(weightList[1]) 289 out = _op.nn.bias_add(out, bias) 290 # defuse activation 291 if sys.version_info.major < 3: 292 act_type = keras_layer.activation.func_name 293 else: 294 act_type = keras_layer.activation.__name__ 295 if act_type != 'linear': 296 out = _convert_activation(out, act_type, etab) 297 return out 298 299 300def _convert_separable_convolution(inexpr, keras_layer, etab): 301 _check_data_format(keras_layer) 302 weightList = keras_layer.get_weights() 303 # depthwise conv 304 kernel_h, kernel_w, in_channels, depth_mult = weightList[0].shape 305 stride_h, stride_w = keras_layer.strides 306 weight0 = weightList[0].transpose([2, 3, 0, 1]) 307 params0 = {'weight': etab.new_const(weight0), 308 'channels': in_channels * depth_mult, 309 'groups': in_channels, 310 'kernel_size': [kernel_h, kernel_w], 311 'strides': [stride_h, stride_w], 312 'dilation': [1, 1], 313 'padding': [0, 0]} 314 if keras_layer.padding == 'valid': 315 pass 316 # we insert a separate pad operator 317 elif keras_layer.padding == 'same': 318 in_h = keras_layer.input_shape[1] 319 in_w = keras_layer.input_shape[2] 320 pad_t, pad_b = _get_pad_pair(in_h, kernel_h, stride_h) 321 pad_l, pad_r = _get_pad_pair(in_w, kernel_w, stride_w) 322 if pad_t == pad_b and pad_l == pad_r: 323 params0['padding'] = (pad_t, pad_l) 324 else: 325 inexpr = _op.nn.pad(data=inexpr, pad_width=( 326 (0, 0), (0, 0), (pad_t, pad_b), (pad_l, pad_r))) 327 else: 328 msg = 'Padding with {} is not supported for operator Separable ' \ 329 'Convolution in frontend Keras.' 330 raise tvm.error.OpAttributeUnImplemented(msg.format(keras_layer.padding)) 331 332 depthconv = _op.nn.conv2d(data=inexpr, **params0) 333 # pointwise conv 334 weight1 = weightList[1].transpose([3, 2, 0, 1]) 335 params1 = {'weight': etab.new_const(weight1), 336 'channels': weight1.shape[0], 337 'groups': 1, 338 'kernel_size': [1, 1], 339 'strides': [1, 1], 340 'dilation': [1, 1]} 341 out = _op.nn.conv2d(data=depthconv, **params1) 342 if keras_layer.use_bias: 343 bias = etab.new_const(weightList[2]) 344 out = _op.nn.bias_add(out, bias) 345 # defuse activation 346 if sys.version_info.major < 3: 347 act_type = keras_layer.activation.func_name 348 else: 349 act_type = keras_layer.activation.__name__ 350 if act_type != 'linear': 351 out = _convert_activation(out, act_type, etab) 352 return out 353 354 355def _convert_flatten(inexpr, keras_layer, _): 356 _check_data_format(keras_layer) 357 # NCHW -> NHWC so that dense can be correctly converted 358 inexpr = _op.transpose(inexpr, axes=[0, 2, 3, 1]) 359 return _op.nn.batch_flatten(inexpr) 360 361 362def _convert_pooling(inexpr, keras_layer, etab): 363 _check_data_format(keras_layer) 364 pool_type = type(keras_layer).__name__ 365 # global pool in keras = global pool + flatten in nnvm/relay 366 if pool_type == 'GlobalMaxPooling2D': 367 return _convert_flatten(_op.nn.global_max_pool2d(inexpr), keras_layer, etab) 368 if pool_type == 'GlobalAveragePooling2D': 369 return _convert_flatten(_op.nn.global_avg_pool2d(inexpr), keras_layer, etab) 370 pool_h, pool_w = keras_layer.pool_size 371 stride_h, stride_w = keras_layer.strides 372 params = {'pool_size': [pool_h, pool_w], 373 'strides': [stride_h, stride_w], 374 'padding': [0, 0]} 375 if keras_layer.padding == 'valid': 376 pass 377 elif keras_layer.padding == 'same': 378 in_h = keras_layer.input_shape[1] 379 in_w = keras_layer.input_shape[2] 380 pad_t, pad_b = _get_pad_pair(in_h, pool_h, stride_h) 381 pad_l, pad_r = _get_pad_pair(in_w, pool_w, stride_w) 382 params['padding'] = [pad_t, pad_l, pad_b, pad_r] 383 else: 384 raise tvm.error.OpAttributeUnImplemented( 385 'Padding with {} is not supported in operator Pooling.'.format(keras_layer.padding)) 386 if pool_type == 'MaxPooling2D': 387 return _op.nn.max_pool2d(inexpr, **params) 388 if pool_type == 'AveragePooling2D': 389 params['count_include_pad'] = False 390 return _op.nn.avg_pool2d(inexpr, **params) 391 raise tvm.error.OpNotImplemented( 392 'Operator {} is not supported for frontend Keras.'.format(keras_layer)) 393 394 395def _convert_upsample(inexpr, keras_layer, _): 396 _check_data_format(keras_layer) 397 upsample_type = type(keras_layer).__name__ 398 params = {} 399 if upsample_type == 'UpSampling1D': 400 h = keras_layer.size 401 params['scale_h'] = h 402 elif upsample_type == 'UpSampling2D': 403 h, w = keras_layer.size 404 if h != w: 405 raise tvm.error.OpAttributeInvalid( 406 'Height must equal width for operator Upsample.') 407 params['scale_h'] = h 408 params['scale_w'] = h 409 410 if hasattr(keras_layer, 'interpolation'): 411 interpolation = keras_layer.interpolation 412 if interpolation == 'nearest': 413 params['method'] = 'nearest_neighbor' 414 else: 415 params['method'] = 'bilinear' 416 417 elif upsample_type == 'UpSampling3D': 418 h, w, d = keras_layer.size 419 if h != w or w != d: 420 raise tvm.error.OpAttributeInvalid( 421 'Height, width, and depth must all be equal for operator Upsample.') 422 params['scale_h'] = h 423 params['scale_w'] = h 424 else: 425 raise tvm.error.OpNotImplemented( 426 'Operator {} is not supported for frontend Keras.'.format(upsample_type)) 427 return _op.nn.upsampling(inexpr, **params) 428 429 430def _convert_cropping(inexpr, keras_layer, _): 431 _check_data_format(keras_layer) 432 crop_type = type(keras_layer).__name__ 433 if crop_type == 'Cropping2D': 434 (_, in_h, in_w, _) = keras_layer.input_shape 435 ((crop_t, crop_b), (crop_l, crop_r)) = keras_layer.cropping 436 else: 437 raise tvm.error.OpNotImplemented( 438 'Operator {} is not supported for frontend Keras.'.format(crop_type)) 439 int32_max = np.iinfo(np.int32).max 440 return _op.strided_slice(inexpr, begin=[0, 0, crop_t, crop_l], \ 441 end=[int32_max, int32_max, in_h-crop_b, in_w-crop_r]) 442 443 444def _convert_batchnorm(inexpr, keras_layer, etab): 445 params = {'scale': False, 446 'center': False, 447 'epsilon': keras_layer.epsilon} 448 idx = 0 449 if keras_layer.scale: 450 params['scale'] = True 451 gamma = keras_layer.get_weights()[idx] 452 params['gamma'] = etab.new_const(gamma) 453 idx += 1 454 if keras_layer.center: 455 params['center'] = True 456 beta = keras_layer.get_weights()[idx] 457 params['beta'] = etab.new_const(beta) 458 idx += 1 459 moving_mean = keras_layer.get_weights()[idx] 460 moving_var = keras_layer.get_weights()[idx + 1] 461 params['moving_mean'] = etab.new_const(moving_mean) 462 params['moving_var'] = etab.new_const(moving_var) 463 # in case beta or gamma is not defined 464 params['beta'] = etab.new_const(np.zeros(moving_mean.shape)) if \ 465 'beta' not in params else params['beta'] 466 params['gamma'] = etab.new_const(np.ones(moving_mean.shape)) if \ 467 'gamma' not in params else params['gamma'] 468 result, moving_mean, moving_var = _op.nn.batch_norm(inexpr, **params) 469 return result 470 471 472def _convert_padding(inexpr, keras_layer, _): 473 _check_data_format(keras_layer) 474 padding_type = type(keras_layer).__name__ 475 padding = keras_layer.padding 476 top = left = bottom = right = 0 477 if padding_type == 'ZeroPadding2D': 478 if isinstance(padding, int): 479 top = left = bottom = right = padding 480 elif isinstance(padding, tuple): 481 if isinstance(padding[0], int): 482 top, left = padding 483 bottom, right = padding 484 elif isinstance(padding[0], tuple): 485 top, bottom = padding[0] 486 left, right = padding[1] 487 else: 488 msg = 'Value {} in attribute "padding" of operator Padding ' \ 489 'is not valid.' 490 raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) 491 else: 492 msg = 'Value {} in attribute "padding" of operator Padding is ' \ 493 'not valid.' 494 raise tvm.error.OpAttributeInvalid(msg.format(str(padding))) 495 else: 496 msg = 'Operator {} is not supported in frontend Keras.' 497 raise tvm.error.OpNotImplemented(msg.format(padding_type)) 498 return _op.nn.pad(data=inexpr, 499 pad_width=((0, 0), (0, 0), (top, bottom), (left, right))) 500 501 502def _convert_concat(inexpr, keras_layer, _): 503 _check_data_format(keras_layer) 504 return _op.concatenate(_as_list(inexpr), axis=1) 505 506 507def _convert_reshape(inexpr, keras_layer, _): 508 _check_data_format(keras_layer) 509 inshape = keras_layer.input_shape # includes batch 510 tshape = keras_layer.target_shape # no batch 511 if len(inshape) == 3 and len(tshape) == 1: 512 # (?, a, b) -> (-1, ab) 513 shape = (-1, tshape[0]) 514 elif len(inshape) in [2, 3] and len(tshape) == 2: 515 # (?, cc) -> (-1, c, c) 516 # (?, a, b) -> (-1, c, c) 517 assert tshape[0] == tshape[1], \ 518 "Only supports square target shapes, but got {}".format(tshape) 519 shape = (-1, ) + tshape 520 else: 521 # (?, h, w, c) -> (-1, c, H, W) 522 # (?, h, w, c) -> (-1, c, hw) 523 # (?, hw, c) -> (-1, c, h, w) 524 ch = inshape[-1] 525 assert ch == tshape[-1], \ 526 "Only supports last dimension in target shape being equal to " \ 527 "the channel number of input tensor." 528 shape = (-1, ch) + tshape[:-1] 529 return _op.reshape(inexpr, newshape=shape) 530 531 532def _convert_lstm(inexpr, keras_layer, etab): 533 _check_data_format(keras_layer) 534 if not isinstance(inexpr, list): 535 buf = np.zeros((1, keras_layer.units), 'float32') 536 c_op = etab.new_const(buf) 537 h_op = etab.new_const(buf) 538 inexpr = [inexpr, h_op, c_op] 539 in_data = inexpr[0] 540 next_h = inexpr[1] 541 next_c = inexpr[2] 542 weightList = keras_layer.get_weights() 543 in_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.input_shape)[0]) 544 kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) 545 recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) 546 in_bias = etab.new_const(weightList[2]) 547 units = list(weightList[0].shape)[1] 548 time_steps = in_shape[1] 549 in_data = _op.squeeze(in_data, axis=[0]) 550 in_data = _op.split(in_data, indices_or_sections=time_steps, axis=0) 551 # loop for the number of time_steps 552 for data in in_data: 553 ixh1 = _op.nn.dense(data, kernel_weight, units=units) 554 ixh2 = _op.nn.bias_add(_op.nn.dense(next_h, recurrent_weight, units=units), bias=in_bias) 555 gate = ixh1 + ixh2 556 gates = _op.split(gate, indices_or_sections=4, axis=1) 557 in_gate = _convert_recurrent_activation(gates[0], keras_layer) 558 in_transform = _convert_recurrent_activation(gates[1], keras_layer) 559 next_c = in_transform * next_c + in_gate * _convert_activation(gates[2], keras_layer, None) 560 out_gate = _convert_recurrent_activation(gates[3], keras_layer) 561 next_h = out_gate * _convert_activation(next_c, keras_layer, None) 562 out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) 563 out = _op.reshape(next_h, newshape=out_shape) 564 return [out, next_h, next_c] 565 566 567def _convert_simple_rnn(inexpr, keras_layer, etab): 568 _check_data_format(keras_layer) 569 if not isinstance(inexpr, list): 570 buf = np.zeros((1, keras_layer.units), 'float32') 571 prev_op = etab.new_const(buf) 572 inexpr = [inexpr, prev_op] 573 in_data = inexpr[0] 574 prev_op = inexpr[1] 575 weightList = keras_layer.get_weights() 576 kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) 577 recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) 578 in_bias = etab.new_const(weightList[2]) 579 units = list(weightList[0].shape)[1] 580 in_data = _op.nn.batch_flatten(in_data) 581 ixh = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), bias=in_bias) 582 prev_op = _op.nn.batch_flatten(prev_op) 583 ixh2 = _op.nn.dense(prev_op, recurrent_weight, units=units) 584 output = ixh + ixh2 585 output = _convert_activation(output, keras_layer, None) 586 out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) 587 output = _op.reshape(output, newshape=out_shape) 588 return [output, output] 589 590 591def _convert_gru(inexpr, keras_layer, etab): 592 _check_data_format(keras_layer) 593 if not isinstance(inexpr, list): 594 buf = np.zeros((1, keras_layer.units), 'float32') 595 h_tm1 = etab.new_const(buf) 596 inexpr = [inexpr, h_tm1] 597 in_data = inexpr[0] 598 h_tm1_op = inexpr[1] 599 weightList = keras_layer.get_weights() 600 kernel_weight = etab.new_const(weightList[0].transpose([1, 0])) 601 recurrent_weight = etab.new_const(weightList[1].transpose([1, 0])) 602 in_bias = etab.new_const(weightList[2]) 603 units = list(weightList[0].shape)[1] 604 in_data = _op.nn.batch_flatten(in_data) 605 matrix_x = _op.nn.bias_add(_op.nn.dense(in_data, kernel_weight, units=units), in_bias) 606 # inputs projected by all gate matrices at once 607 split_indices = [keras_layer.units, 2 * keras_layer.units] 608 gates = _op.split(matrix_x, indices_or_sections=split_indices, axis=1) 609 x_z = gates[0] 610 x_r = gates[1] 611 x_h = gates[2] 612 # hidden state projected separately for update/reset and new 613 units = 2 * keras_layer.units 614 split_indices = [units] 615 rec_weights = _op.split(recurrent_weight, indices_or_sections=split_indices, axis=0) 616 h_tm1_op = _op.nn.batch_flatten(h_tm1_op) 617 matrix_inner = _op.nn.dense(h_tm1_op, rec_weights[0], units=units) 618 split_indices = [keras_layer.units] 619 recurrent = _op.split(matrix_inner, indices_or_sections=split_indices, axis=1) 620 recurrent_z = recurrent[0] 621 recurrent_r = recurrent[1] 622 rec_act_z = _convert_recurrent_activation(x_z + recurrent_z, keras_layer) 623 rec_act_r = _convert_recurrent_activation(x_r + recurrent_r, keras_layer) 624 units = keras_layer.units 625 recurrent_h = _op.nn.dense(rec_act_r * h_tm1_op, rec_weights[1], units=units) 626 act_hh = _convert_activation(x_h + recurrent_h, keras_layer, None) 627 # previous and candidate state mixed by update gate 628 output = rec_act_z * h_tm1_op + (_expr.const(1., dtype='float32') - rec_act_z) * act_hh 629 out_shape = tuple(dim if dim else 1 for dim in _as_list(keras_layer.output_shape)[0]) 630 output = _op.reshape(output, newshape=out_shape) 631 return [output, output] 632 633 634def _default_skip(inexpr, keras_layer, _): # pylint: disable=unused-argument 635 """Layers that can be skipped because they are train time only.""" 636 return inexpr 637 638 639_convert_map = { 640 'Dense' : _convert_dense, 641 'Activation' : _convert_activation, 642 'Softmax' : _convert_advanced_activation, 643 'ReLU' : _convert_advanced_activation, 644 'LeakyReLU' : _convert_advanced_activation, 645 'PReLU' : _convert_advanced_activation, 646 'ELU' : _convert_advanced_activation, 647 'ThresholdedReLU' : _convert_advanced_activation, 648 649 'AveragePooling2D' : _convert_pooling, 650 'MaxPooling2D' : _convert_pooling, 651 'GlobalAveragePooling2D' : _convert_pooling, 652 'GlobalMaxPooling2D' : _convert_pooling, 653 'Conv2D' : _convert_convolution, 654 'Conv2DTranspose' : _convert_convolution, 655 'DepthwiseConv2D' : _convert_convolution, 656 'SeparableConv2D' : _convert_separable_convolution, 657 658 'Flatten' : _convert_flatten, 659 'Reshape' : _convert_reshape, 660 'Concatenate' : _convert_concat, 661 'BatchNormalization' : _convert_batchnorm, 662 663 'Add' : _convert_merge, 664 'Subtract' : _convert_merge, 665 'Multiply' : _convert_merge, 666 'ZeroPadding2D' : _convert_padding, 667 'UpSampling2D' : _convert_upsample, 668 'Cropping2D' : _convert_cropping, 669 670 # 'ZeroPadding1D' : _convert_padding, 671 # 'AveragePooling1D' : _convert_pooling, 672 # 'MaxPooling1D' : _convert_pooling, 673 # 'GlobalAveragePooling1D' : _convert_pooling, 674 # 'GlobalMaxPooling1D' : _convert_pooling, 675 # 'Cropping1D' : _convert_cropping, 676 # 'UpSampling1D' : _convert_upsample, 677 # 'UpSampling3D' : _convert_upsample, 678 # 'Conv1D' : _convert_convolution1d, 679 680 'SimpleRNN' : _convert_simple_rnn, 681 'LSTM' : _convert_lstm, 682 'GRU' : _convert_gru, 683 # 'Bidirectional' : _convert_bidirectional, 684 # 'TimeDistributed' : _default_skip, 685 686 'Average' : _convert_merge, 687 'Maximum' : _convert_merge, 688 'Dot' : _convert_merge, 689 'Permute' : _convert_permute, 690 # 'Embedding' : _convert_embedding, 691 # 'RepeatVector' : _convert_repeat_vector, 692 693 'InputLayer' : _default_skip, 694 'Dropout' : _default_skip, 695 'SpatialDropout2D' : _default_skip, 696 'SpatialDropout1D' : _default_skip, 697} 698 699 700def _check_unsupported_layers(model): 701 missing_ops = set() 702 for layer in model.layers: 703 op_name = type(layer).__name__ 704 if op_name not in _convert_map: 705 missing_ops.add(op_name) 706 707 if missing_ops: 708 raise NotImplementedError( \ 709 "The following operators are not implemented: {}".format(missing_ops)) 710 711 712def keras_op_to_relay(inexpr, keras_layer, outname, etab): 713 """Convert a Keras layer to a Relay expression and update the expression table. 714 715 Parameters 716 ---------- 717 inexpr : relay.expr.Expr or a list of it 718 The input Relay expression(s). 719 720 keras_layer : keras.layers 721 The Keras layer to be converted. 722 723 outname : str 724 Name of the output Relay expression. 725 726 etab : relay.frontend.common.ExprTable 727 The global expression table to be updated. 728 """ 729 op_name = type(keras_layer).__name__ 730 if op_name not in _convert_map: 731 raise tvm.error.OpNotImplemented( 732 'Operator {} is not supported for frontend Keras.'.format(op_name)) 733 outs = _convert_map[op_name](inexpr, keras_layer, etab) 734 outs = _as_list(outs) 735 for t_idx, out in enumerate(outs): 736 name = outname + ":" + str(t_idx) 737 etab.set_expr(name, out) 738 739 740def from_keras(model, shape=None): 741 """Convert keras model to relay Function. 742 743 Parameters 744 ---------- 745 model : keras.engine.training.Model 746 The keras model to be converted. 747 748 shape: dict of str to int list/tuple 749 Input shapes of the model, optional 750 751 Returns 752 ------- 753 mod : tvm.relay.Module 754 The relay module for compilation. 755 756 params : dict of str to tvm.NDArray 757 The parameter dict to be used by Relay. 758 """ 759 try: 760 import keras 761 except ImportError: 762 raise ImportError('Keras must be installed') 763 assert isinstance(model, keras.engine.training.Model) 764 if keras.backend.backend() != 'tensorflow': 765 raise ValueError("Keras frontend currently supports tensorflow backend only.") 766 if keras.backend.image_data_format() != 'channels_last': 767 raise ValueError("Keras frontend currently supports data_format = channels_last only.") 768 _check_unsupported_layers(model) 769 770 def _convert_input_layer(keras_layer): 771 input_name = keras_layer.name 772 input_shape = shape[input_name] if shape is not None and input_name in shape else None 773 etab.set_expr(input_name, new_var(input_name, shape=input_shape)) 774 775 etab = ExprTable() 776 for keras_layer in model.layers: 777 if isinstance(keras_layer, keras.engine.InputLayer): 778 _convert_input_layer(keras_layer) 779 else: 780 inbound_nodes = keras_layer.inbound_nodes if hasattr(keras_layer, 'inbound_nodes') \ 781 else keras_layer._inbound_nodes if hasattr(keras_layer, '_inbound_nodes') \ 782 else None 783 if inbound_nodes is None: 784 raise TypeError("Unknown layer type or unsupported Keras version : {}" 785 .format(keras_layer)) 786 for node_idx, node in enumerate(inbound_nodes): 787 # If some nodes in imported model is not relevant to the current model, 788 # skip such layers. model._network_nodes contains keys of all nodes relevant 789 # to the current model. 790 if not model._node_key(keras_layer, node_idx) in model._network_nodes: 791 continue 792 inexpr = [] 793 # Since Keras allows creating multiple layers from the same name instance, 794 # we append node index to the expr name to make it unique. 795 # The one exception is InputLayer. Changing input variable names after conversion 796 # would confuse users, so we should keep them as far as possible. Fortunately, 797 # they are named uniquely to input_1, input_2, input_3... by default. 798 zip_node = zip(node.node_indices, node.tensor_indices, node.inbound_layers) 799 for n_idx, t_idx, inbound_layer in zip_node: 800 if isinstance(inbound_layer, keras.engine.InputLayer): 801 expr_name = inbound_layer.name 802 _convert_input_layer(inbound_layer) 803 else: 804 expr_name = inbound_layer.name + ':' + str(n_idx) + ':' + str(t_idx) 805 expr = etab.get_expr(expr_name) 806 inexpr.append(expr) 807 if len(inexpr) == 1: 808 inexpr = inexpr[0] 809 keras_op_to_relay(inexpr, keras_layer, keras_layer.name + ':' + str(node_idx), etab) 810 # model._output_coordinates contains out_node(oc[0]), node_index(oc[1]) and tensor_index(oc[2]) 811 # Get all output nodes in etab using the name made from above values. 812 # The out exprs were added to etab in keras_op_to_relay using this name. 813 outexpr = [etab.get_expr(oc[0].name + ":" + str(oc[1]) + ":" + str(oc[2])) \ 814 for oc in model._output_coordinates] 815 outexpr = outexpr[0] if len(outexpr) == 1 else _expr.Tuple(outexpr) 816 func = _expr.Function(analysis.free_vars(outexpr), outexpr) 817 params = {k:_nd.array(np.array(v, dtype=np.float32)) for k, v in etab.params.items()} 818 return _module.Module.from_expr(func), params 819