1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Data preprocessing for transformer."""
18
19import os
20import io
21import time
22import logging
23import numpy as np
24from mxnet import gluon
25import gluonnlp as nlp
26import gluonnlp.data.batchify as btf
27import _constants
28import dataset as _dataset
29
30
31def _cache_dataset(dataset, prefix):
32    """Cache the processed npy dataset the dataset into a npz
33
34    Parameters
35    ----------
36    dataset : SimpleDataset
37    file_path : str
38    """
39    if not os.path.exists(_constants.CACHE_PATH):
40        os.makedirs(_constants.CACHE_PATH)
41    src_data = np.concatenate([e[0] for e in dataset])
42    tgt_data = np.concatenate([e[1] for e in dataset])
43    src_cumlen = np.cumsum([0]+[len(e[0]) for e in dataset])
44    tgt_cumlen = np.cumsum([0]+[len(e[1]) for e in dataset])
45    np.savez(os.path.join(_constants.CACHE_PATH, prefix + '.npz'),
46             src_data=src_data, tgt_data=tgt_data,
47             src_cumlen=src_cumlen, tgt_cumlen=tgt_cumlen)
48
49
50def _load_cached_dataset(prefix):
51    cached_file_path = os.path.join(_constants.CACHE_PATH, prefix + '.npz')
52    if os.path.exists(cached_file_path):
53        print('Loading dataset...')
54        npz_data = np.load(cached_file_path)
55        src_data, tgt_data, src_cumlen, tgt_cumlen = \
56                [npz_data[n] for n in ['src_data', 'tgt_data', 'src_cumlen', 'tgt_cumlen']]
57        src_data = np.array([src_data[low:high] for low, high
58                             in zip(src_cumlen[:-1], src_cumlen[1:])])
59        tgt_data = np.array([tgt_data[low:high] for low, high
60                             in zip(tgt_cumlen[:-1], tgt_cumlen[1:])])
61        return gluon.data.ArrayDataset(np.array(src_data), np.array(tgt_data))
62    else:
63        return None
64
65
66class TrainValDataTransform:
67    """Transform the machine translation dataset.
68
69    Clip source and the target sentences to the maximum length. For the source sentence, append the
70    EOS. For the target sentence, append BOS and EOS.
71
72    Parameters
73    ----------
74    src_vocab : Vocab
75    tgt_vocab : Vocab
76    src_max_len : int
77    tgt_max_len : int
78    """
79
80    def __init__(self, src_vocab, tgt_vocab, src_max_len=None, tgt_max_len=None):
81        self._src_vocab = src_vocab
82        self._tgt_vocab = tgt_vocab
83        self._src_max_len = src_max_len
84        self._tgt_max_len = tgt_max_len
85
86    def __call__(self, src, tgt):
87        # For src_max_len < 0, we do not clip the sequence
88        if self._src_max_len >= 0:
89            src_sentence = self._src_vocab[src.split()[:self._src_max_len]]
90        else:
91            src_sentence = self._src_vocab[src.split()]
92        # For tgt_max_len < 0, we do not clip the sequence
93        if self._tgt_max_len >= 0:
94            tgt_sentence = self._tgt_vocab[tgt.split()[:self._tgt_max_len]]
95        else:
96            tgt_sentence = self._tgt_vocab[tgt.split()]
97        src_sentence.append(self._src_vocab[self._src_vocab.eos_token])
98        tgt_sentence.insert(0, self._tgt_vocab[self._tgt_vocab.bos_token])
99        tgt_sentence.append(self._tgt_vocab[self._tgt_vocab.eos_token])
100        src_npy = np.array(src_sentence, dtype=np.int32)
101        tgt_npy = np.array(tgt_sentence, dtype=np.int32)
102        return src_npy, tgt_npy
103
104
105def process_dataset(dataset, src_vocab, tgt_vocab, src_max_len=-1, tgt_max_len=-1):
106    start = time.time()
107    dataset_processed = dataset.transform(TrainValDataTransform(src_vocab, tgt_vocab,
108                                                                src_max_len,
109                                                                tgt_max_len), lazy=False)
110    end = time.time()
111    print('Processing Time spent: {}'.format(end - start))
112    return dataset_processed
113
114
115def load_translation_data(dataset, bleu, args):
116    """Load translation dataset
117
118    Parameters
119    ----------
120    dataset : str
121    args : argparse result
122
123    Returns
124    -------
125
126    """
127    src_lang, tgt_lang = args.src_lang, args.tgt_lang
128    if dataset == 'IWSLT2015':
129        common_prefix = 'IWSLT2015_{}_{}_{}_{}'.format(src_lang, tgt_lang,
130                                                       args.src_max_len, args.tgt_max_len)
131        data_train = nlp.data.IWSLT2015('train', src_lang=src_lang, tgt_lang=tgt_lang)
132        data_val = nlp.data.IWSLT2015('val', src_lang=src_lang, tgt_lang=tgt_lang)
133        data_test = nlp.data.IWSLT2015('test', src_lang=src_lang, tgt_lang=tgt_lang)
134    elif dataset == 'WMT2016BPE':
135        common_prefix = 'WMT2016BPE_{}_{}_{}_{}'.format(src_lang, tgt_lang,
136                                                        args.src_max_len, args.tgt_max_len)
137        data_train = nlp.data.WMT2016BPE('train', src_lang=src_lang, tgt_lang=tgt_lang)
138        data_val = nlp.data.WMT2016BPE('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang)
139        data_test = nlp.data.WMT2016BPE('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang)
140    elif dataset == 'WMT2014BPE':
141        common_prefix = 'WMT2014BPE_{}_{}_{}_{}'.format(src_lang, tgt_lang,
142                                                        args.src_max_len, args.tgt_max_len)
143        data_train = nlp.data.WMT2014BPE('train', src_lang=src_lang, tgt_lang=tgt_lang)
144        data_val = nlp.data.WMT2014BPE('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang)
145        data_test = nlp.data.WMT2014BPE('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang,
146                                        full=args.full)
147    elif dataset == 'TOY':
148        common_prefix = 'TOY_{}_{}_{}_{}'.format(src_lang, tgt_lang,
149                                                 args.src_max_len, args.tgt_max_len)
150        data_train = _dataset.TOY('train', src_lang=src_lang, tgt_lang=tgt_lang)
151        data_val = _dataset.TOY('val', src_lang=src_lang, tgt_lang=tgt_lang)
152        data_test = _dataset.TOY('test', src_lang=src_lang, tgt_lang=tgt_lang)
153    else:
154        raise NotImplementedError
155    src_vocab, tgt_vocab = data_train.src_vocab, data_train.tgt_vocab
156    data_train_processed = _load_cached_dataset(common_prefix + '_train')
157    if not data_train_processed:
158        data_train_processed = process_dataset(data_train, src_vocab, tgt_vocab,
159                                               args.src_max_len, args.tgt_max_len)
160        _cache_dataset(data_train_processed, common_prefix + '_train')
161    data_val_processed = _load_cached_dataset(common_prefix + '_val')
162    if not data_val_processed:
163        data_val_processed = process_dataset(data_val, src_vocab, tgt_vocab)
164        _cache_dataset(data_val_processed, common_prefix + '_val')
165    if dataset == 'WMT2014BPE':
166        filename = common_prefix + '_' + str(args.full) + '_test'
167    else:
168        filename = common_prefix + '_test'
169    data_test_processed = _load_cached_dataset(filename)
170    if not data_test_processed:
171        data_test_processed = process_dataset(data_test, src_vocab, tgt_vocab)
172        _cache_dataset(data_test_processed, filename)
173    if bleu == 'tweaked':
174        fetch_tgt_sentence = lambda src, tgt: tgt.split()
175        val_tgt_sentences = list(data_val.transform(fetch_tgt_sentence))
176        test_tgt_sentences = list(data_test.transform(fetch_tgt_sentence))
177    elif bleu in ('13a', 'intl'):
178        fetch_tgt_sentence = lambda src, tgt: tgt
179        if dataset == 'WMT2016BPE':
180            val_text = nlp.data.WMT2016('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang)
181            test_text = nlp.data.WMT2016('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang)
182        elif dataset == 'WMT2014BPE':
183            val_text = nlp.data.WMT2014('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang)
184            test_text = nlp.data.WMT2014('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang,
185                                         full=args.full)
186        elif dataset in ('IWSLT2015', 'TOY'):
187            val_text = data_val
188            test_text = data_test
189        else:
190            raise NotImplementedError
191        val_tgt_sentences = list(val_text.transform(fetch_tgt_sentence))
192        test_tgt_sentences = list(test_text.transform(fetch_tgt_sentence))
193    else:
194        raise NotImplementedError
195    return data_train_processed, data_val_processed, data_test_processed, \
196           val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab
197
198
199def get_data_lengths(dataset):
200    get_lengths = lambda *args: (args[2], args[3])
201    return list(dataset.transform(get_lengths))
202
203def get_dataloader(data_set, args, dataset_type,
204                   use_average_length=False, num_shards=0, num_workers=8):
205    """Create data loaders for training/validation/test."""
206    assert dataset_type in ['train', 'val', 'test']
207
208    if args.bucket_scheme == 'constant':
209        bucket_scheme = nlp.data.ConstWidthBucket()
210    elif args.bucket_scheme == 'linear':
211        bucket_scheme = nlp.data.LinearWidthBucket()
212    elif args.bucket_scheme == 'exp':
213        bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2)
214    else:
215        raise NotImplementedError
216
217    data_lengths = get_data_lengths(data_set)
218
219    if dataset_type == 'train':
220        train_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0),
221                                      btf.Stack(dtype='float32'), btf.Stack(dtype='float32'))
222
223    else:
224        data_lengths = list(map(lambda x: x[-1], data_lengths))
225        test_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0),
226                                     btf.Stack(dtype='float32'), btf.Stack(dtype='float32'),
227                                     btf.Stack())
228
229    batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths,
230                                                batch_size=(args.batch_size \
231                                                            if dataset_type == 'train' \
232                                                            else args.test_batch_size),
233                                                num_buckets=args.num_buckets,
234                                                ratio=args.bucket_ratio,
235                                                shuffle=(dataset_type == 'train'),
236                                                use_average_length=use_average_length,
237                                                num_shards=num_shards,
238                                                bucket_scheme=bucket_scheme)
239
240    if dataset_type == 'train':
241        logging.info('Train Batch Sampler:\n%s', batch_sampler.stats())
242        data_loader = nlp.data.ShardedDataLoader(data_set,
243                                                 batch_sampler=batch_sampler,
244                                                 batchify_fn=train_batchify_fn,
245                                                 num_workers=num_workers)
246    else:
247        if dataset_type == 'val':
248            logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats())
249        else:
250            logging.info('Test Batch Sampler:\n%s', batch_sampler.stats())
251
252        data_loader = gluon.data.DataLoader(data_set,
253                                            batch_sampler=batch_sampler,
254                                            batchify_fn=test_batchify_fn,
255                                            num_workers=num_workers)
256
257    return data_loader
258
259def make_dataloader(data_train, data_val, data_test, args,
260                    use_average_length=False, num_shards=0, num_workers=8):
261    """Create data loaders for training/validation/test."""
262    train_data_loader = get_dataloader(data_train, args, dataset_type='train',
263                                       use_average_length=use_average_length,
264                                       num_shards=num_shards,
265                                       num_workers=num_workers)
266
267    val_data_loader = get_dataloader(data_val, args, dataset_type='val',
268                                     use_average_length=use_average_length,
269                                     num_workers=num_workers)
270
271    test_data_loader = get_dataloader(data_test, args, dataset_type='test',
272                                      use_average_length=use_average_length,
273                                      num_workers=num_workers)
274
275    return train_data_loader, val_data_loader, test_data_loader
276
277
278def write_sentences(sentences, file_path):
279    with io.open(file_path, 'w', encoding='utf-8') as of:
280        for sent in sentences:
281            if isinstance(sent, (list, tuple)):
282                of.write(' '.join(sent) + '\n')
283            else:
284                of.write(sent + '\n')
285