1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Gluon model block for the named entity recognition task.""" 18 19from contextlib import ExitStack 20 21import mxnet as mx 22from mxnet.gluon import Block, nn 23 24 25class BERTTagger(Block): 26 """Model for sequence tagging with BERT 27 28 Parameters 29 ---------- 30 bert_model: BERTModel 31 Bidirectional encoder with transformer. 32 num_tag_types: int 33 number of possible tags 34 dropout_prob: float 35 dropout probability for the last layer 36 prefix: str or None 37 See document of `mx.gluon.Block`. 38 params: ParameterDict or None 39 See document of `mx.gluon.Block`. 40 """ 41 42 def __init__(self, bert_model, num_tag_types, dropout_prob, prefix=None, params=None): 43 super(BERTTagger, self).__init__(prefix=prefix, params=params) 44 self.bert_model = bert_model 45 with self.name_scope(): 46 self.tag_classifier = nn.Dense(units=num_tag_types, flatten=False) 47 self.dropout = nn.Dropout(rate=dropout_prob) 48 49 def forward(self, token_ids, token_types, valid_length): # pylint: disable=arguments-differ 50 """Generate an unnormalized score for the tag of each token 51 52 Parameters 53 ---------- 54 token_ids: NDArray, shape (batch_size, seq_length) 55 ID of tokens in sentences 56 See `input` of `glounnlp.model.BERTModel` 57 token_types: NDArray, shape (batch_size, seq_length) 58 See `glounnlp.model.BERTModel` 59 valid_length: NDArray, shape (batch_size,) 60 See `glounnlp.model.BERTModel` 61 62 Returns 63 ------- 64 NDArray, shape (batch_size, seq_length, num_tag_types): 65 Unnormalized prediction scores for each tag on each position. 66 """ 67 bert_output = self.dropout(self.bert_model(token_ids, token_types, valid_length)) 68 output = self.tag_classifier(bert_output) 69 return output 70 71 72def attach_prediction(data_loader, net, ctx, is_train): 73 """Attach the prediction from a model to a data loader as the last field. 74 75 Parameters 76 ---------- 77 data_loader: mx.gluon.data.DataLoader 78 Input data from `bert_model.BERTTaggingDataset._encode_as_input`. 79 net: mx.gluon.Block 80 gluon `Block` for making the preciction. 81 ctx: 82 The context data should be loaded to. 83 is_train: 84 Whether the forward pass should be made with `mx.autograd.record()`. 85 86 Returns 87 ------- 88 All fields from `bert_model.BERTTaggingDataset._encode_as_input`, 89 as well as the prediction of the model. 90 91 """ 92 for data in data_loader: 93 text_ids, token_types, valid_length, tag_ids, flag_nonnull_tag = \ 94 [x.astype('float32').as_in_context(ctx) for x in data] 95 96 with ExitStack() as stack: 97 if is_train: 98 stack.enter_context(mx.autograd.record()) 99 out = net(text_ids, token_types, valid_length) 100 yield text_ids, token_types, valid_length, tag_ids, flag_nonnull_tag, out 101