1"""XLNetForQA models.""" 2 3import mxnet as mx 4from mxnet.gluon import HybridBlock, Block, loss, nn 5 6 7class PoolerStartLogits(HybridBlock): 8 """ Compute SQuAD start_logits from sequence hidden states.""" 9 def __init__(self, prefix=None, params=None): 10 super(PoolerStartLogits, self).__init__(prefix=prefix, params=params) 11 self.dense = nn.Dense(1, flatten=False) 12 13 def __call__(self, hidden_states, p_masks=None): 14 # pylint: disable=arguments-differ 15 return super(PoolerStartLogits, self).__call__(hidden_states, p_masks) 16 17 def hybrid_forward(self, F, hidden_states, p_mask): 18 """Get start logits from the model output. 19 20 Parameters 21 ---------- 22 hidden_states : NDArray, shape (batch_size, seq_length, hidden_size) 23 p_mask : NDArray or None, shape(batch_size, seq_length) 24 25 Returns 26 ------- 27 x : NDarray, shape(batch_size, seq_length) 28 Masked start logits. 29 """ 30 # pylint: disable=arguments-differ 31 x = self.dense(hidden_states).squeeze(-1) 32 if p_mask is not None: 33 x = x * (1 - p_mask) - 1e30 * p_mask 34 return x 35 36 37class PoolerEndLogits(HybridBlock): 38 """ Compute SQuAD end_logits from sequence hidden states and start token hidden state.""" 39 def __init__(self, units=768, is_eval=False, prefix=None, params=None): 40 super(PoolerEndLogits, self).__init__(prefix=prefix, params=params) 41 self._eval = is_eval 42 self._hsz = units 43 with self.name_scope(): 44 self.dense_0 = nn.Dense(units, activation='tanh', flatten=False) 45 self.dense_1 = nn.Dense(1, flatten=False) 46 self.layernorm = nn.LayerNorm(epsilon=1e-12, in_channels=units) 47 48 def __call__(self, 49 hidden_states, 50 start_states=None, 51 start_positions=None, 52 p_masks=None): 53 # pylint: disable=arguments-differ 54 return super(PoolerEndLogits, 55 self).__call__(hidden_states, start_states, 56 start_positions, p_masks) 57 58 def hybrid_forward(self, F, hidden_states, start_states, start_positions, p_mask): 59 # pylint: disable=arguments-differ 60 """Get end logits from the model output and start states or start positions. 61 62 Parameters 63 ---------- 64 hidden_states : NDArray, shape (batch_size, seq_length, hidden_size) 65 start_states : NDArray, shape (batch_size, seq_length, start_n_top, hidden_size) 66 Used during inference 67 start_positions : NDArray, shape (batch_size) 68 Ground-truth start positions used during training. 69 p_mask : NDArray or None, shape(batch_size, seq_length) 70 71 Returns 72 ------- 73 x : NDarray, shape(batch_size, seq_length) 74 Masked end logits. 75 """ 76 if not self._eval: 77 start_states = F.gather_nd( 78 hidden_states, 79 F.concat( 80 F.contrib.arange_like(hidden_states, 81 axis=0).expand_dims(1), 82 start_positions.expand_dims( 83 1)).transpose()) # shape(bsz, hsz) 84 start_states = start_states.expand_dims(1) 85 start_states = F.broadcast_like( 86 start_states, hidden_states) # shape (bsz, slen, hsz) 87 x = self.dense_0(F.concat(hidden_states, start_states, dim=-1)) 88 x = self.layernorm(x) 89 x = self.dense_1(x).squeeze(-1) 90 if p_mask is not None and self._eval: 91 p_mask = p_mask.expand_dims(-1) 92 p_mask = F.broadcast_like(p_mask, x) 93 if p_mask is not None: 94 x = x * (1 - p_mask) - 1e30 * p_mask 95 return x 96 97 98class XLNetPoolerAnswerClass(HybridBlock): 99 """ Compute SQuAD 2.0 answer class from classification and start tokens hidden states. """ 100 def __init__(self, units=768, dropout=0.1, prefix=None, params=None): 101 super(XLNetPoolerAnswerClass, self).__init__(prefix=prefix, 102 params=params) 103 with self.name_scope(): 104 self._units = units 105 self.dense_0 = nn.Dense(units, 106 in_units=2 * units, 107 activation='tanh', 108 use_bias=True, 109 flatten=False) 110 self.dense_1 = nn.Dense(1, 111 in_units=units, 112 use_bias=False, 113 flatten=False) 114 self._dropout = nn.Dropout(dropout) 115 116 def __call__(self, hidden_states, start_states=None, cls_index=None): 117 # pylint: disable=arguments-differ 118 return super(XLNetPoolerAnswerClass, 119 self).__call__(hidden_states, start_states, cls_index) 120 121 def hybrid_forward(self, F, hidden_states, start_states, cls_index): 122 # pylint: disable=arguments-differ 123 """Get answerability logits from the model output and start states. 124 125 Parameters 126 ---------- 127 hidden_states : NDArray, shape (batch_size, seq_length, hidden_size) 128 start_states : NDArray, shape (batch_size, hidden_size) 129 Typically weighted average hidden_states along second dimension. 130 cls_index : NDArray, shape (batch_size) 131 Index of [CLS] token in sequence. 132 133 Returns 134 ------- 135 x : NDarray, shape(batch_size,) 136 CLS logits. 137 """ 138 index = F.contrib.arange_like(hidden_states, 139 axis=0).expand_dims(1) 140 valid_length_rs = cls_index.reshape((-1, 1)) - 1 141 gather_index = F.transpose(F.concat(index, valid_length_rs), axes=(1, 0)) 142 cls_token_state = F.gather_nd(hidden_states, gather_index) 143 144 x = self.dense_0(F.concat(start_states, cls_token_state, dim=-1)) 145 x = self._dropout(x) 146 x = self.dense_1(x).squeeze(-1) 147 return x 148 149 150class XLNetForQA(Block): 151 """Model for SQuAD task with XLNet. 152 153 Parameters 154 ---------- 155 xlnet_base: XLNet Block 156 start_top_n : int 157 Number of start position candidates during inference. 158 end_top_n : int 159 Number of end position candidates for each start position during inference. 160 is_eval : Bool 161 If set to True, do inference. 162 prefix : str or None 163 See document of `mx.gluon.Block`. 164 params : ParameterDict or None 165 See document of `mx.gluon.Block`. 166 """ 167 def __init__(self, 168 xlnet_base, 169 start_top_n=None, 170 end_top_n=None, 171 is_eval=False, 172 units=768, 173 prefix=None, 174 params=None): 175 super(XLNetForQA, self).__init__(prefix=prefix, params=params) 176 with self.name_scope(): 177 self.xlnet = xlnet_base 178 self.start_top_n = start_top_n 179 self.end_top_n = end_top_n 180 self.loss = loss.SoftmaxCELoss() 181 self.start_logits = PoolerStartLogits() 182 self.end_logits = PoolerEndLogits(units=units, is_eval=is_eval) 183 self.eval = is_eval 184 self.answer_class = XLNetPoolerAnswerClass(units=units) 185 self.cls_loss = loss.SigmoidBinaryCrossEntropyLoss() 186 187 def __call__(self, 188 inputs, 189 token_types, 190 valid_length=None, 191 label=None, 192 p_mask=None, 193 is_impossible=None, 194 mems=None): 195 #pylint: disable=arguments-differ 196 """Generate the unnormalized score for the given the input sequences.""" 197 valid_length = [] if valid_length is None else valid_length 198 return super(XLNetForQA, 199 self).__call__(inputs, token_types, valid_length, p_mask, 200 label, is_impossible, mems) 201 202 def _padding_mask(self, inputs, valid_length, left_pad=False): 203 F = mx.ndarray 204 if left_pad: 205 # left padding 206 valid_length_start = valid_length.astype('int64') 207 steps = F.contrib.arange_like(inputs, axis=1) + 1 208 ones = F.ones_like(steps) 209 mask = F.broadcast_greater( 210 F.reshape(steps, shape=(1, -1)), 211 F.reshape(valid_length_start, shape=(-1, 1))) 212 mask = F.broadcast_mul( 213 F.expand_dims(mask, axis=1), 214 F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1)))) 215 else: 216 # right padding 217 valid_length = valid_length.astype(inputs.dtype) 218 steps = F.contrib.arange_like(inputs, axis=1) 219 ones = F.ones_like(steps) 220 mask = F.broadcast_lesser(F.reshape(steps, shape=(1, -1)), 221 F.reshape(valid_length, shape=(-1, 1))) 222 mask = F.broadcast_mul( 223 F.expand_dims(mask, axis=1), 224 F.broadcast_mul(ones, F.reshape(ones, shape=(-1, 1)))) 225 return mask 226 227 def forward(self, inputs, token_types, valid_length, p_mask, label, 228 is_impossible, mems): 229 # pylint: disable=arguments-differ 230 """Generate the unnormalized score for the given the input sequences. 231 232 Parameters 233 ---------- 234 inputs : NDArray, shape (batch_size, seq_length) 235 Input words for the sequences. 236 token_types : NDArray, shape (batch_size, seq_length) 237 Token types for the sequences, used to indicate whether the word belongs to the 238 first sentence or the second one. 239 valid_length : NDArray or None, shape (batch_size,) 240 Valid length of the sequence. This is used to mask the padded tokens. 241 p_mask : NDArray or None, shape (batch_size, seq_length) 242 We do not want special tokens(e.g., [SEP], [PAD]) and question tokens to be 243 included in answer. Set to 1 to mask the token. 244 label : NDArray, shape (batch_size, 1) 245 Ground-truth label(start/end position) for loss computation. 246 is_impossible : NDArray or None, shape (batch_size ,1) 247 Ground-truth label(is impossible) for loss computation. Set to None for squad1. 248 mems : NDArray 249 We do not use memory(a Transformer XL component) during finetuning. 250 251 Returns 252 ------- 253 For training we have: 254 total_loss : list of NDArray 255 Specifically, we have a span loss (batch_size, ) and a cls_loss (batch_size, ) 256 total_loss_sum : NDArray 257 258 For inference we have: 259 start_top_log_probs : NDArray, shape (batch_size, start_n_top, ) 260 start_top_index : NDArray, shape (batch_size, start_n_top) 261 end_top_log_probs : NDArray, shape (batch_size, start_n_top * end_n_top) 262 end_top_index : NDArray, shape (batch_size, start_n_top * end_n_top) 263 cls_logits : NDArray or None, shape (batch_size, ) 264 """ 265 if isinstance(valid_length, list) and len(valid_length) == 0: 266 valid_length = None 267 attention_mask = self._padding_mask(inputs, 268 valid_length).astype('float32') 269 output, _ = self.xlnet(inputs, token_types, mems, attention_mask) 270 start_logits = self.start_logits(output, 271 p_masks=p_mask) # shape (bsz, slen) 272 bsz, slen, hsz = output.shape 273 if not self.eval: 274 # training 275 start_positions, end_positions = label 276 end_logit = self.end_logits(output, 277 start_positions=start_positions, 278 p_masks=p_mask) 279 span_loss = (self.loss(start_logits, start_positions) + 280 self.loss(end_logit, end_positions)) / 2 281 282 total_loss = [span_loss] 283 284 # get cls loss 285 start_log_probs = mx.nd.softmax(start_logits, axis=-1) 286 start_states = mx.nd.batch_dot(output, 287 start_log_probs.expand_dims(-1), 288 transpose_a=True).squeeze(-1) 289 290 cls_logits = self.answer_class(output, start_states, 291 valid_length) 292 cls_loss = self.cls_loss(cls_logits, is_impossible) 293 total_loss.append(0.5 * cls_loss) 294 total_loss_sum = span_loss + 0.5 * cls_loss 295 return total_loss, total_loss_sum 296 else: 297 #inference 298 start_log_probs = mx.nd.log_softmax(start_logits, 299 axis=-1) # shape (bsz, slen) 300 start_top_log_probs, start_top_index = mx.ndarray.topk( 301 start_log_probs, k=self.start_top_n, axis=-1, 302 ret_typ='both') # shape (bsz, start_n_top) 303 index = mx.nd.concat(*[ 304 mx.nd.arange(bsz, ctx=start_log_probs.context).expand_dims(1) 305 ] * self.start_top_n).reshape(bsz * self.start_top_n, 1) 306 start_top_index_rs = start_top_index.reshape((-1, 1)) 307 gather_index = mx.nd.concat( 308 index, start_top_index_rs).T #shape(2, bsz * start_n_top) 309 start_states = mx.nd.gather_nd(output, gather_index).reshape( 310 (bsz, self.start_top_n, hsz)) #shape (bsz, start_n_top, hsz) 311 312 start_states = start_states.expand_dims(1) 313 start_states = mx.nd.broadcast_to( 314 start_states, (bsz, slen, self.start_top_n, 315 hsz)) # shape (bsz, slen, start_n_top, hsz) 316 hidden_states_expanded = output.expand_dims(2) 317 hidden_states_expanded = mx.ndarray.broadcast_to( 318 hidden_states_expanded, shape=start_states.shape 319 ) # shape (bsz, slen, start_n_top, hsz) 320 end_logits = self.end_logits( 321 hidden_states_expanded, 322 start_states=start_states, 323 p_masks=p_mask) # shape (bsz, slen, start_n_top) 324 end_log_probs = mx.nd.log_softmax( 325 end_logits, axis=1) # shape (bsz, slen, start_n_top) 326 # Note that end_top_index and end_top_log_probs have shape (bsz, END_N_TOP, start_n_top) 327 # So that for each start position, there are end_n_top end positions on the second dim. 328 end_top_log_probs, end_top_index = mx.ndarray.topk( 329 end_log_probs, k=self.end_top_n, axis=1, 330 ret_typ='both') # shape (bsz, end_n_top, start_n_top) 331 end_top_log_probs = end_top_log_probs.reshape( 332 (-1, self.start_top_n * self.end_top_n)) 333 end_top_index = end_top_index.reshape( 334 (-1, self.start_top_n * self.end_top_n)) 335 336 start_probs = mx.nd.softmax(start_logits, axis=-1) 337 start_states = mx.nd.batch_dot(output, 338 start_probs.expand_dims(-1), 339 transpose_a=True).squeeze(-1) 340 cls_logits = self.answer_class(output, start_states, 341 valid_length) 342 343 outputs = (start_top_log_probs, start_top_index, end_top_log_probs, 344 end_top_index, cls_logits) 345 return outputs 346