1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# pylint: disable=ungrouped-imports 19"""DatasetLoader. An extension of Gluon data loader that allows 20reading and processing multiple files on-the-fly. 21""" 22 23__all__ = ['DatasetLoader'] 24 25import io 26import pickle 27import warnings 28import multiprocessing 29from functools import partial 30from mxnet import context 31from mxnet.gluon.data.dataloader import ForkingPickler, _as_in_context 32from mxnet.gluon.data.dataloader import default_mp_batchify_fn, default_batchify_fn 33from .stream import _PathDataset 34 35 36# manager for creating shared object 37_manager = None 38_dataset = None 39def _initialize_dataset_worker(manager): 40 global _manager 41 _manager = manager 42 43 44def _dataset_worker_fn(urls, dataset_fn, batch_sampler_fn): 45 """Function to generate datasets and batch sampler for each worker.""" 46 global _manager, _dataset 47 dataset = dataset_fn(urls) 48 batch_sampler = batch_sampler_fn(dataset) 49 if _manager: 50 dataset = _manager.list(zip(*dataset._data)) 51 _dataset = dataset 52 return dataset, batch_sampler 53 54 55def _batch_worker_fn(samples, batchify_fn, dataset=None, counter=None): 56 """Function for processing data in worker process.""" 57 # pylint: disable=unused-argument 58 # it is required that each worker process has to fork a new MXIndexedRecordIO handle 59 # preserving dataset as global variable can save tons of overhead and is safe in new process 60 if len(dataset[0]) > 1: 61 if isinstance(samples[0], (list, tuple)): 62 batch = [batchify_fn([dataset[i] for i in shard]) for shard in samples] 63 else: 64 batch = batchify_fn([dataset[i] for i in samples]) 65 else: 66 if isinstance(samples[0], (list, tuple)): 67 batch = [batchify_fn([dataset[i][0] for i in shard]) for shard in samples] 68 else: 69 batch = batchify_fn([dataset[i][0] for i in samples]) 70 buf = io.BytesIO() 71 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch) 72 return buf.getvalue(), counter 73 74 75class _MultiBatchWorkerIter: 76 """Internal multi-worker iterator for DataLoader.""" 77 def __init__(self, worker_pool, batchify_fn, dataset_iter=None, 78 pin_memory=False, worker_fn=_batch_worker_fn, prefetch=0, 79 manager=None): 80 self._worker_pool = worker_pool 81 self._batchify_fn = batchify_fn 82 self._data_buffer = {} 83 self._rcvd_idx = 0 84 self._sent_idx = 0 85 self._dataset_iter = iter(dataset_iter) 86 self._worker_fn = worker_fn 87 self._pin_memory = pin_memory 88 self._prefetch = prefetch 89 self._dataset = None 90 self._batch_iter = None 91 self._manager = manager 92 93 # datasets reference list 94 self._dataset_refs = [] 95 96 # counter reference dict 97 self._counter_ref = {} 98 99 # pre-fetch 100 for _ in range(self._prefetch): 101 self._push_next() 102 103 def _count_dataset_ref(self, new_dataset): 104 dataset_refs = [] 105 for dataset in self._dataset_refs: 106 if self._counter_ref[id(dataset)].value > 0: 107 dataset_refs.append(dataset) 108 else: 109 del self._counter_ref[id(dataset)] 110 if self._dataset: 111 if self._counter_ref[id(self._dataset)].value > 0: 112 if id(new_dataset) != id(self._dataset): 113 dataset_refs.append(self._dataset) 114 else: 115 del self._counter_ref[id(self._dataset)] 116 117 self._dataset_refs = dataset_refs 118 119 def _next_dataset(self): 120 try: 121 dataset, batch_sampler = next(self._dataset_iter) 122 except StopIteration: 123 return None 124 return dataset, batch_sampler 125 126 def _push_next(self): 127 """Assign next batch workload to workers.""" 128 if self._batch_iter is not None: 129 r = next(self._batch_iter, None) 130 else: 131 r = None 132 if r is None: 133 result = self._next_dataset() 134 if result is None: 135 return 136 else: 137 dataset, batch_sampler = result 138 # Without checking the reference counts of previous datasets in the master process, 139 # the key error can be triggered occasionally. This may be a bug in Python. 140 self._count_dataset_ref(dataset) 141 self._dataset = dataset 142 # initialize reference counter 143 if id(dataset) not in self._counter_ref: 144 self._counter_ref[id(dataset)] = self._manager.Value('i', 0) 145 self._batch_iter = iter(batch_sampler) 146 self._push_next() 147 else: 148 counter = self._counter_ref[id(self._dataset)] 149 counter.value += 1 150 async_ret = self._worker_pool.apply_async( 151 self._worker_fn, (r, self._batchify_fn, self._dataset, counter)) 152 self._data_buffer[self._sent_idx] = async_ret 153 self._sent_idx += 1 154 155 def __next__(self): 156 self._push_next() 157 if self._rcvd_idx == self._sent_idx: 158 assert not self._data_buffer, 'Data buffer should be empty at this moment' 159 raise StopIteration 160 161 assert self._rcvd_idx < self._sent_idx, 'rcvd_idx must be smaller than sent_idx' 162 assert self._rcvd_idx in self._data_buffer, 'fatal error with _push_next, rcvd_idx missing' 163 ret = self._data_buffer.pop(self._rcvd_idx) 164 batch, counter = ret.get() 165 batch = pickle.loads(batch) 166 counter.value -= 1 167 if self._pin_memory: 168 batch = _as_in_context(batch, context.cpu_pinned()) 169 self._rcvd_idx += 1 170 return batch 171 172 def next(self): 173 return self.__next__() 174 175 def __iter__(self): 176 return self 177 178 179class _MultiDatasetWorkerIter: 180 """Internal multi-worker iterator for DataLoader.""" 181 def __init__(self, worker_pool, file_sampler, 182 dataset_fn, batch_sampler_fn, 183 worker_fn=_dataset_worker_fn, 184 prefetch=0, dataset=None, circle_length=1, 185 cached=False, num_max_cached=0): 186 if cached: 187 assert num_max_cached > 0,\ 188 'When cached is turned on, num_max_cached must be positive.' 189 self._worker_pool = worker_pool 190 self._dataset_fn = dataset_fn 191 self._batch_sampler_fn = batch_sampler_fn 192 self._worker_fn = worker_fn 193 self._prefetch = prefetch 194 self._circle_length = circle_length 195 self._cached = cached 196 self._num_max_cached = num_max_cached 197 198 # send and receive index for datasets 199 self._rcvd_idx = 0 200 self._sent_idx = 0 201 self._data_buffer = {} 202 203 self._dataset = [dataset[i] for i in iter(file_sampler)] 204 self._num_datasets = len(self._dataset) 205 206 # construct cached list 207 self._cached_dataset = [] 208 209 # pre-fetch 210 for _ in range(self._prefetch): 211 self._push_next_dataset() 212 213 def _push_next_dataset(self): 214 """Assign next dataset workload to workers.""" 215 current_dataset_idx = self._sent_idx * self._circle_length 216 if current_dataset_idx < self._num_datasets: 217 circle_length = min(self._circle_length, 218 self._num_datasets - current_dataset_idx) 219 urls = [self._dataset[current_dataset_idx + i] for i in range(circle_length)] 220 else: 221 return 222 # push to worker asynchronously 223 async_ret = self._worker_pool.apply_async( 224 self._worker_fn, (urls, self._dataset_fn, self._batch_sampler_fn)) 225 # data buffer stores the async result 226 self._data_buffer[self._sent_idx] = async_ret 227 self._sent_idx += 1 228 229 def _next_dataset(self): 230 """Retrieve the next dataset. Returns None if no dataset is available.""" 231 if self._rcvd_idx == self._sent_idx: 232 assert not self._data_buffer, 'Data buffer should be empty at this moment' 233 return None 234 235 assert self._rcvd_idx < self._sent_idx, \ 236 'rcvd_idx must be smaller than sent_idx' 237 assert self._rcvd_idx in self._data_buffer, \ 238 'fatal error with _next_dataset, rcvd_idx missing' 239 240 if len(self._cached_dataset) == 0 or self._data_buffer[self._rcvd_idx].ready(): 241 ret = self._data_buffer.pop(self._rcvd_idx) 242 dataset, batch_sampler = ret.get() 243 self._rcvd_idx += 1 244 if self._cached and len(self._cached_dataset) < self._num_max_cached: 245 self._cached_dataset.append((dataset, batch_sampler)) 246 else: 247 dataset, batch_sampler = self._cached_dataset.pop(0) 248 249 return dataset, batch_sampler 250 251 def __next__(self): 252 """Next dataset""" 253 self._push_next_dataset() 254 result = self._next_dataset() 255 256 if result is None: 257 raise StopIteration 258 259 return result 260 261 def next(self): 262 """Next dataset""" 263 return self.__next__() 264 265 def __iter__(self): 266 """Returns the iterator object""" 267 return self 268 269 270class DatasetLoader: 271 """Loads data from a list of datasets and returns mini-batches of data. 272 273 One dataset is loaded at a time. 274 275 Parameters 276 ---------- 277 file_patterns: str 278 Path to the input text files. 279 file_sampler : str or gluon.data.Sampler, defaults to 'random' 280 The sampler used to sample a file. The following string values are supported: 281 282 - 'sequential': SequentialSampler 283 - 'random': RandomSampler 284 dataset_fn : DatasetFn, callable 285 Callable object to generate a gluon.data.Dataset given a url. 286 batch_sampler_fn : SamplerFn, callable 287 Callable object to generate a gluon.data.sampler.Sampler given a dataset. 288 dataset_params : dict, default is None 289 Dictionary of parameters passed to dataset_fn. 290 batch_sampler_params : dict, default is None 291 Dictionary of parameters passed to batch_sampler_fn. 292 batchify_fn : callable 293 Callback function to allow users to specify how to merge samples 294 into a batch. Defaults to `default_batchify_fn`:: 295 296 def default_batchify_fn(data): 297 if isinstance(data[0], nd.NDArray): 298 return nd.stack(*data) 299 elif isinstance(data[0], tuple): 300 data = zip(*data) 301 return [default_batchify_fn(i) for i in data] 302 else: 303 data = np.asarray(data) 304 return nd.array(data, dtype=data.dtype) 305 num_dataset_workers : int 306 Number of worker process for dataset creation. 307 num_batch_workers : int 308 Number of worker process for batch creation. 309 pin_memory : boolean, default False 310 If ``True``, the dataloader will copy NDArrays into pinned memory 311 before returning them. Copying from CPU pinned memory to GPU is faster 312 than from normal CPU memory. At the same time, it increases GPU memory. 313 circle_length : int, default is 1 314 The number of files to be read at the same time. When `circle_length` is larger than 1, 315 we merge `circle_length` number of files. 316 dataset_prefetch : int, default is `num_dataset_workers` 317 The number of prefetching datasets only works if `num_workers` > 0. 318 If `prefetch` > 0, it allow worker process to prefetch certain datasets before 319 acquiring data from iterators. 320 Note that using large prefetching batch will provide smoother bootstrapping performance, 321 but will consume more memory. Using smaller number may forfeit the purpose of using 322 multiple worker processes, try reduce `num_dataset_workers` in this case. 323 By default it defaults to `num_dataset_workers`. 324 batch_prefetch : int, default is `num_batch_workers * 2` 325 The number of prefetching batches only works if `num_workers` > 0. 326 If `prefetch` > 0, it allow worker process to prefetch certain batches before 327 acquiring data from iterators. 328 Note that using large prefetching batch will provide smoother bootstrapping performance, 329 but will consume more shared_memory. Using smaller number may forfeit the purpose of using 330 multiple worker processes, try reduce `num_batch_workers` in this case. 331 By default it defaults to `num_batch_workers * 2`. 332 dataset_cached : bool, default is False 333 Whether or not to cache last processed dataset. Each processed dataset can 334 only be cached for once. When there is no new available processed dataset to be fetched, 335 we pop a cached processed dataset. 336 num_max_dataset_cached : int, default is 0 337 Maximum number of cached datasets. It is valid only if `dataset_cached` is True 338 """ 339 def __init__(self, file_patterns, file_sampler, 340 dataset_fn=None, batch_sampler_fn=None, 341 dataset_params=None, batch_sampler_params=None, batchify_fn=None, 342 num_dataset_workers=0, num_batch_workers=0, 343 pin_memory=False, circle_length=1, 344 dataset_prefetch=None, batch_prefetch=None, 345 dataset_cached=False, num_max_dataset_cached=0): 346 assert num_dataset_workers >= 0, \ 347 'num_dataset_workers must be non-negative' 348 assert num_batch_workers >= 0, \ 349 'num_batch_workers must be non-negative' 350 if num_batch_workers > 0: 351 assert num_dataset_workers > 0, \ 352 'num_dataset_workers must be positive when num_batch_workers > 0' 353 else: 354 if num_dataset_workers > 0: 355 warnings.warn('The multi-processing functionalities for both dataset and' 356 ' batch sampling are disabled when num_batch_workers=0 though ' 357 'num_dataset_workers={} > 0'.format(num_dataset_workers)) 358 assert circle_length >= 1, \ 359 'circle_length must be larger than or equal to 1' 360 if dataset_cached: 361 assert num_max_dataset_cached > 0, \ 362 'When dataset_cached is True, num_max_dataset_cached must be positive' 363 364 self._dataset = _PathDataset(file_patterns) 365 self._file_sampler = file_sampler 366 367 assert dataset_fn is not None, 'dataset_fn is not given.' 368 assert batch_sampler_fn is not None, 'batch_sampler_fn is not given.' 369 if dataset_params is not None: 370 self._dataset_fn = partial(dataset_fn, **dataset_params) 371 else: 372 self._dataset_fn = dataset_fn 373 if batch_sampler_params is not None: 374 self._batch_sampler_fn = partial(batch_sampler_fn, **batch_sampler_params) 375 else: 376 self._batch_sampler_fn = batch_sampler_fn 377 378 self._num_dataset_workers = num_dataset_workers 379 self._num_batch_workers = num_batch_workers 380 self._dataset_prefetch = max(0, int(dataset_prefetch) \ 381 if dataset_prefetch is not None else self._num_dataset_workers) 382 self._batch_prefetch = max(0, int(batch_prefetch) \ 383 if batch_prefetch is not None else 2 * self._num_batch_workers) 384 385 self._pin_memory = pin_memory 386 self._circle_length = circle_length 387 self._dataset_cached = dataset_cached 388 self._num_max_dataset_cached = num_max_dataset_cached 389 390 self._manager = None 391 self._dataset_worker_pool = None 392 if self._num_dataset_workers > 0: 393 self._manager = multiprocessing.Manager() 394 self._dataset_worker_pool = multiprocessing.Pool(self._num_dataset_workers, 395 initializer=_initialize_dataset_worker, 396 initargs=[self._manager]) 397 self._batch_worker_pool = None 398 if self._num_batch_workers > 0: 399 self._batch_worker_pool = multiprocessing.Pool(self._num_batch_workers) 400 if batchify_fn is None: 401 if self._num_batch_workers > 0: 402 self._batchify_fn = default_mp_batchify_fn 403 else: 404 self._batchify_fn = default_batchify_fn 405 else: 406 self._batchify_fn = batchify_fn 407 408 def __iter__(self): 409 if self._num_dataset_workers == 0: 410 def _same_process_iter(): 411 urls = [] 412 dataset = [self._dataset[i] for i in iter(self._file_sampler)] 413 for i, url in enumerate(dataset): 414 urls.append(url) 415 if i < len(dataset) - 1: 416 if len(urls) < self._circle_length: 417 continue 418 if self._circle_length == 1: 419 urls = urls[0] 420 dataset, batch_sampler = _dataset_worker_fn(urls, self._dataset_fn, 421 self._batch_sampler_fn) 422 for batch in batch_sampler: 423 ret = self._batchify_fn([dataset[idx] for idx in batch]) 424 if self._pin_memory: 425 ret = _as_in_context(ret, context.cpu_pinned()) 426 yield ret 427 urls = [] 428 return _same_process_iter() 429 430 # multi-worker 431 dataset_iter = _MultiDatasetWorkerIter(self._dataset_worker_pool, 432 worker_fn=_dataset_worker_fn, 433 dataset=self._dataset, 434 file_sampler=self._file_sampler, 435 dataset_fn=self._dataset_fn, 436 batch_sampler_fn=self._batch_sampler_fn, 437 prefetch=self._dataset_prefetch, 438 circle_length=self._circle_length, 439 cached=self._dataset_cached, 440 num_max_cached=self._num_max_dataset_cached) 441 return _MultiBatchWorkerIter(self._batch_worker_pool, self._batchify_fn, dataset_iter, 442 pin_memory=self._pin_memory, worker_fn=_batch_worker_fn, 443 prefetch=self._batch_prefetch, manager=self._manager) 444 445 def __del__(self): 446 if self._dataset_worker_pool: 447 # manually terminate due to a bug that pool is not automatically terminated 448 # https://bugs.python.org/issue34172 449 assert isinstance(self._dataset_worker_pool, multiprocessing.pool.Pool) 450 self._dataset_worker_pool.terminate() 451 if self._batch_worker_pool: 452 assert isinstance(self._batch_worker_pool, multiprocessing.pool.Pool) 453 self._batch_worker_pool.terminate() 454