1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=
19"""Extract the vocabulary from a file and write it to disk."""
20
21import argparse
22import itertools
23import json
24import logging
25import time
26
27import gluonnlp as nlp
28
29
30def parse_args():
31    """Parse command line arguments."""
32    parser = argparse.ArgumentParser(
33        description='Vocabulary extractor.',
34        formatter_class=argparse.ArgumentDefaultsHelpFormatter)
35    parser.add_argument('--max-size', type=int, default=None)
36    parser.add_argument('--min-freq', type=int, default=5)
37    parser.add_argument('--max-word-length', type=int, default=50)
38    parser.add_argument('files', type=str, nargs='+')
39    parser.add_argument('--vocab-output', type=str, default='vocab.json')
40    parser.add_argument('--counts-output', type=str, default='counts.json')
41    args = parser.parse_args()
42    return args
43
44
45def get_vocab(args):
46    """Compute the vocabulary."""
47    counter = nlp.data.Counter()
48    start = time.time()
49    for filename in args.files:
50        print('Starting processing of {} after {:.1f} seconds.'.format(
51            filename,
52            time.time() - start))
53        with open(filename, 'r') as f:
54            tokens = itertools.chain.from_iterable((l.split() for l in f))
55            counter.update(tokens)
56
57    if args.max_word_length:
58        counter = {
59            w: c
60            for w, c in counter.items() if len(w) < args.max_word_length
61        }
62
63    total_time = time.time() - start
64    print('Finished after {:.1f} seconds.'.format(total_time))
65    num_words = sum(counter.values())
66    print('Got {} words. Processed {:.1f} per second.'.format(
67        num_words, num_words / total_time))
68
69    start = time.time()
70    print('Starting creation of vocabulary.')
71    vocab = nlp.Vocab(counter, max_size=args.max_size, min_freq=args.min_freq,
72                      unknown_token=None, padding_token=None, bos_token=None,
73                      eos_token=None)
74    with open(args.vocab_output, 'w') as f:
75        f.write(vocab.to_json())
76    print('Finished creation of vocabulary after {:.1f} seconds.'.format(
77        time.time() - start))
78
79    print('Writing word counts.')
80    start = time.time()
81    idx_to_counts = [counter[t] for t in vocab.idx_to_token]
82    with open(args.counts_output, 'w') as f:
83        json.dump(idx_to_counts, f)
84    print('Finished writing word counts after {:.1f} seconds..'.format(
85        time.time() - start))
86
87
88if __name__ == '__main__':
89    logging.basicConfig()
90    logging.getLogger().setLevel(logging.INFO)
91    args_ = parse_args()
92    get_vocab(args_)
93