1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Gluon model block for the named entity recognition task."""
18
19from contextlib import ExitStack
20
21import mxnet as mx
22from mxnet.gluon import Block, nn
23
24
25class BERTTagger(Block):
26    """Model for sequence tagging with BERT
27
28    Parameters
29    ----------
30    bert_model: BERTModel
31        Bidirectional encoder with transformer.
32    num_tag_types: int
33        number of possible tags
34    dropout_prob: float
35        dropout probability for the last layer
36    prefix: str or None
37        See document of `mx.gluon.Block`.
38    params: ParameterDict or None
39        See document of `mx.gluon.Block`.
40    """
41
42    def __init__(self, bert_model, num_tag_types, dropout_prob, prefix=None, params=None):
43        super(BERTTagger, self).__init__(prefix=prefix, params=params)
44        self.bert_model = bert_model
45        with self.name_scope():
46            self.tag_classifier = nn.Dense(units=num_tag_types, flatten=False)
47            self.dropout = nn.Dropout(rate=dropout_prob)
48
49    def forward(self, token_ids, token_types, valid_length): # pylint: disable=arguments-differ
50        """Generate an unnormalized score for the tag of each token
51
52        Parameters
53        ----------
54        token_ids: NDArray, shape (batch_size, seq_length)
55            ID of tokens in sentences
56            See `input` of `glounnlp.model.BERTModel`
57        token_types: NDArray, shape (batch_size, seq_length)
58            See `glounnlp.model.BERTModel`
59        valid_length: NDArray, shape (batch_size,)
60            See `glounnlp.model.BERTModel`
61
62        Returns
63        -------
64        NDArray, shape (batch_size, seq_length, num_tag_types):
65            Unnormalized prediction scores for each tag on each position.
66        """
67        bert_output = self.dropout(self.bert_model(token_ids, token_types, valid_length))
68        output = self.tag_classifier(bert_output)
69        return output
70
71
72def attach_prediction(data_loader, net, ctx, is_train):
73    """Attach the prediction from a model to a data loader as the last field.
74
75    Parameters
76    ----------
77    data_loader: mx.gluon.data.DataLoader
78        Input data from `bert_model.BERTTaggingDataset._encode_as_input`.
79    net: mx.gluon.Block
80        gluon `Block` for making the preciction.
81    ctx:
82        The context data should be loaded to.
83    is_train:
84        Whether the forward pass should be made with `mx.autograd.record()`.
85
86    Returns
87    -------
88        All fields from `bert_model.BERTTaggingDataset._encode_as_input`,
89        as well as the prediction of the model.
90
91    """
92    for data in data_loader:
93        text_ids, token_types, valid_length, tag_ids, flag_nonnull_tag = \
94            [x.astype('float32').as_in_context(ctx) for x in data]
95
96        with ExitStack() as stack:
97            if is_train:
98                stack.enter_context(mx.autograd.record())
99            out = net(text_ids, token_types, valid_length)
100        yield text_ids, token_types, valid_length, tag_ids, flag_nonnull_tag, out
101