1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=undefined-all-variable 19"""NLP Toolkit Data Stream API. It allows easy and customizable streaming of 20corpora and dataset files. Files can be streamed into formats that are 21ready for training and evaluation.""" 22 23 24import glob 25import multiprocessing 26import multiprocessing.pool 27import os 28import queue 29import random 30import sys 31import threading 32import traceback 33 34import numpy as np 35 36import mxnet as mx 37from mxnet.gluon.data import RandomSampler, Sampler, SequentialSampler 38 39__all__ = [ 40 'DataStream', 'SimpleDataStream', 'DatasetStream', 'SimpleDatasetStream', 41 'PrefetchingStream'] 42 43class DataStream: 44 """Abstract Data Stream Interface. 45 46 DataStreams are useful to avoid loading big datasets to memory. A 47 DataStream is a iterable object (it implements the __iter__ function). 48 Whenever an iteration over the DataStream is requested (e.g. in a for loop 49 or by calling iter(datastream)), a new iterator over all samples in the 50 DataStream is returned. DataStreams can be lazily transformed by calling 51 `transform()` which returns a DataStream over the transformed samples. 52 53 """ 54 55 def __iter__(self): 56 """Return an iterator over all elements of the DataStream. 57 58 This method returns a new iterator object that can iterate over 59 all the objects in the DataStream. 60 61 Returns 62 ------- 63 iterator 64 An object implementing the Python *iterator protocol*. 65 66 """ 67 raise NotImplementedError 68 69 def transform(self, fn): 70 """Transform a DataStream lazily. 71 72 Returns 73 ------- 74 DataStream 75 The data stream that lazily transforms the data while streaming. 76 """ 77 78 return _LazyTransformDataStream(self, fn) 79 80 81class SimpleDataStream(DataStream): 82 """SimpleDataStream wraps iterables to expose the DataStream API. 83 84 Unlike the iterable itself, the SimpleDataStream exposes the DataStream API 85 and allows lazy transformation of the iterable. 86 87 """ 88 def __init__(self, iterable): 89 self._stream = iterable 90 91 def __iter__(self): 92 return iter(self._stream) 93 94 95class _LazyTransformDataStream(DataStream): 96 """Data stream that lazily transforms the data.""" 97 def __init__(self, stream, fn): 98 self._stream = stream 99 self._fn = fn 100 101 def __iter__(self): 102 stream_iter = iter(self._stream) 103 104 # Yield must be hidden in closure so that __iter__ is called before 105 # __next__ is called. This is important, as calling iter(self._stream) 106 # may trigger multi-threaded or multi-processing prefetching of the 107 # stream. 108 def _closure(): 109 try: 110 item = next(stream_iter) 111 except StopIteration: 112 return 113 istuple = isinstance(item, tuple) 114 if istuple: 115 yield self._fn(*item) 116 while True: 117 try: 118 yield self._fn(*next(stream_iter)) 119 except StopIteration: 120 return 121 else: 122 yield self._fn(item) 123 while True: 124 try: 125 yield self._fn(next(stream_iter)) 126 except StopIteration: 127 return 128 129 return _closure() 130 131 132class DatasetStream(DataStream): 133 """Abstract Dataset Stream Interface. 134 135 A DatasetStream is a DataStream where each sample is a 136 `mxnet.gluon.data.Dataset`. An iteration over a DatasetStream iterates over 137 `mxnet.gluon.data.Dataset` objects, representing a chunk or shards of some 138 large datasets. 139 140 Iterating over sizeable chunks of a dataset can be helpful to speed up 141 preprocessing as the overhead of preprocessing each sample individually is 142 reduced (this is similar to the idea of using batches for training a 143 model). 144 145 """ 146 147 def __iter__(self): 148 raise NotImplementedError 149 150 151class _PathDataset(mx.gluon.data.SimpleDataset): 152 """A simple Datasets containing a list of paths given the file_pattern. 153 154 Parameters 155 ---------- 156 file_pattern: str 157 Path to the input text files. 158 """ 159 def __init__(self, file_pattern): 160 if not isinstance(file_pattern, str): 161 raise TypeError('file_pattern must be str, but got %s'%type(file_pattern)) 162 files = [] 163 for pattern in file_pattern.split(','): 164 files.extend(glob.glob(os.path.expanduser(pattern.strip()))) 165 files = sorted(files) 166 if len(files) == 0: 167 raise ValueError('Cannot find any file with path "%s"'%file_pattern) 168 super(_PathDataset, self).__init__(files) 169 170 171class SimpleDatasetStream(DatasetStream): 172 """A simple stream of Datasets. 173 174 The SimpleDatasetStream is created from multiple files based on provided 175 `file_pattern`. One file is read at a time and a corresponding Dataset is 176 returned. The Dataset is created based on the file and the kwargs passed to 177 SimpleDatasetStream. 178 179 Parameters 180 ---------- 181 dataset : class 182 The class for which to create an object for every file. kwargs are 183 passed to this class. 184 file_pattern: str 185 Path to the input text files. 186 file_sampler : str or gluon.data.Sampler, defaults to 'random' 187 The sampler used to sample a file. The following string values are supported: 188 189 - 'sequential': SequentialSampler 190 - 'random': RandomSampler 191 kwargs 192 All other keyword arguments are passed to the dataset constructor. 193 """ 194 def __init__(self, dataset, file_pattern, file_sampler='random', **kwargs): 195 # TODO(haibin) reuse _SimpleDatasetPathStream here 196 if not isinstance(file_pattern, str): 197 raise TypeError('file_pattern must be str, but got %s'%type(file_pattern)) 198 self._dataset = dataset 199 self._files = [] 200 for pattern in file_pattern.split(','): 201 self._files.extend(glob.glob(os.path.expanduser(pattern.strip()))) 202 self._files = sorted(self._files) 203 204 if len(self._files) == 0: 205 raise ValueError('Cannot find any file with path "%s"'%file_pattern) 206 self._file_sampler = self._get_sampler(file_sampler) 207 self._kwargs = kwargs 208 209 def _get_sampler(self, sampler): 210 if isinstance(sampler, Sampler): 211 return sampler 212 if isinstance(sampler, str): 213 length = len(self._files) 214 if sampler == 'random': 215 return RandomSampler(length) 216 if sampler == 'sequential': 217 return SequentialSampler(length) 218 raise ValueError('file_sampler must be a supported str ("random", "sequential") or' 219 'a `gluon.data.Sampler`, but got %s'%(sampler)) 220 221 def __iter__(self): 222 # generate file samples 223 for file_idx in iter(self._file_sampler): 224 filename = self._files[file_idx] 225 yield self._dataset(filename, **self._kwargs) 226 227 228class _Prefetcher: 229 """Internal shared prefetcher logic.""" 230 _dataq = None # Data queue transmits prefetched elements 231 _controlq = None # Control queue to instruct thread / process shutdown 232 _errorq = None # Error queue to transmit exceptions from worker to master 233 234 _checked_start = False # True once startup has been checkd by _check_start 235 236 def __init__(self, stream, num_prefetch, seed, np_seed, mx_seed): 237 super(_Prefetcher, self).__init__() 238 self.stream = stream 239 assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.' 240 self.num_prefetch = num_prefetch 241 self.seed = seed 242 self.np_seed = np_seed 243 self.mx_seed = mx_seed 244 245 def run(self): 246 """Method representing the process’s activity.""" 247 random.seed(self.seed) 248 np.random.seed(self.np_seed) 249 if not isinstance(self, multiprocessing.Process): 250 # Calling mxnet methods in a subprocess will raise an exception if 251 # mxnet is built with GPU support 252 # https://github.com/apache/incubator-mxnet/issues/4659 253 mx.random.seed(self.mx_seed) 254 255 # Startup - Master waits for this 256 try: 257 stream_iter = iter(self.stream) 258 self._errorq.put(None) 259 except Exception as e: # pylint: disable=broad-except 260 tb = traceback.format_exc() 261 self._errorq.put((e, tb)) 262 263 # Async work 264 while True: 265 try: # Check control queue 266 c = self._controlq.get(False) 267 if c is None: 268 break 269 raise RuntimeError('Got unexpected control code {}'.format(repr(c))) 270 except queue.Empty: 271 pass 272 except RuntimeError as e: 273 tb = traceback.format_exc() 274 self._errorq.put((e, tb)) 275 self._dataq.put(None) 276 277 try: 278 data = next(stream_iter) 279 error = None 280 except Exception as e: # pylint: disable=broad-except 281 tb = traceback.format_exc() 282 error = (e, tb) 283 data = None 284 finally: 285 self._errorq.put(error) 286 self._dataq.put(data) 287 288 def __next__(self): 289 next_item = self._dataq.get() 290 next_error = self._errorq.get() 291 292 if next_error is None: 293 return next_item 294 else: 295 self._controlq.put(None) 296 if isinstance(next_error[0], StopIteration): 297 raise StopIteration 298 return self._reraise(*next_error) 299 300 def _reraise(self, e, tb): 301 print('Reraising exception from Prefetcher', file=sys.stderr) 302 print(tb, file=sys.stderr) 303 raise e 304 305 def _check_start(self): 306 assert not self._checked_start 307 self._checked_start = True 308 next_error = self._errorq.get(block=True) 309 if next_error is not None: 310 self._reraise(*next_error) 311 312 def next(self): 313 return self.__next__() 314 315class _ProcessPrefetcher(_Prefetcher, multiprocessing.Process): 316 """Internal multi-processing prefetcher.""" 317 318 def __init__(self, *args, **kwargs): 319 super(_ProcessPrefetcher, self).__init__(*args, **kwargs) 320 self._dataq = multiprocessing.Queue(self.num_prefetch) 321 self._controlq = multiprocessing.Queue() 322 self._errorq = multiprocessing.Queue(self.num_prefetch) 323 self.daemon = True 324 self.start() 325 self._check_start() 326 327 328class _ThreadPrefetcher(_Prefetcher, threading.Thread): 329 """Internal threaded prefetcher.""" 330 331 def __init__(self, *args, **kwargs): 332 super(_ThreadPrefetcher, self).__init__(*args, **kwargs) 333 self._dataq = queue.Queue(self.num_prefetch) 334 self._controlq = queue.Queue() 335 self._errorq = queue.Queue(self.num_prefetch) 336 self.daemon = True 337 self.start() 338 self._check_start() 339 340 341class PrefetchingStream(DataStream): 342 """Prefetch a DataStream in a separate Thread or Process. 343 344 This iterator will create another thread or process to perform 345 ``iter_next`` and then store the data in memory. It potentially accelerates 346 the data read, at the cost of more memory usage. 347 348 The python, numpy and mxnet random states in the launched Thread or Process 349 will be initialized randomly based on the next 32 bit integer in the 350 python, numpy and mxnet random generator of the caller respectively 351 (random.getrandbits(32), numpy.random.randint(0, 2**32), 352 int(mx.nd.random.uniform(0, 2**32).asscalar())). 353 354 Parameters 355 ---------- 356 stream : DataStream 357 Source stream. 358 num_prefetch : int, default 1 359 Number of elements to prefetch from the stream. Must be greater 0. 360 worker_type : 'thread' or 'process', default 'thread' 361 Use a separate Python Thread or Process to prefetch. 362 363 """ 364 365 def __init__(self, stream, num_prefetch=1, worker_type='thread'): 366 self._stream = stream 367 self._num_prefetch = num_prefetch 368 if num_prefetch < 1: 369 raise ValueError('num_prefetch must be greater 0.') 370 assert worker_type.lower() in ['thread', 'process'] 371 self._multiprocessing = worker_type.lower() == 'process' 372 373 def __iter__(self): 374 seed = random.getrandbits(32) 375 # TODO should be possible to change to 64 bit in MXNet 1.6 (uses int64 by default?) 376 np_seed = np.random.randint(0, np.iinfo(np.int32).max) 377 mx_seed = int(mx.nd.random.uniform(0, np.iinfo(np.int32).max).asscalar()) 378 if self._multiprocessing: 379 return _ProcessPrefetcher(self._stream, self._num_prefetch, 380 seed=seed, np_seed=np_seed, 381 mx_seed=mx_seed) 382 else: 383 return _ThreadPrefetcher(self._stream, self._num_prefetch, 384 seed=seed, np_seed=np_seed, 385 mx_seed=mx_seed) 386