1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=
19"""Word embedding training datasets."""
20
21__all__ = [
22    'WikiDumpStream', 'preprocess_dataset', 'wiki', 'transform_data_fasttext',
23    'transform_data_word2vec', 'skipgram_lookup', 'cbow_lookup',
24    'skipgram_fasttext_batch', 'cbow_fasttext_batch', 'skipgram_batch',
25    'cbow_batch']
26
27import functools
28import io
29import itertools
30import json
31import math
32import os
33import warnings
34
35import mxnet as mx
36import numpy as np
37
38import gluonnlp as nlp
39from gluonnlp import Vocab
40from gluonnlp.base import numba_njit
41from gluonnlp.data import CorpusDataset, SimpleDatasetStream
42from utils import print_time
43
44
45def preprocess_dataset(data, min_freq=5, max_vocab_size=None):
46    """Dataset preprocessing helper.
47
48    Parameters
49    ----------
50    data : mx.data.Dataset
51        Input Dataset. For example gluonnlp.data.Text8 or gluonnlp.data.Fil9
52    min_freq : int, default 5
53        Minimum token frequency for a token to be included in the vocabulary
54        and returned DataStream.
55    max_vocab_size : int, optional
56        Specifies a maximum size for the vocabulary.
57
58    Returns
59    -------
60    gluonnlp.data.DataStream
61        Each sample is a valid input to
62        gluonnlp.data.EmbeddingCenterContextBatchify.
63    gluonnlp.Vocab
64        Vocabulary of all tokens in Text8 that occur at least min_freq times of
65        maximum size max_vocab_size.
66    idx_to_counts : list of int
67        Mapping from token indices to their occurrence-counts in the Text8
68        dataset.
69
70    """
71    with print_time('count and construct vocabulary'):
72        counter = nlp.data.count_tokens(itertools.chain.from_iterable(data))
73        vocab = nlp.Vocab(counter, unknown_token=None, padding_token=None,
74                          bos_token=None, eos_token=None, min_freq=min_freq,
75                          max_size=max_vocab_size)
76        idx_to_counts = [counter[w] for w in vocab.idx_to_token]
77
78    def code(sentence):
79        return [vocab[token] for token in sentence if token in vocab]
80
81    with print_time('code data'):
82        data = data.transform(code, lazy=False)
83    data = nlp.data.SimpleDataStream([data])
84    return data, vocab, idx_to_counts
85
86
87def wiki(wiki_root, wiki_date, wiki_language, max_vocab_size=None):
88    """Wikipedia dump helper.
89
90    Parameters
91    ----------
92    wiki_root : str
93        Parameter for WikiDumpStream
94    wiki_date : str
95        Parameter for WikiDumpStream
96    wiki_language : str
97        Parameter for WikiDumpStream
98    max_vocab_size : int, optional
99        Specifies a maximum size for the vocabulary.
100
101    Returns
102    -------
103    gluonnlp.data.DataStream
104        Each sample is a valid input to
105        gluonnlp.data.EmbeddingCenterContextBatchify.
106    gluonnlp.Vocab
107        Vocabulary of all tokens in the Wikipedia corpus as provided by
108        WikiDumpStream but with maximum size max_vocab_size.
109    idx_to_counts : list of int
110        Mapping from token indices to their occurrence-counts in the Wikipedia
111        corpus.
112
113    """
114    data = WikiDumpStream(
115        root=os.path.expanduser(wiki_root), language=wiki_language,
116        date=wiki_date)
117    vocab = data.vocab
118    if max_vocab_size:
119        for token in vocab.idx_to_token[max_vocab_size:]:
120            vocab.token_to_idx.pop(token)
121        vocab.idx_to_token = vocab.idx_to_token[:max_vocab_size]
122    idx_to_counts = data.idx_to_counts
123
124    def code(shard):
125        return [[vocab[token] for token in sentence if token in vocab]
126                for sentence in shard]
127
128    data = data.transform(code)
129    return data, vocab, idx_to_counts
130
131
132def transform_data_fasttext(data, vocab, idx_to_counts, cbow, ngram_buckets,
133                            ngrams, batch_size, window_size,
134                            frequent_token_subsampling=1E-4, dtype='float32',
135                            index_dtype='int64'):
136    """Transform a DataStream of coded DataSets to a DataStream of batches.
137
138    Parameters
139    ----------
140    data : gluonnlp.data.DataStream
141        DataStream where each sample is a valid input to
142        gluonnlp.data.EmbeddingCenterContextBatchify.
143    vocab : gluonnlp.Vocab
144        Vocabulary containing all tokens whose indices occur in data. For each
145        token, it's associated subwords will be computed and used for
146        constructing the batches. No subwords are used if ngram_buckets is 0.
147    idx_to_counts : list of int
148        List of integers such that idx_to_counts[idx] represents the count of
149        vocab.idx_to_token[idx] in the underlying dataset. The count
150        information is used to subsample frequent words in the dataset.
151        Each token is independently dropped with probability 1 - sqrt(t /
152        (count / sum_counts)) where t is the hyperparameter
153        frequent_token_subsampling.
154    cbow : boolean
155        If True, batches for CBOW are returned.
156    ngram_buckets : int
157        Number of hash buckets to consider for the fastText
158        nlp.vocab.NGramHashes subword function.
159    ngrams : list of int
160        For each integer n in the list, all ngrams of length n will be
161        considered by the nlp.vocab.NGramHashes subword function.
162    batch_size : int
163        The returned data stream iterates over batches of batch_size.
164    window_size : int
165        The context window size for
166        gluonnlp.data.EmbeddingCenterContextBatchify.
167    frequent_token_subsampling : float
168        Hyperparameter for subsampling. See idx_to_counts above for more
169        information.
170    dtype : str or np.dtype, default 'float32'
171        Data type of data array.
172    index_dtype : str or np.dtype, default 'int64'
173        Data type of index arrays.
174
175    Returns
176    -------
177    gluonnlp.data.DataStream
178        Stream over batches. Each returned element is a list corresponding to
179        the arguments for the forward pass of model.SG or model.CBOW
180        respectively based on if cbow is False or True. If ngarm_buckets > 0,
181        the returned sample will contain ngrams. Both model.SG or model.CBOW
182        will handle them correctly as long as they are initialized with the
183        subword_function returned as second argument by this function (see
184        below).
185    gluonnlp.vocab.NGramHashes
186        The subword_function used for obtaining the subwords in the returned
187        batches.
188
189    """
190    if ngram_buckets <= 0:
191        raise ValueError('Invalid ngram_buckets. Use Word2Vec training '
192                         'pipeline if not interested in ngrams.')
193
194    sum_counts = float(sum(idx_to_counts))
195    idx_to_pdiscard = [
196        1 - math.sqrt(frequent_token_subsampling / (count / sum_counts))
197        for count in idx_to_counts]
198
199    def subsample(shard):
200        return [[
201            t for t, r in zip(sentence,
202                              np.random.uniform(0, 1, size=len(sentence)))
203            if r > idx_to_pdiscard[t]] for sentence in shard]
204
205    data = data.transform(subsample)
206
207    batchify = nlp.data.batchify.EmbeddingCenterContextBatchify(
208        batch_size=batch_size, window_size=window_size, cbow=cbow,
209        weight_dtype=dtype, index_dtype=index_dtype)
210    data = data.transform(batchify)
211
212    with print_time('prepare subwords'):
213        subword_function = nlp.vocab.create_subword_function(
214            'NGramHashes', ngrams=ngrams, num_subwords=ngram_buckets)
215
216        # Store subword indices for all words in vocabulary
217        idx_to_subwordidxs = list(subword_function(vocab.idx_to_token))
218        subwordidxs = np.concatenate(idx_to_subwordidxs)
219        subwordidxsptr = np.cumsum([
220            len(subwordidxs) for subwordidxs in idx_to_subwordidxs])
221        subwordidxsptr = np.concatenate([
222            np.zeros(1, dtype=np.int64), subwordidxsptr])
223        if cbow:
224            subword_lookup = functools.partial(
225                cbow_lookup, subwordidxs=subwordidxs,
226                subwordidxsptr=subwordidxsptr, offset=len(vocab))
227        else:
228            subword_lookup = functools.partial(
229                skipgram_lookup, subwordidxs=subwordidxs,
230                subwordidxsptr=subwordidxsptr, offset=len(vocab))
231        max_subwordidxs_len = max(len(s) for s in idx_to_subwordidxs)
232        if max_subwordidxs_len > 500:
233            warnings.warn(
234                'The word with largest number of subwords '
235                'has {} subwords, suggesting there are '
236                'some noisy words in your vocabulary. '
237                'You should filter out very long words '
238                'to avoid memory issues.'.format(max_subwordidxs_len))
239
240    data = UnchainStream(data)
241
242    if cbow:
243        batchify_fn = cbow_fasttext_batch
244    else:
245        batchify_fn = skipgram_fasttext_batch
246    batchify_fn = functools.partial(
247        batchify_fn, num_tokens=len(vocab) + len(subword_function),
248        subword_lookup=subword_lookup, dtype=dtype, index_dtype=index_dtype)
249
250    return data, batchify_fn, subword_function
251
252
253def transform_data_word2vec(data, vocab, idx_to_counts, cbow, batch_size,
254                            window_size, frequent_token_subsampling=1E-4,
255                            dtype='float32', index_dtype='int64'):
256    """Transform a DataStream of coded DataSets to a DataStream of batches.
257
258    Parameters
259    ----------
260    data : gluonnlp.data.DataStream
261        DataStream where each sample is a valid input to
262        gluonnlp.data.EmbeddingCenterContextBatchify.
263    vocab : gluonnlp.Vocab
264        Vocabulary containing all tokens whose indices occur in data.
265    idx_to_counts : list of int
266        List of integers such that idx_to_counts[idx] represents the count of
267        vocab.idx_to_token[idx] in the underlying dataset. The count
268        information is used to subsample frequent words in the dataset.
269        Each token is independently dropped with probability 1 - sqrt(t /
270        (count / sum_counts)) where t is the hyperparameter
271        frequent_token_subsampling.
272    batch_size : int
273        The returned data stream iterates over batches of batch_size.
274    window_size : int
275        The context window size for
276        gluonnlp.data.EmbeddingCenterContextBatchify.
277    frequent_token_subsampling : float
278        Hyperparameter for subsampling. See idx_to_counts above for more
279        information.
280    dtype : str or np.dtype, default 'float32'
281        Data type of data array.
282    index_dtype : str or np.dtype, default 'int64'
283        Data type of index arrays.
284
285    Returns
286    -------
287    gluonnlp.data.DataStream
288        Stream over batches.
289    """
290
291    sum_counts = float(sum(idx_to_counts))
292    idx_to_pdiscard = [
293        1 - math.sqrt(frequent_token_subsampling / (count / sum_counts))
294        for count in idx_to_counts]
295
296    def subsample(shard):
297        return [[
298            t for t, r in zip(sentence,
299                              np.random.uniform(0, 1, size=len(sentence)))
300            if r > idx_to_pdiscard[t]] for sentence in shard]
301
302    data = data.transform(subsample)
303
304    batchify = nlp.data.batchify.EmbeddingCenterContextBatchify(
305        batch_size=batch_size, window_size=window_size, cbow=cbow,
306        weight_dtype=dtype, index_dtype=index_dtype)
307    data = data.transform(batchify)
308    data = UnchainStream(data)
309
310    if cbow:
311        batchify_fn = cbow_batch
312    else:
313        batchify_fn = skipgram_batch
314    batchify_fn = functools.partial(batchify_fn, num_tokens=len(vocab),
315                                    dtype=dtype, index_dtype=index_dtype)
316
317    return data, batchify_fn,
318
319
320def cbow_fasttext_batch(centers, contexts, num_tokens, subword_lookup, dtype,
321                        index_dtype):
322    """Create a batch for CBOW training objective with subwords."""
323    _, contexts_row, contexts_col = contexts
324    data, row, col = subword_lookup(contexts_row, contexts_col)
325    centers = mx.nd.array(centers, dtype=index_dtype)
326    contexts = mx.nd.sparse.csr_matrix(
327        (data, (row, col)), dtype=dtype,
328        shape=(len(centers), num_tokens))  # yapf: disable
329    return centers, contexts
330
331
332def skipgram_fasttext_batch(centers, contexts, num_tokens, subword_lookup,
333                            dtype, index_dtype):
334    """Create a batch for SG training objective with subwords."""
335    contexts = mx.nd.array(contexts[2], dtype=index_dtype)
336    data, row, col = subword_lookup(centers)
337    centers = mx.nd.array(centers, dtype=index_dtype)
338    centers_csr = mx.nd.sparse.csr_matrix(
339        (data, (row, col)), dtype=dtype,
340        shape=(len(centers), num_tokens))  # yapf: disable
341    return centers_csr, contexts, centers
342
343
344def cbow_batch(centers, contexts, num_tokens, dtype, index_dtype):
345    """Create a batch for CBOW training objective."""
346    contexts_data, contexts_row, contexts_col = contexts
347    centers = mx.nd.array(centers, dtype=index_dtype)
348    contexts = mx.nd.sparse.csr_matrix(
349        (contexts_data, (contexts_row, contexts_col)),
350        dtype=dtype, shape=(len(centers), num_tokens))  # yapf: disable
351    return centers, contexts
352
353
354def skipgram_batch(centers, contexts, num_tokens, dtype, index_dtype):
355    """Create a batch for SG training objective."""
356    contexts = mx.nd.array(contexts[2], dtype=index_dtype)
357    indptr = mx.nd.arange(len(centers) + 1)
358    centers = mx.nd.array(centers, dtype=index_dtype)
359    centers_csr = mx.nd.sparse.csr_matrix(
360        (mx.nd.ones(centers.shape), centers, indptr), dtype=dtype,
361        shape=(len(centers), num_tokens))
362    return centers_csr, contexts, centers
363
364
365class UnchainStream(nlp.data.DataStream):
366    def __init__(self, iterable):
367        self._stream = iterable
368
369    def __iter__(self):
370        return iter(itertools.chain.from_iterable(self._stream))
371
372
373@numba_njit
374def skipgram_lookup(indices, subwordidxs, subwordidxsptr, offset=0):
375    """Get a sparse COO array of words and subwords for SkipGram.
376
377    Parameters
378    ----------
379    indices : numpy.ndarray
380        Array containing numbers in [0, vocabulary_size). The element at
381        position idx is taken to be the word that occurs at row idx in the
382        SkipGram batch.
383    offset : int
384        Offset to add to each subword index.
385    subwordidxs : numpy.ndarray
386        Array containing concatenation of all subwords of all tokens in the
387        vocabulary, in order of their occurrence in the vocabulary.
388        For example np.concatenate(idx_to_subwordidxs)
389    subwordidxsptr
390        Array containing pointers into subwordidxs array such that
391        subwordidxs[subwordidxsptr[i]:subwordidxsptr[i+1]] returns all subwords
392        of of token i. For example subwordidxsptr = np.cumsum([
393        len(subwordidxs) for subwordidxs in idx_to_subwordidxs])
394    offset : int, default 0
395        Offset to add to each subword index.
396
397    Returns
398    -------
399    numpy.ndarray of dtype float32
400        Array containing weights such that for each row, all weights sum to
401        1. In particular, all elements in a row have weight 1 /
402        num_elements_in_the_row
403    numpy.ndarray of dtype int64
404        This array is the row array of a sparse array of COO format.
405    numpy.ndarray of dtype int64
406        This array is the col array of a sparse array of COO format.
407
408    """
409    row = []
410    col = []
411    data = []
412    for i, idx in enumerate(indices):
413        start = subwordidxsptr[idx]
414        end = subwordidxsptr[idx + 1]
415
416        row.append(i)
417        col.append(idx)
418        data.append(1 / (1 + end - start))
419        for subword in subwordidxs[start:end]:
420            row.append(i)
421            col.append(subword + offset)
422            data.append(1 / (1 + end - start))
423
424    return (np.array(data, dtype=np.float32), np.array(row, dtype=np.int64),
425            np.array(col, dtype=np.int64))
426
427
428@numba_njit
429def cbow_lookup(context_row, context_col, subwordidxs, subwordidxsptr,
430                offset=0):
431    """Get a sparse COO array of words and subwords for CBOW.
432
433    Parameters
434    ----------
435    context_row : numpy.ndarray of dtype int64
436        Array of same length as context_col containing numbers in [0,
437        batch_size). For each idx, context_row[idx] specifies the row that
438        context_col[idx] occurs in a sparse matrix.
439    context_col : numpy.ndarray of dtype int64
440        Array of same length as context_row containing numbers in [0,
441        vocabulary_size). For each idx, context_col[idx] is one of the
442        context words in the context_row[idx] row of the batch.
443    subwordidxs : numpy.ndarray
444        Array containing concatenation of all subwords of all tokens in the
445        vocabulary, in order of their occurrence in the vocabulary.
446        For example np.concatenate(idx_to_subwordidxs)
447    subwordidxsptr
448        Array containing pointers into subwordidxs array such that
449        subwordidxs[subwordidxsptr[i]:subwordidxsptr[i+1]] returns all subwords
450        of of token i. For example subwordidxsptr = np.cumsum([
451        len(subwordidxs) for subwordidxs in idx_to_subwordidxs])
452    offset : int, default 0
453        Offset to add to each subword index.
454
455    Returns
456    -------
457    numpy.ndarray of dtype float32
458        Array containing weights summing to 1. The weights are chosen such
459        that the sum of weights for all subwords and word units of a given
460        context word is equal to 1 / number_of_context_words_in_the_row.
461        This array is the data array of a sparse array of COO format.
462    numpy.ndarray of dtype int64
463        This array is the row array of a sparse array of COO format.
464    numpy.ndarray of dtype int64
465        This array is the col array of a sparse array of COO format.
466        Array containing weights such that for each row, all weights sum to
467        1. In particular, all elements in a row have weight 1 /
468        num_elements_in_the_row
469
470    """
471    row = []
472    col = []
473    data = []
474
475    num_rows = np.max(context_row) + 1
476    row_to_numwords = np.zeros(num_rows)
477
478    for i, idx in enumerate(context_col):
479        start = subwordidxsptr[idx]
480        end = subwordidxsptr[idx + 1]
481
482        row_ = context_row[i]
483        row_to_numwords[row_] += 1
484
485        row.append(row_)
486        col.append(idx)
487        data.append(1 / (1 + end - start))
488        for subword in subwordidxs[start:end]:
489            row.append(row_)
490            col.append(subword + offset)
491            data.append(1 / (1 + end - start))
492
493    # Normalize by number of words
494    for i, row_ in enumerate(row):
495        assert 0 <= row_ <= num_rows
496        data[i] /= row_to_numwords[row_]
497
498    return (np.array(data, dtype=np.float32), np.array(row, dtype=np.int64),
499            np.array(col, dtype=np.int64))
500
501
502class WikiDumpStream(SimpleDatasetStream):
503    """Stream for preprocessed Wikipedia Dumps.
504
505    Expects data in format
506    - root/date/wiki.language/*.txt
507    - root/date/wiki.language/vocab.json
508    - root/date/wiki.language/counts.json
509
510    Parameters
511    ----------
512    path : str
513        Path to a folder storing the dataset and preprocessed vocabulary.
514    skip_empty : bool, default True
515        Whether to skip the empty samples produced from sample_splitters. If
516        False, `bos` and `eos` will be added in empty samples.
517    bos : str or None, default None
518        The token to add at the beginning of each sentence. If None, nothing is
519        added.
520    eos : str or None, default None
521        The token to add at the end of each sentence. If None, nothing is
522        added.
523
524    Attributes
525    ----------
526    vocab : gluonnlp.Vocab
527        Vocabulary object constructed from vocab.json.
528    idx_to_counts : list[int]
529        Mapping from vocabulary word indices to word counts.
530
531    """
532
533    def __init__(self, root, language, date, skip_empty=True, bos=None,
534                 eos=None):
535        self._root = root
536        self._language = language
537        self._date = date
538        self._path = os.path.join(root, date, 'wiki.' + language)
539
540        if not os.path.isdir(self._path):
541            raise ValueError('{} is not valid. '
542                             'Please make sure that the path exists and '
543                             'contains the preprocessed files.'.format(
544                                 self._path))
545
546        self._file_pattern = os.path.join(self._path, '*.txt')
547        super(WikiDumpStream, self).__init__(
548            dataset=CorpusDataset, file_pattern=self._file_pattern,
549            skip_empty=skip_empty, bos=bos, eos=eos)
550
551    @property
552    def vocab(self):
553        path = os.path.join(self._path, 'vocab.json')
554        with io.open(path, 'r', encoding='utf-8') as in_file:
555            return Vocab.from_json(in_file.read())
556
557    @property
558    def idx_to_counts(self):
559        path = os.path.join(self._path, 'counts.json')
560        with io.open(path, 'r', encoding='utf-8') as in_file:
561            return json.load(in_file)
562