1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Data preprocessing for transformer.""" 18 19import os 20import io 21import time 22import logging 23import numpy as np 24from mxnet import gluon 25import gluonnlp as nlp 26import gluonnlp.data.batchify as btf 27import _constants 28import dataset as _dataset 29 30 31def _cache_dataset(dataset, prefix): 32 """Cache the processed npy dataset the dataset into a npz 33 34 Parameters 35 ---------- 36 dataset : SimpleDataset 37 file_path : str 38 """ 39 if not os.path.exists(_constants.CACHE_PATH): 40 os.makedirs(_constants.CACHE_PATH) 41 src_data = np.concatenate([e[0] for e in dataset]) 42 tgt_data = np.concatenate([e[1] for e in dataset]) 43 src_cumlen = np.cumsum([0]+[len(e[0]) for e in dataset]) 44 tgt_cumlen = np.cumsum([0]+[len(e[1]) for e in dataset]) 45 np.savez(os.path.join(_constants.CACHE_PATH, prefix + '.npz'), 46 src_data=src_data, tgt_data=tgt_data, 47 src_cumlen=src_cumlen, tgt_cumlen=tgt_cumlen) 48 49 50def _load_cached_dataset(prefix): 51 cached_file_path = os.path.join(_constants.CACHE_PATH, prefix + '.npz') 52 if os.path.exists(cached_file_path): 53 print('Loading dataset...') 54 npz_data = np.load(cached_file_path) 55 src_data, tgt_data, src_cumlen, tgt_cumlen = \ 56 [npz_data[n] for n in ['src_data', 'tgt_data', 'src_cumlen', 'tgt_cumlen']] 57 src_data = np.array([src_data[low:high] for low, high 58 in zip(src_cumlen[:-1], src_cumlen[1:])]) 59 tgt_data = np.array([tgt_data[low:high] for low, high 60 in zip(tgt_cumlen[:-1], tgt_cumlen[1:])]) 61 return gluon.data.ArrayDataset(np.array(src_data), np.array(tgt_data)) 62 else: 63 return None 64 65 66class TrainValDataTransform: 67 """Transform the machine translation dataset. 68 69 Clip source and the target sentences to the maximum length. For the source sentence, append the 70 EOS. For the target sentence, append BOS and EOS. 71 72 Parameters 73 ---------- 74 src_vocab : Vocab 75 tgt_vocab : Vocab 76 src_max_len : int 77 tgt_max_len : int 78 """ 79 80 def __init__(self, src_vocab, tgt_vocab, src_max_len=None, tgt_max_len=None): 81 self._src_vocab = src_vocab 82 self._tgt_vocab = tgt_vocab 83 self._src_max_len = src_max_len 84 self._tgt_max_len = tgt_max_len 85 86 def __call__(self, src, tgt): 87 # For src_max_len < 0, we do not clip the sequence 88 if self._src_max_len >= 0: 89 src_sentence = self._src_vocab[src.split()[:self._src_max_len]] 90 else: 91 src_sentence = self._src_vocab[src.split()] 92 # For tgt_max_len < 0, we do not clip the sequence 93 if self._tgt_max_len >= 0: 94 tgt_sentence = self._tgt_vocab[tgt.split()[:self._tgt_max_len]] 95 else: 96 tgt_sentence = self._tgt_vocab[tgt.split()] 97 src_sentence.append(self._src_vocab[self._src_vocab.eos_token]) 98 tgt_sentence.insert(0, self._tgt_vocab[self._tgt_vocab.bos_token]) 99 tgt_sentence.append(self._tgt_vocab[self._tgt_vocab.eos_token]) 100 src_npy = np.array(src_sentence, dtype=np.int32) 101 tgt_npy = np.array(tgt_sentence, dtype=np.int32) 102 return src_npy, tgt_npy 103 104 105def process_dataset(dataset, src_vocab, tgt_vocab, src_max_len=-1, tgt_max_len=-1): 106 start = time.time() 107 dataset_processed = dataset.transform(TrainValDataTransform(src_vocab, tgt_vocab, 108 src_max_len, 109 tgt_max_len), lazy=False) 110 end = time.time() 111 print('Processing Time spent: {}'.format(end - start)) 112 return dataset_processed 113 114 115def load_translation_data(dataset, bleu, args): 116 """Load translation dataset 117 118 Parameters 119 ---------- 120 dataset : str 121 args : argparse result 122 123 Returns 124 ------- 125 126 """ 127 src_lang, tgt_lang = args.src_lang, args.tgt_lang 128 if dataset == 'IWSLT2015': 129 common_prefix = 'IWSLT2015_{}_{}_{}_{}'.format(src_lang, tgt_lang, 130 args.src_max_len, args.tgt_max_len) 131 data_train = nlp.data.IWSLT2015('train', src_lang=src_lang, tgt_lang=tgt_lang) 132 data_val = nlp.data.IWSLT2015('val', src_lang=src_lang, tgt_lang=tgt_lang) 133 data_test = nlp.data.IWSLT2015('test', src_lang=src_lang, tgt_lang=tgt_lang) 134 elif dataset == 'WMT2016BPE': 135 common_prefix = 'WMT2016BPE_{}_{}_{}_{}'.format(src_lang, tgt_lang, 136 args.src_max_len, args.tgt_max_len) 137 data_train = nlp.data.WMT2016BPE('train', src_lang=src_lang, tgt_lang=tgt_lang) 138 data_val = nlp.data.WMT2016BPE('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang) 139 data_test = nlp.data.WMT2016BPE('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang) 140 elif dataset == 'WMT2014BPE': 141 common_prefix = 'WMT2014BPE_{}_{}_{}_{}'.format(src_lang, tgt_lang, 142 args.src_max_len, args.tgt_max_len) 143 data_train = nlp.data.WMT2014BPE('train', src_lang=src_lang, tgt_lang=tgt_lang) 144 data_val = nlp.data.WMT2014BPE('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang) 145 data_test = nlp.data.WMT2014BPE('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang, 146 full=args.full) 147 elif dataset == 'TOY': 148 common_prefix = 'TOY_{}_{}_{}_{}'.format(src_lang, tgt_lang, 149 args.src_max_len, args.tgt_max_len) 150 data_train = _dataset.TOY('train', src_lang=src_lang, tgt_lang=tgt_lang) 151 data_val = _dataset.TOY('val', src_lang=src_lang, tgt_lang=tgt_lang) 152 data_test = _dataset.TOY('test', src_lang=src_lang, tgt_lang=tgt_lang) 153 else: 154 raise NotImplementedError 155 src_vocab, tgt_vocab = data_train.src_vocab, data_train.tgt_vocab 156 data_train_processed = _load_cached_dataset(common_prefix + '_train') 157 if not data_train_processed: 158 data_train_processed = process_dataset(data_train, src_vocab, tgt_vocab, 159 args.src_max_len, args.tgt_max_len) 160 _cache_dataset(data_train_processed, common_prefix + '_train') 161 data_val_processed = _load_cached_dataset(common_prefix + '_val') 162 if not data_val_processed: 163 data_val_processed = process_dataset(data_val, src_vocab, tgt_vocab) 164 _cache_dataset(data_val_processed, common_prefix + '_val') 165 if dataset == 'WMT2014BPE': 166 filename = common_prefix + '_' + str(args.full) + '_test' 167 else: 168 filename = common_prefix + '_test' 169 data_test_processed = _load_cached_dataset(filename) 170 if not data_test_processed: 171 data_test_processed = process_dataset(data_test, src_vocab, tgt_vocab) 172 _cache_dataset(data_test_processed, filename) 173 if bleu == 'tweaked': 174 fetch_tgt_sentence = lambda src, tgt: tgt.split() 175 val_tgt_sentences = list(data_val.transform(fetch_tgt_sentence)) 176 test_tgt_sentences = list(data_test.transform(fetch_tgt_sentence)) 177 elif bleu in ('13a', 'intl'): 178 fetch_tgt_sentence = lambda src, tgt: tgt 179 if dataset == 'WMT2016BPE': 180 val_text = nlp.data.WMT2016('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang) 181 test_text = nlp.data.WMT2016('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang) 182 elif dataset == 'WMT2014BPE': 183 val_text = nlp.data.WMT2014('newstest2013', src_lang=src_lang, tgt_lang=tgt_lang) 184 test_text = nlp.data.WMT2014('newstest2014', src_lang=src_lang, tgt_lang=tgt_lang, 185 full=args.full) 186 elif dataset in ('IWSLT2015', 'TOY'): 187 val_text = data_val 188 test_text = data_test 189 else: 190 raise NotImplementedError 191 val_tgt_sentences = list(val_text.transform(fetch_tgt_sentence)) 192 test_tgt_sentences = list(test_text.transform(fetch_tgt_sentence)) 193 else: 194 raise NotImplementedError 195 return data_train_processed, data_val_processed, data_test_processed, \ 196 val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab 197 198 199def get_data_lengths(dataset): 200 get_lengths = lambda *args: (args[2], args[3]) 201 return list(dataset.transform(get_lengths)) 202 203def get_dataloader(data_set, args, dataset_type, 204 use_average_length=False, num_shards=0, num_workers=8): 205 """Create data loaders for training/validation/test.""" 206 assert dataset_type in ['train', 'val', 'test'] 207 208 if args.bucket_scheme == 'constant': 209 bucket_scheme = nlp.data.ConstWidthBucket() 210 elif args.bucket_scheme == 'linear': 211 bucket_scheme = nlp.data.LinearWidthBucket() 212 elif args.bucket_scheme == 'exp': 213 bucket_scheme = nlp.data.ExpWidthBucket(bucket_len_step=1.2) 214 else: 215 raise NotImplementedError 216 217 data_lengths = get_data_lengths(data_set) 218 219 if dataset_type == 'train': 220 train_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0), 221 btf.Stack(dtype='float32'), btf.Stack(dtype='float32')) 222 223 else: 224 data_lengths = list(map(lambda x: x[-1], data_lengths)) 225 test_batchify_fn = btf.Tuple(btf.Pad(pad_val=0), btf.Pad(pad_val=0), 226 btf.Stack(dtype='float32'), btf.Stack(dtype='float32'), 227 btf.Stack()) 228 229 batch_sampler = nlp.data.FixedBucketSampler(lengths=data_lengths, 230 batch_size=(args.batch_size \ 231 if dataset_type == 'train' \ 232 else args.test_batch_size), 233 num_buckets=args.num_buckets, 234 ratio=args.bucket_ratio, 235 shuffle=(dataset_type == 'train'), 236 use_average_length=use_average_length, 237 num_shards=num_shards, 238 bucket_scheme=bucket_scheme) 239 240 if dataset_type == 'train': 241 logging.info('Train Batch Sampler:\n%s', batch_sampler.stats()) 242 data_loader = nlp.data.ShardedDataLoader(data_set, 243 batch_sampler=batch_sampler, 244 batchify_fn=train_batchify_fn, 245 num_workers=num_workers) 246 else: 247 if dataset_type == 'val': 248 logging.info('Valid Batch Sampler:\n%s', batch_sampler.stats()) 249 else: 250 logging.info('Test Batch Sampler:\n%s', batch_sampler.stats()) 251 252 data_loader = gluon.data.DataLoader(data_set, 253 batch_sampler=batch_sampler, 254 batchify_fn=test_batchify_fn, 255 num_workers=num_workers) 256 257 return data_loader 258 259def make_dataloader(data_train, data_val, data_test, args, 260 use_average_length=False, num_shards=0, num_workers=8): 261 """Create data loaders for training/validation/test.""" 262 train_data_loader = get_dataloader(data_train, args, dataset_type='train', 263 use_average_length=use_average_length, 264 num_shards=num_shards, 265 num_workers=num_workers) 266 267 val_data_loader = get_dataloader(data_val, args, dataset_type='val', 268 use_average_length=use_average_length, 269 num_workers=num_workers) 270 271 test_data_loader = get_dataloader(data_test, args, dataset_type='test', 272 use_average_length=use_average_length, 273 num_workers=num_workers) 274 275 return train_data_loader, val_data_loader, test_data_loader 276 277 278def write_sentences(sentences, file_path): 279 with io.open(file_path, 'w', encoding='utf-8') as of: 280 for sent in sentences: 281 if isinstance(sent, (list, tuple)): 282 of.write(' '.join(sent) + '\n') 283 else: 284 of.write(sent + '\n') 285