1from functools import partial
2
3import tensorflow as tf
4
5from onnx_tf.common import get_unique_suffix
6from onnx_tf.common import exception
7from onnx_tf.handlers.backend_handler import BackendHandler
8from onnx_tf.handlers.handler import onnx_op
9from onnx_tf.handlers.handler import partial_support
10from onnx_tf.handlers.handler import ps_description
11from .rnn_mixin import RNNMixin
12
13
14@onnx_op("GRU")
15@partial_support(True)
16@ps_description(
17    "GRU with clip or GRU with linear_before_reset, or " +
18    "GRU not using sigmoid for z and r, or " +
19    "GRU using Elu as the activation function " + "with alpha != 1, or " +
20    "GRU using HardSigmoid as the activation function " +
21    "with alpha != 0.2 or beta != 0.5 " + "are not supported in TensorFlow.")
22class GRU(RNNMixin, BackendHandler):
23
24  @classmethod
25  def args_check(cls, node, **kwargs):
26    direction = node.attrs.get("direction", "forward")
27    num_directions = 2 if direction == "bidirectional" else 1
28    if "clip" in node.attrs:
29      exception.OP_UNSUPPORTED_EXCEPT("GRU with clip", "Tensorflow")
30    if node.attrs.get("linear_before_reset", 0):
31      exception.OP_UNSUPPORTED_EXCEPT("GRU with linear_before_reset",
32                                      "Tensorflow")
33    if "activations" in node.attrs:
34      activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
35      if activations[0] != "sigmoid":
36        exception.OP_UNSUPPORTED_EXCEPT("GRU without sigmoid for `z` and `r`",
37                                        "Tensorflow")
38      if num_directions == 2:
39        if activations[2] != "sigmoid":
40          exception.OP_UNSUPPORTED_EXCEPT("GRU without sigmoid for `z` and `r`",
41                                          "Tensorflow")
42
43  @classmethod
44  def _custom_getter(cls,
45                     getter,
46                     name,
47                     node=None,
48                     tensor_dict=None,
49                     is_bidirectional=None,
50                     *args,
51                     **kwargs):
52    names = name.split("/")
53    if is_bidirectional:
54      if "fw" in names:
55        index = 0
56      elif "bw" in names:
57        index = 1
58      else:
59        raise RuntimeError("Can not get {} for bidirectional. "
60                           "Either fw and bw is not in name scope.".format(
61                               names[-1]))
62    if names[-1] == "kernel":
63      # onnx W[zrh], R[zrh]
64      if is_bidirectional:
65        w = tf.split(tensor_dict[node.inputs[1]], 2)[index]
66        r = tf.split(tensor_dict[node.inputs[2]], 2)[index]
67      else:
68        w = tensor_dict[node.inputs[1]]
69        r = tensor_dict[node.inputs[2]]
70      w_z, w_r, w_h = tf.split(tf.squeeze(w), 3)
71      r_z, r_r, r_h = tf.split(tf.squeeze(r), 3)
72      if names[-2] == "gates":
73        new_w = tf.transpose(tf.concat([w_r, w_z], 0))
74        new_r = tf.transpose(tf.concat([r_r, r_z], 0))
75      elif names[-2] == "candidate":
76        new_w = tf.transpose(w_h)
77        new_r = tf.transpose(r_h)
78      kernel = tf.concat([new_w, new_r], 0)
79      return kernel
80    if names[-1] == "bias":
81      if len(node.inputs) >= 4:
82        # onnx Wb[zrh], Rb[zrh]
83        if is_bidirectional:
84          b = tf.split(tensor_dict[node.inputs[3]], 2)[index]
85        else:
86          b = tensor_dict[node.inputs[3]]
87        w_b, r_b = tf.split(tf.squeeze(b), 2)
88        w_b_z, w_b_r, w_b_h = tf.split(w_b, 3)
89        r_b_z, r_b_r, r_b_h = tf.split(r_b, 3)
90        if names[-2] == "gates":
91          w_b = tf.transpose(tf.concat([w_b_r, w_b_z], 0))
92          r_b = tf.transpose(tf.concat([r_b_r, r_b_z], 0))
93        elif names[-2] == "candidate":
94          w_b = tf.transpose(w_b_h)
95          r_b = tf.transpose(r_b_h)
96        return tf.add(w_b, r_b)
97      return getter(name, *args, **kwargs)
98    return getter(name, *args, **kwargs)
99
100  @classmethod
101  def _common(cls, node, **kwargs):
102    tensor_dict = kwargs["tensor_dict"]
103    x = tensor_dict[node.inputs[0]]
104    input_shape = x.get_shape().as_list()
105    input_size = len(node.inputs)
106    hidden_size = node.attrs["hidden_size"]
107    direction = node.attrs.get("direction", "forward")
108    num_directions = 2 if direction == "bidirectional" else 1
109
110    # removed from version 7, default is 0
111    output_sequence = node.attrs.get("output_sequence", 0)
112
113    # TODO(fumihwh): check if prev node is one of RNN
114    # process input if it comes from other previous cell
115    # which has shape [seq_length, num_directions, batch_size, hidden_size]
116    if len(input_shape) == 4 and input_shape[1] == 1:
117      x = tf.squeeze(x)
118
119    sequence_length = None
120    if input_size >= 5 and node.inputs[4] in tensor_dict:
121      sequence_length = tensor_dict[node.inputs[4]]
122
123    cell_kwargs = {}
124
125    tf_activations = [tf.nn.tanh]
126    if "activations" in node.attrs:
127      activations = list(map(lambda x: x.lower(), node.attrs["activations"]))
128      activation_alpha = node.attrs.get("activation_alpha", [None] * 4)
129      activation_beta = node.attrs.get("activation_beta", [None] * 4)
130      tf_activations = [
131          cls.rnn_get_activation(activations[1], activation_alpha[1],
132                                 activation_beta[1])
133      ]
134      if num_directions == 2:
135        tf_activations.append(
136            cls.rnn_get_activation(activations[3], activation_alpha[3],
137                                   activation_beta[3]))
138
139    # TODO(fumihwh): check if reverse and bidirectional works
140    with tf.compat.v1.variable_scope(
141        "GRU_" + get_unique_suffix(),
142        custom_getter=partial(
143            cls._custom_getter,
144            node=node,
145            tensor_dict=tensor_dict,
146            is_bidirectional=num_directions == 2)):
147
148      cell_kwargs["num_units"] = hidden_size
149      if input_size < 4 or node.inputs[3] not in tensor_dict:
150        cell_kwargs["bias_initializer"] = tf.zeros_initializer
151      initial_state = None
152      initial_state_bw = None
153      if input_size == 6:
154        initial_h = tensor_dict.get(node.inputs[5], None)
155        if initial_h is not None:
156          initial_state = (initial_h[0],)
157          if num_directions == 2:
158            initial_state_bw = (initial_h[1],)
159
160      rnn_kwargs = {}
161      if num_directions == 1:
162        rnn_kwargs["initial_state"] = initial_state
163      elif num_directions == 2:
164        rnn_kwargs["initial_state_fw"] = initial_state
165        rnn_kwargs["initial_state_bw"] = initial_state_bw
166      rnn_kwargs["sequence_length"] = sequence_length
167      rnn_kwargs["time_major"] = True
168      rnn_kwargs["dtype"] = tf.float32
169
170      outputs, states = cls.rnn(x, tf.compat.v1.nn.rnn_cell.GRUCell,
171                                cell_kwargs, rnn_kwargs, tf_activations,
172                                direction)
173
174    if num_directions == 1:
175      state = states[0]
176      h = tf.expand_dims(state, 0)
177      output = tf.expand_dims(outputs, 1)
178    else:
179      state_fw = states[0][0]
180      state_bw = states[1][0]
181      output_fw = outputs[0]
182      output_bw = outputs[1]
183      h_fw = tf.expand_dims(state_fw, 0)
184      h_bw = tf.expand_dims(state_bw, 0)
185      h = tf.concat((h_fw, h_bw), axis=0)
186      output_fw = tf.expand_dims(output_fw, 1)
187      output_bw = tf.expand_dims(output_bw, 1)
188      output = tf.concat((output_fw, output_bw), axis=1)
189
190    return [output, h] if output_sequence == 0 else [h]
191
192  @classmethod
193  def version_1(cls, node, **kwargs):
194    return cls._common(node, **kwargs)
195
196  @classmethod
197  def version_3(cls, node, **kwargs):
198    return cls._common(node, **kwargs)
199
200  @classmethod
201  def version_7(cls, node, **kwargs):
202    return cls._common(node, **kwargs)
203