1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Utility functions for BERT."""
18
19import sys
20import logging
21import collections
22import hashlib
23import io
24
25import mxnet as mx
26import gluonnlp as nlp
27
28__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab']
29
30
31def tf_vocab_to_gluon_vocab(tf_vocab):
32    special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]']
33    assert all(t in tf_vocab for t in special_tokens)
34    counter = nlp.data.count_tokens(tf_vocab.keys())
35    vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab)
36    return vocab
37
38
39def get_hash(filename):
40    sha1 = hashlib.sha1()
41    with open(filename, 'rb') as f:
42        while True:
43            data = f.read(1048576)
44            if not data:
45                break
46            sha1.update(data)
47    return sha1.hexdigest(), str(sha1.hexdigest())[:8]
48
49
50def read_tf_checkpoint(path):
51    """read tensorflow checkpoint"""
52    from tensorflow.python import pywrap_tensorflow  # pylint: disable=import-outside-toplevel
53    tensors = {}
54    reader = pywrap_tensorflow.NewCheckpointReader(path)
55    var_to_shape_map = reader.get_variable_to_shape_map()
56    for key in sorted(var_to_shape_map):
57        tensor = reader.get_tensor(key)
58        tensors[key] = tensor
59    return tensors
60
61def profile(curr_step, start_step, end_step, profile_name='profile.json',
62            early_exit=True):
63    """profile the program between [start_step, end_step)."""
64    if curr_step == start_step:
65        mx.nd.waitall()
66        mx.profiler.set_config(profile_memory=False, profile_symbolic=True,
67                               profile_imperative=True, filename=profile_name,
68                               aggregate_stats=True)
69        mx.profiler.set_state('run')
70    elif curr_step == end_step:
71        mx.nd.waitall()
72        mx.profiler.set_state('stop')
73        logging.info(mx.profiler.dumps())
74        mx.profiler.dump()
75        if early_exit:
76            sys.exit(0)
77
78def load_text_vocab(vocab_file):
79    """Loads a vocabulary file into a dictionary."""
80    vocab = collections.OrderedDict()
81    index = 0
82    with io.open(vocab_file, 'r') as reader:
83        while True:
84            token = reader.readline()
85            if not token:
86                break
87            token = token.strip()
88            vocab[token] = index
89            index += 1
90    return vocab
91