1from functools import partial 2 3import tensorflow as tf 4 5from onnx_tf.common import get_unique_suffix 6from onnx_tf.common import exception 7from onnx_tf.handlers.backend_handler import BackendHandler 8from onnx_tf.handlers.handler import onnx_op 9from onnx_tf.handlers.handler import partial_support 10from onnx_tf.handlers.handler import ps_description 11from .rnn_mixin import RNNMixin 12 13 14@onnx_op("GRU") 15@partial_support(True) 16@ps_description( 17 "GRU with clip or GRU with linear_before_reset, or " + 18 "GRU not using sigmoid for z and r, or " + 19 "GRU using Elu as the activation function " + "with alpha != 1, or " + 20 "GRU using HardSigmoid as the activation function " + 21 "with alpha != 0.2 or beta != 0.5 " + "are not supported in TensorFlow.") 22class GRU(RNNMixin, BackendHandler): 23 24 @classmethod 25 def args_check(cls, node, **kwargs): 26 direction = node.attrs.get("direction", "forward") 27 num_directions = 2 if direction == "bidirectional" else 1 28 if "clip" in node.attrs: 29 exception.OP_UNSUPPORTED_EXCEPT("GRU with clip", "Tensorflow") 30 if node.attrs.get("linear_before_reset", 0): 31 exception.OP_UNSUPPORTED_EXCEPT("GRU with linear_before_reset", 32 "Tensorflow") 33 if "activations" in node.attrs: 34 activations = list(map(lambda x: x.lower(), node.attrs["activations"])) 35 if activations[0] != "sigmoid": 36 exception.OP_UNSUPPORTED_EXCEPT("GRU without sigmoid for `z` and `r`", 37 "Tensorflow") 38 if num_directions == 2: 39 if activations[2] != "sigmoid": 40 exception.OP_UNSUPPORTED_EXCEPT("GRU without sigmoid for `z` and `r`", 41 "Tensorflow") 42 43 @classmethod 44 def _custom_getter(cls, 45 getter, 46 name, 47 node=None, 48 tensor_dict=None, 49 is_bidirectional=None, 50 *args, 51 **kwargs): 52 names = name.split("/") 53 if is_bidirectional: 54 if "fw" in names: 55 index = 0 56 elif "bw" in names: 57 index = 1 58 else: 59 raise RuntimeError("Can not get {} for bidirectional. " 60 "Either fw and bw is not in name scope.".format( 61 names[-1])) 62 if names[-1] == "kernel": 63 # onnx W[zrh], R[zrh] 64 if is_bidirectional: 65 w = tf.split(tensor_dict[node.inputs[1]], 2)[index] 66 r = tf.split(tensor_dict[node.inputs[2]], 2)[index] 67 else: 68 w = tensor_dict[node.inputs[1]] 69 r = tensor_dict[node.inputs[2]] 70 w_z, w_r, w_h = tf.split(tf.squeeze(w), 3) 71 r_z, r_r, r_h = tf.split(tf.squeeze(r), 3) 72 if names[-2] == "gates": 73 new_w = tf.transpose(tf.concat([w_r, w_z], 0)) 74 new_r = tf.transpose(tf.concat([r_r, r_z], 0)) 75 elif names[-2] == "candidate": 76 new_w = tf.transpose(w_h) 77 new_r = tf.transpose(r_h) 78 kernel = tf.concat([new_w, new_r], 0) 79 return kernel 80 if names[-1] == "bias": 81 if len(node.inputs) >= 4: 82 # onnx Wb[zrh], Rb[zrh] 83 if is_bidirectional: 84 b = tf.split(tensor_dict[node.inputs[3]], 2)[index] 85 else: 86 b = tensor_dict[node.inputs[3]] 87 w_b, r_b = tf.split(tf.squeeze(b), 2) 88 w_b_z, w_b_r, w_b_h = tf.split(w_b, 3) 89 r_b_z, r_b_r, r_b_h = tf.split(r_b, 3) 90 if names[-2] == "gates": 91 w_b = tf.transpose(tf.concat([w_b_r, w_b_z], 0)) 92 r_b = tf.transpose(tf.concat([r_b_r, r_b_z], 0)) 93 elif names[-2] == "candidate": 94 w_b = tf.transpose(w_b_h) 95 r_b = tf.transpose(r_b_h) 96 return tf.add(w_b, r_b) 97 return getter(name, *args, **kwargs) 98 return getter(name, *args, **kwargs) 99 100 @classmethod 101 def _common(cls, node, **kwargs): 102 tensor_dict = kwargs["tensor_dict"] 103 x = tensor_dict[node.inputs[0]] 104 input_shape = x.get_shape().as_list() 105 input_size = len(node.inputs) 106 hidden_size = node.attrs["hidden_size"] 107 direction = node.attrs.get("direction", "forward") 108 num_directions = 2 if direction == "bidirectional" else 1 109 110 # removed from version 7, default is 0 111 output_sequence = node.attrs.get("output_sequence", 0) 112 113 # TODO(fumihwh): check if prev node is one of RNN 114 # process input if it comes from other previous cell 115 # which has shape [seq_length, num_directions, batch_size, hidden_size] 116 if len(input_shape) == 4 and input_shape[1] == 1: 117 x = tf.squeeze(x) 118 119 sequence_length = None 120 if input_size >= 5 and node.inputs[4] in tensor_dict: 121 sequence_length = tensor_dict[node.inputs[4]] 122 123 cell_kwargs = {} 124 125 tf_activations = [tf.nn.tanh] 126 if "activations" in node.attrs: 127 activations = list(map(lambda x: x.lower(), node.attrs["activations"])) 128 activation_alpha = node.attrs.get("activation_alpha", [None] * 4) 129 activation_beta = node.attrs.get("activation_beta", [None] * 4) 130 tf_activations = [ 131 cls.rnn_get_activation(activations[1], activation_alpha[1], 132 activation_beta[1]) 133 ] 134 if num_directions == 2: 135 tf_activations.append( 136 cls.rnn_get_activation(activations[3], activation_alpha[3], 137 activation_beta[3])) 138 139 # TODO(fumihwh): check if reverse and bidirectional works 140 with tf.compat.v1.variable_scope( 141 "GRU_" + get_unique_suffix(), 142 custom_getter=partial( 143 cls._custom_getter, 144 node=node, 145 tensor_dict=tensor_dict, 146 is_bidirectional=num_directions == 2)): 147 148 cell_kwargs["num_units"] = hidden_size 149 if input_size < 4 or node.inputs[3] not in tensor_dict: 150 cell_kwargs["bias_initializer"] = tf.zeros_initializer 151 initial_state = None 152 initial_state_bw = None 153 if input_size == 6: 154 initial_h = tensor_dict.get(node.inputs[5], None) 155 if initial_h is not None: 156 initial_state = (initial_h[0],) 157 if num_directions == 2: 158 initial_state_bw = (initial_h[1],) 159 160 rnn_kwargs = {} 161 if num_directions == 1: 162 rnn_kwargs["initial_state"] = initial_state 163 elif num_directions == 2: 164 rnn_kwargs["initial_state_fw"] = initial_state 165 rnn_kwargs["initial_state_bw"] = initial_state_bw 166 rnn_kwargs["sequence_length"] = sequence_length 167 rnn_kwargs["time_major"] = True 168 rnn_kwargs["dtype"] = tf.float32 169 170 outputs, states = cls.rnn(x, tf.compat.v1.nn.rnn_cell.GRUCell, 171 cell_kwargs, rnn_kwargs, tf_activations, 172 direction) 173 174 if num_directions == 1: 175 state = states[0] 176 h = tf.expand_dims(state, 0) 177 output = tf.expand_dims(outputs, 1) 178 else: 179 state_fw = states[0][0] 180 state_bw = states[1][0] 181 output_fw = outputs[0] 182 output_bw = outputs[1] 183 h_fw = tf.expand_dims(state_fw, 0) 184 h_bw = tf.expand_dims(state_bw, 0) 185 h = tf.concat((h_fw, h_bw), axis=0) 186 output_fw = tf.expand_dims(output_fw, 1) 187 output_bw = tf.expand_dims(output_bw, 1) 188 output = tf.concat((output_fw, output_bw), axis=1) 189 190 return [output, h] if output_sequence == 0 else [h] 191 192 @classmethod 193 def version_1(cls, node, **kwargs): 194 return cls._common(node, **kwargs) 195 196 @classmethod 197 def version_3(cls, node, **kwargs): 198 return cls._common(node, **kwargs) 199 200 @classmethod 201 def version_7(cls, node, **kwargs): 202 return cls._common(node, **kwargs) 203