1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Utility functions for BERT.""" 18 19import sys 20import logging 21import collections 22import hashlib 23import io 24 25import mxnet as mx 26import gluonnlp as nlp 27 28__all__ = ['tf_vocab_to_gluon_vocab', 'load_text_vocab'] 29 30 31def tf_vocab_to_gluon_vocab(tf_vocab): 32 special_tokens = ['[UNK]', '[PAD]', '[SEP]', '[MASK]', '[CLS]'] 33 assert all(t in tf_vocab for t in special_tokens) 34 counter = nlp.data.count_tokens(tf_vocab.keys()) 35 vocab = nlp.vocab.BERTVocab(counter, token_to_idx=tf_vocab) 36 return vocab 37 38 39def get_hash(filename): 40 sha1 = hashlib.sha1() 41 with open(filename, 'rb') as f: 42 while True: 43 data = f.read(1048576) 44 if not data: 45 break 46 sha1.update(data) 47 return sha1.hexdigest(), str(sha1.hexdigest())[:8] 48 49 50def read_tf_checkpoint(path): 51 """read tensorflow checkpoint""" 52 from tensorflow.python import pywrap_tensorflow # pylint: disable=import-outside-toplevel 53 tensors = {} 54 reader = pywrap_tensorflow.NewCheckpointReader(path) 55 var_to_shape_map = reader.get_variable_to_shape_map() 56 for key in sorted(var_to_shape_map): 57 tensor = reader.get_tensor(key) 58 tensors[key] = tensor 59 return tensors 60 61def profile(curr_step, start_step, end_step, profile_name='profile.json', 62 early_exit=True): 63 """profile the program between [start_step, end_step).""" 64 if curr_step == start_step: 65 mx.nd.waitall() 66 mx.profiler.set_config(profile_memory=False, profile_symbolic=True, 67 profile_imperative=True, filename=profile_name, 68 aggregate_stats=True) 69 mx.profiler.set_state('run') 70 elif curr_step == end_step: 71 mx.nd.waitall() 72 mx.profiler.set_state('stop') 73 logging.info(mx.profiler.dumps()) 74 mx.profiler.dump() 75 if early_exit: 76 sys.exit(0) 77 78def load_text_vocab(vocab_file): 79 """Loads a vocabulary file into a dictionary.""" 80 vocab = collections.OrderedDict() 81 index = 0 82 with io.open(vocab_file, 'r') as reader: 83 while True: 84 token = reader.readline() 85 if not token: 86 break 87 token = token.strip() 88 vocab[token] = index 89 index += 1 90 return vocab 91