1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable= 19"""Extract the vocabulary from a file and write it to disk.""" 20 21import argparse 22import itertools 23import json 24import logging 25import time 26 27import gluonnlp as nlp 28 29 30def parse_args(): 31 """Parse command line arguments.""" 32 parser = argparse.ArgumentParser( 33 description='Vocabulary extractor.', 34 formatter_class=argparse.ArgumentDefaultsHelpFormatter) 35 parser.add_argument('--max-size', type=int, default=None) 36 parser.add_argument('--min-freq', type=int, default=5) 37 parser.add_argument('--max-word-length', type=int, default=50) 38 parser.add_argument('files', type=str, nargs='+') 39 parser.add_argument('--vocab-output', type=str, default='vocab.json') 40 parser.add_argument('--counts-output', type=str, default='counts.json') 41 args = parser.parse_args() 42 return args 43 44 45def get_vocab(args): 46 """Compute the vocabulary.""" 47 counter = nlp.data.Counter() 48 start = time.time() 49 for filename in args.files: 50 print('Starting processing of {} after {:.1f} seconds.'.format( 51 filename, 52 time.time() - start)) 53 with open(filename, 'r') as f: 54 tokens = itertools.chain.from_iterable((l.split() for l in f)) 55 counter.update(tokens) 56 57 if args.max_word_length: 58 counter = { 59 w: c 60 for w, c in counter.items() if len(w) < args.max_word_length 61 } 62 63 total_time = time.time() - start 64 print('Finished after {:.1f} seconds.'.format(total_time)) 65 num_words = sum(counter.values()) 66 print('Got {} words. Processed {:.1f} per second.'.format( 67 num_words, num_words / total_time)) 68 69 start = time.time() 70 print('Starting creation of vocabulary.') 71 vocab = nlp.Vocab(counter, max_size=args.max_size, min_freq=args.min_freq, 72 unknown_token=None, padding_token=None, bos_token=None, 73 eos_token=None) 74 with open(args.vocab_output, 'w') as f: 75 f.write(vocab.to_json()) 76 print('Finished creation of vocabulary after {:.1f} seconds.'.format( 77 time.time() - start)) 78 79 print('Writing word counts.') 80 start = time.time() 81 idx_to_counts = [counter[t] for t in vocab.idx_to_token] 82 with open(args.counts_output, 'w') as f: 83 json.dump(idx_to_counts, f) 84 print('Finished writing word counts after {:.1f} seconds..'.format( 85 time.time() - start)) 86 87 88if __name__ == '__main__': 89 logging.basicConfig() 90 logging.getLogger().setLevel(logging.INFO) 91 args_ = parse_args() 92 get_vocab(args_) 93