1"""XLNetForQA models."""
2
3import mxnet as mx
4from mxnet.gluon import HybridBlock, Block, loss, nn
5
6
7class PoolerStartLogits(HybridBlock):
8    """ Compute SQuAD start_logits from sequence hidden states."""
9    def __init__(self, prefix=None, params=None):
10        super(PoolerStartLogits, self).__init__(prefix=prefix, params=params)
11        self.dense = nn.Dense(1, flatten=False)
12
13    def __call__(self, hidden_states, p_masks=None):
14        # pylint: disable=arguments-differ
15        return super(PoolerStartLogits, self).__call__(hidden_states, p_masks)
16
17    def hybrid_forward(self, F, hidden_states, p_mask):
18        """Get start logits from the model output.
19
20        Parameters
21        ----------
22        hidden_states : NDArray, shape (batch_size, seq_length, hidden_size)
23        p_mask : NDArray or None, shape(batch_size, seq_length)
24
25        Returns
26        -------
27        x : NDarray, shape(batch_size, seq_length)
28            Masked start logits.
29        """
30        # pylint: disable=arguments-differ
31        x = self.dense(hidden_states).squeeze(-1)
32        if p_mask is not None:
33            x = x * (1 - p_mask) - 1e30 * p_mask
34        return x
35
36
37class PoolerEndLogits(HybridBlock):
38    """ Compute SQuAD end_logits from sequence hidden states and start token hidden state."""
39    def __init__(self, units=768, is_eval=False, prefix=None, params=None):
40        super(PoolerEndLogits, self).__init__(prefix=prefix, params=params)
41        self._eval = is_eval
42        self._hsz = units
43        with self.name_scope():
44            self.dense_0 = nn.Dense(units, activation='tanh', flatten=False)
45            self.dense_1 = nn.Dense(1, flatten=False)
46            self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=units)
47
48    def __call__(self,
49                 hidden_states,
50                 start_states=None,
51                 start_positions=None,
52                 p_masks=None):
53        # pylint: disable=arguments-differ
54        return super(PoolerEndLogits,
55                     self).__call__(hidden_states, start_states,
56                                    start_positions, p_masks)
57
58    def hybrid_forward(self, F, hidden_states, start_states, start_positions, p_mask):
59        # pylint: disable=arguments-differ
60        """Get end logits from the model output and start states or start positions.
61
62        Parameters
63        ----------
64        hidden_states : NDArray, shape (batch_size, seq_length, hidden_size)
65        start_states : NDArray, shape (batch_size, seq_length, start_n_top, hidden_size)
66            Used during inference
67        start_positions : NDArray, shape (batch_size)
68            Ground-truth start positions used during training.
69        p_mask : NDArray or None, shape(batch_size, seq_length)
70
71        Returns
72        -------
73        x : NDarray, shape(batch_size, seq_length)
74            Masked end logits.
75        """
76        if not self._eval:
77            start_states = F.gather_nd(
78                hidden_states,
79                F.concat(
80                    F.contrib.arange_like(hidden_states,
81                                          axis=0).expand_dims(1),
82                    start_positions.expand_dims(
83                        1)).transpose())  # shape(bsz, hsz)
84            start_states = start_states.expand_dims(1)
85            start_states = F.broadcast_like(
86                start_states, hidden_states)  # shape (bsz, slen, hsz)
87        x = self.dense_0(F.concat(hidden_states, start_states, dim=-1))
88        x = self.layernorm(x)
89        x = self.dense_1(x).squeeze(-1)
90        if p_mask is not None and self._eval:
91            p_mask = p_mask.expand_dims(-1)
92            p_mask = F.broadcast_like(p_mask, x)
93        if p_mask is not None:
94            x = x * (1 - p_mask) - 1e30 * p_mask
95        return x
96
97
98class XLNetPoolerAnswerClass(HybridBlock):
99    """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """
100    def __init__(self, units=768, dropout=0.1, prefix=None, params=None):
101        super(XLNetPoolerAnswerClass, self).__init__(prefix=prefix,
102                                                     params=params)
103        with self.name_scope():
104            self._units = units
105            self.dense_0 = nn.Dense(units,
106                                    in_units=2 * units,
107                                    activation='tanh',
108                                    use_bias=True,
109                                    flatten=False)
110            self.dense_1 = nn.Dense(1,
111                                    in_units=units,
112                                    use_bias=False,
113                                    flatten=False)
114            self._dropout = nn.Dropout(dropout)
115
116    def __call__(self, hidden_states, start_states=None, cls_index=None):
117        # pylint: disable=arguments-differ
118        return super(XLNetPoolerAnswerClass,
119                     self).__call__(hidden_states, start_states, cls_index)
120
121    def hybrid_forward(self, F, hidden_states, start_states, cls_index):
122        # pylint: disable=arguments-differ
123        """Get answerability logits from the model output and start states.
124
125        Parameters
126        ----------
127        hidden_states : NDArray, shape (batch_size, seq_length, hidden_size)
128        start_states : NDArray, shape (batch_size, hidden_size)
129            Typically weighted average hidden_states along second dimension.
130        cls_index : NDArray, shape (batch_size)
131            Index of [CLS] token in sequence.
132
133        Returns
134        -------
135        x : NDarray, shape(batch_size,)
136            CLS logits.
137        """
138        index = F.contrib.arange_like(hidden_states,
139                                      axis=0).expand_dims(1)
140        valid_length_rs = cls_index.reshape((-1, 1)) - 1
141        gather_index = F.transpose(F.concat(index, valid_length_rs), axes=(1, 0))
142        cls_token_state = F.gather_nd(hidden_states, gather_index)
143
144        x = self.dense_0(F.concat(start_states, cls_token_state, dim=-1))
145        x = self._dropout(x)
146        x = self.dense_1(x).squeeze(-1)
147        return x
148
149
150class XLNetForQA(Block):
151    """Model for SQuAD task with XLNet.
152
153    Parameters
154    ----------
155    xlnet_base: XLNet Block
156    start_top_n : int
157        Number of start position candidates during inference.
158    end_top_n : int
159        Number of end position candidates for each start position during inference.
160    is_eval : Bool
161        If set to True, do inference.
162    prefix : str or None
163        See document of `mx.gluon.Block`.
164    params : ParameterDict or None
165        See document of `mx.gluon.Block`.
166    """
167    def __init__(self,
168                 xlnet_base,
169                 start_top_n=None,
170                 end_top_n=None,
171                 is_eval=False,
172                 units=768,
173                 prefix=None,
174                 params=None):
175        super(XLNetForQA, self).__init__(prefix=prefix, params=params)
176        with self.name_scope():
177            self.xlnet = xlnet_base
178            self.start_top_n = start_top_n
179            self.end_top_n = end_top_n
180            self.loss = loss.SoftmaxCELoss()
181            self.start_logits = PoolerStartLogits()
182            self.end_logits = PoolerEndLogits(units=units, is_eval=is_eval)
183            self.eval = is_eval
184            self.answer_class = XLNetPoolerAnswerClass(units=units)
185            self.cls_loss = loss.SigmoidBinaryCrossEntropyLoss()
186
187    def __call__(self,
188                 inputs,
189                 token_types,
190                 valid_length=None,
191                 label=None,
192                 p_mask=None,
193                 is_impossible=None,
194                 mems=None):
195        #pylint: disable=arguments-differ
196        """Generate the unnormalized score for the given the input sequences."""
197        valid_length = [] if valid_length is None else valid_length
198        return super(XLNetForQA,
199                     self).__call__(inputs, token_types, valid_length, p_mask,
200                                    label, is_impossible, mems)
201
202    def _padding_mask(self, inputs, valid_length, left_pad=False):
203        F = mx.ndarray
204        if left_pad:
205            # left padding
206            valid_length_start = valid_length.astype('int64')
207            steps = F.contrib.arange_like(inputs, axis=1) + 1
208            ones = F.ones_like(steps)
209            mask = F.broadcast_greater(
210                F.reshape(steps, shape=(1, -1)),
211                F.reshape(valid_length_start, shape=(-1, 1)))
212            mask = F.broadcast_mul(
213                F.expand_dims(mask, axis=1),
214                F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1))))
215        else:
216            # right padding
217            valid_length = valid_length.astype(inputs.dtype)
218            steps = F.contrib.arange_like(inputs, axis=1)
219            ones = F.ones_like(steps)
220            mask = F.broadcast_lesser(F.reshape(steps, shape=(1, -1)),
221                                      F.reshape(valid_length, shape=(-1, 1)))
222            mask = F.broadcast_mul(
223                F.expand_dims(mask, axis=1),
224                F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1))))
225        return mask
226
227    def forward(self, inputs, token_types, valid_length, p_mask, label,
228                is_impossible, mems):
229        # pylint: disable=arguments-differ
230        """Generate the unnormalized score for the given the input sequences.
231
232        Parameters
233        ----------
234        inputs : NDArray, shape (batch_size, seq_length)
235            Input words for the sequences.
236        token_types : NDArray, shape (batch_size, seq_length)
237            Token types for the sequences, used to indicate whether the word belongs to the
238            first sentence or the second one.
239        valid_length : NDArray or None, shape (batch_size,)
240            Valid length of the sequence. This is used to mask the padded tokens.
241        p_mask : NDArray or None, shape (batch_size, seq_length)
242            We do not want special tokens(e.g., [SEP], [PAD]) and question tokens to be
243            included in answer. Set to 1 to mask the token.
244        label : NDArray, shape (batch_size, 1)
245            Ground-truth label(start/end position) for loss computation.
246        is_impossible : NDArray or None, shape (batch_size ,1)
247            Ground-truth label(is impossible) for loss computation. Set to None for squad1.
248        mems : NDArray
249            We do not use memory(a Transformer XL component) during finetuning.
250
251        Returns
252        -------
253        For training we have:
254        total_loss : list of NDArray
255            Specifically, we have a span loss (batch_size, ) and a cls_loss (batch_size, )
256        total_loss_sum : NDArray
257
258        For inference we have:
259        start_top_log_probs : NDArray, shape (batch_size, start_n_top, )
260        start_top_index :  NDArray, shape (batch_size, start_n_top)
261        end_top_log_probs : NDArray, shape (batch_size, start_n_top * end_n_top)
262        end_top_index : NDArray, shape (batch_size, start_n_top * end_n_top)
263        cls_logits : NDArray or None, shape (batch_size, )
264        """
265        if isinstance(valid_length, list) and len(valid_length) == 0:
266            valid_length = None
267        attention_mask = self._padding_mask(inputs,
268                                            valid_length).astype('float32')
269        output, _ = self.xlnet(inputs, token_types, mems, attention_mask)
270        start_logits = self.start_logits(output,
271                                         p_masks=p_mask)  # shape (bsz, slen)
272        bsz, slen, hsz = output.shape
273        if not self.eval:
274            # training
275            start_positions, end_positions = label
276            end_logit = self.end_logits(output,
277                                        start_positions=start_positions,
278                                        p_masks=p_mask)
279            span_loss = (self.loss(start_logits, start_positions) +
280                         self.loss(end_logit, end_positions)) / 2
281
282            total_loss = [span_loss]
283
284            # get cls loss
285            start_log_probs = mx.nd.softmax(start_logits, axis=-1)
286            start_states = mx.nd.batch_dot(output,
287                                           start_log_probs.expand_dims(-1),
288                                           transpose_a=True).squeeze(-1)
289
290            cls_logits = self.answer_class(output, start_states,
291                                           valid_length)
292            cls_loss = self.cls_loss(cls_logits, is_impossible)
293            total_loss.append(0.5 * cls_loss)
294            total_loss_sum = span_loss + 0.5 * cls_loss
295            return total_loss, total_loss_sum
296        else:
297            #inference
298            start_log_probs = mx.nd.log_softmax(start_logits,
299                                                axis=-1)  # shape (bsz, slen)
300            start_top_log_probs, start_top_index = mx.ndarray.topk(
301                start_log_probs, k=self.start_top_n, axis=-1,
302                ret_typ='both')  # shape (bsz, start_n_top)
303            index = mx.nd.concat(*[
304                mx.nd.arange(bsz, ctx=start_log_probs.context).expand_dims(1)
305            ] * self.start_top_n).reshape(bsz * self.start_top_n, 1)
306            start_top_index_rs = start_top_index.reshape((-1, 1))
307            gather_index = mx.nd.concat(
308                index, start_top_index_rs).T  #shape(2, bsz * start_n_top)
309            start_states = mx.nd.gather_nd(output, gather_index).reshape(
310                (bsz, self.start_top_n, hsz))  #shape (bsz, start_n_top, hsz)
311
312            start_states = start_states.expand_dims(1)
313            start_states = mx.nd.broadcast_to(
314                start_states, (bsz, slen, self.start_top_n,
315                               hsz))  # shape (bsz, slen, start_n_top, hsz)
316            hidden_states_expanded = output.expand_dims(2)
317            hidden_states_expanded = mx.ndarray.broadcast_to(
318                hidden_states_expanded, shape=start_states.shape
319            )  # shape (bsz, slen, start_n_top, hsz)
320            end_logits = self.end_logits(
321                hidden_states_expanded,
322                start_states=start_states,
323                p_masks=p_mask)  # shape (bsz, slen, start_n_top)
324            end_log_probs = mx.nd.log_softmax(
325                end_logits, axis=1)  # shape (bsz, slen, start_n_top)
326            # Note that end_top_index and end_top_log_probs have shape (bsz, END_N_TOP, start_n_top)
327            # So that for each start position, there are end_n_top end positions on the second dim.
328            end_top_log_probs, end_top_index = mx.ndarray.topk(
329                end_log_probs, k=self.end_top_n, axis=1,
330                ret_typ='both')  # shape (bsz, end_n_top, start_n_top)
331            end_top_log_probs = end_top_log_probs.reshape(
332                (-1, self.start_top_n * self.end_top_n))
333            end_top_index = end_top_index.reshape(
334                (-1, self.start_top_n * self.end_top_n))
335
336            start_probs = mx.nd.softmax(start_logits, axis=-1)
337            start_states = mx.nd.batch_dot(output,
338                                           start_probs.expand_dims(-1),
339                                           transpose_a=True).squeeze(-1)
340            cls_logits = self.answer_class(output, start_states,
341                                           valid_length)
342
343            outputs = (start_top_log_probs, start_top_index, end_top_log_probs,
344                       end_top_index, cls_logits)
345            return outputs
346