1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19# pylint: disable=ungrouped-imports 20"""Dataset generator.""" 21__all__ = ['DataLoader'] 22 23import pickle 24import io 25import sys 26import signal 27import multiprocessing 28import multiprocessing.queues 29from multiprocessing.reduction import ForkingPickler 30from multiprocessing.pool import ThreadPool 31import threading 32import numpy as np 33 34try: 35 import multiprocessing.resource_sharer 36except ImportError: 37 pass 38 39from . import sampler as _sampler 40from ... import nd, context 41from ...util import is_np_shape, is_np_array, set_np 42from ... import numpy as _mx_np # pylint: disable=reimported 43 44if sys.platform == 'darwin' or sys.platform == 'win32': 45 def rebuild_ndarray(*args): 46 """Rebuild ndarray from pickled shared memory""" 47 # pylint: disable=no-value-for-parameter 48 return nd.NDArray(nd.ndarray._new_from_shared_mem(*args)) 49 50 def reduce_ndarray(data): 51 """Reduce ndarray to shared memory handle""" 52 return rebuild_ndarray, data._to_shared_mem() 53else: 54 def rebuild_ndarray(pid, fd, shape, dtype): 55 """Rebuild ndarray from pickled shared memory""" 56 # pylint: disable=no-value-for-parameter 57 fd = fd.detach() 58 return nd.NDArray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype)) 59 60 def reduce_ndarray(data): 61 """Reduce ndarray to shared memory handle""" 62 # keep a local ref before duplicating fd 63 data = data.as_in_context(context.Context('cpu_shared', 0)) 64 pid, fd, shape, dtype = data._to_shared_mem() 65 fd = multiprocessing.reduction.DupFd(fd) 66 return rebuild_ndarray, (pid, fd, shape, dtype) 67 68ForkingPickler.register(nd.NDArray, reduce_ndarray) 69 70if sys.platform == 'darwin' or sys.platform == 'win32': 71 def rebuild_np_ndarray(*args): 72 """Rebuild ndarray from pickled shared memory""" 73 # pylint: disable=no-value-for-parameter 74 return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args)) 75 76 def reduce_np_ndarray(data): 77 """Reduce ndarray to shared memory handle""" 78 return rebuild_np_ndarray, data._to_shared_mem() 79else: 80 def rebuild_np_ndarray(pid, fd, shape, dtype): 81 """Rebuild ndarray from pickled shared memory""" 82 # pylint: disable=no-value-for-parameter 83 fd = fd.detach() 84 return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype)) 85 86 def reduce_np_ndarray(data): 87 """Reduce ndarray to shared memory handle""" 88 # keep a local ref before duplicating fd 89 data = data.as_in_context(context.Context('cpu_shared', 0)) 90 pid, fd, shape, dtype = data._to_shared_mem() 91 fd = multiprocessing.reduction.DupFd(fd) 92 return rebuild_np_ndarray, (pid, fd, shape, dtype) 93 94ForkingPickler.register(_mx_np.ndarray, reduce_np_ndarray) 95 96 97class ConnectionWrapper(object): 98 """Connection wrapper for multiprocessing that supports sending 99 NDArray via shared memory.""" 100 101 def __init__(self, conn): 102 self._conn = conn 103 104 def send(self, obj): 105 """Send object""" 106 buf = io.BytesIO() 107 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj) 108 self.send_bytes(buf.getvalue()) 109 110 def recv(self): 111 """Receive object""" 112 buf = self.recv_bytes() 113 return pickle.loads(buf) 114 115 def __getattr__(self, name): 116 """Emmulate conn""" 117 attr = self.__dict__.get('_conn', None) 118 return getattr(attr, name) 119 120 121class Queue(multiprocessing.queues.Queue): 122 """Wrapper for multiprocessing queue that dumps NDArray with shared memory.""" 123 def __init__(self, *args, **kwargs): 124 super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs) 125 self._reader = ConnectionWrapper(self._reader) 126 self._writer = ConnectionWrapper(self._writer) 127 self._send = self._writer.send 128 self._recv = self._reader.recv 129 130 131class SimpleQueue(multiprocessing.queues.SimpleQueue): 132 """Wrapper for multiprocessing SimpleQueue that dumps NDArray with shared memory. 133 SimpleQueue don't use threading internally. 134 """ 135 def __init__(self, *args, **kwargs): 136 super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs) 137 self._reader = ConnectionWrapper(self._reader) 138 self._writer = ConnectionWrapper(self._writer) 139 self._send = self._writer.send 140 self._recv = self._reader.recv 141 142 143def default_batchify_fn(data): 144 """Collate data into batch.""" 145 if isinstance(data[0], nd.NDArray): 146 return _mx_np.stack(data) if is_np_array() else nd.stack(*data) 147 elif isinstance(data[0], tuple): 148 data = zip(*data) 149 return [default_batchify_fn(i) for i in data] 150 else: 151 data = np.asarray(data) 152 array_fn = _mx_np.array if is_np_array() else nd.array 153 return array_fn(data, dtype=data.dtype) 154 155 156def default_mp_batchify_fn(data): 157 """Collate data into batch. Use shared memory for stacking.""" 158 if isinstance(data[0], nd.NDArray): 159 empty_fn = _mx_np.empty if is_np_array() else nd.empty 160 out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype, 161 ctx=context.Context('cpu_shared', 0)) 162 if is_np_array(): 163 return _mx_np.stack(data, out=out) 164 else: 165 return nd.stack(*data, out=out) 166 elif isinstance(data[0], tuple): 167 data = zip(*data) 168 return [default_mp_batchify_fn(i) for i in data] 169 else: 170 data = np.asarray(data) 171 array_fn = _mx_np.array if is_np_array() else nd.array 172 return array_fn(data, dtype=data.dtype, 173 ctx=context.Context('cpu_shared', 0)) 174 175 176def _as_in_context(data, ctx): 177 """Move data into new context.""" 178 if isinstance(data, nd.NDArray): 179 return data.as_in_context(ctx) 180 elif isinstance(data, (list, tuple)): 181 return [_as_in_context(d, ctx) for d in data] 182 return data 183 184 185def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn): 186 """Worker loop for multiprocessing DataLoader.""" 187 while True: 188 idx, samples = key_queue.get() 189 if idx is None: 190 break 191 batch = batchify_fn([dataset[i] for i in samples]) 192 data_queue.put((idx, batch)) 193 194def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False, 195 pin_device_id=0, data_buffer_lock=None): 196 """Fetcher loop for fetching data from queue and put in reorder dict.""" 197 while True: 198 idx, batch = data_queue.get() 199 if idx is None: 200 break 201 if pin_memory: 202 batch = _as_in_context(batch, context.cpu_pinned(pin_device_id)) 203 else: 204 batch = _as_in_context(batch, context.cpu()) 205 if data_buffer_lock is not None: 206 with data_buffer_lock: 207 data_buffer[idx] = batch 208 else: 209 data_buffer[idx] = batch 210 211 212class _MultiWorkerIterV1(object): 213 """Internal multi-worker iterator for DataLoader.""" 214 def __init__(self, num_workers, dataset, batchify_fn, batch_sampler, 215 pin_memory=False, pin_device_id=0, worker_fn=worker_loop_v1): 216 assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers) 217 self._num_workers = num_workers 218 self._dataset = dataset 219 self._batchify_fn = batchify_fn 220 self._batch_sampler = batch_sampler 221 self._key_queue = Queue() 222 self._data_queue = SimpleQueue() 223 224 self._data_buffer = {} 225 self._data_buffer_lock = threading.Lock() 226 227 self._rcvd_idx = 0 228 self._sent_idx = 0 229 self._iter = iter(self._batch_sampler) 230 self._shutdown = False 231 232 workers = [] 233 for _ in range(self._num_workers): 234 worker = multiprocessing.Process( 235 target=worker_fn, 236 args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn)) 237 worker.daemon = True 238 worker.start() 239 workers.append(worker) 240 self._workers = workers 241 242 self._fetcher = threading.Thread( 243 target=fetcher_loop_v1, 244 args=(self._data_queue, self._data_buffer, pin_memory, 245 pin_device_id, self._data_buffer_lock)) 246 self._fetcher.daemon = True 247 self._fetcher.start() 248 249 # pre-fetch 250 for _ in range(2 * self._num_workers): 251 self._push_next() 252 253 def __len__(self): 254 return len(self._batch_sampler) 255 256 def __del__(self): 257 self.shutdown() 258 259 def _push_next(self): 260 """Assign next batch workload to workers.""" 261 r = next(self._iter, None) 262 if r is None: 263 return 264 self._key_queue.put((self._sent_idx, r)) 265 self._sent_idx += 1 266 267 def __next__(self): 268 assert not self._shutdown, "call __next__ after shutdown is forbidden" 269 if self._rcvd_idx == self._sent_idx: 270 assert not self._data_buffer, "Data buffer should be empty at this moment" 271 self.shutdown() 272 raise StopIteration 273 274 while True: 275 if self._rcvd_idx in self._data_buffer: 276 with self._data_buffer_lock: 277 batch = self._data_buffer.pop(self._rcvd_idx) 278 self._rcvd_idx += 1 279 self._push_next() 280 return batch 281 282 def next(self): 283 return self.__next__() 284 285 def __iter__(self): 286 return self 287 288 def shutdown(self): 289 """Shutdown internal workers by pushing terminate signals.""" 290 if not self._shutdown: 291 # send shutdown signal to the fetcher and join data queue first 292 # Remark: loop_fetcher need to be joined prior to the workers. 293 # otherwise, the fetcher may fail at getting data 294 self._data_queue.put((None, None)) 295 self._fetcher.join() 296 # send shutdown signal to all worker processes 297 for _ in range(self._num_workers): 298 self._key_queue.put((None, None)) 299 # force shut down any alive worker processes 300 for w in self._workers: 301 if w.is_alive(): 302 w.terminate() 303 self._shutdown = True 304 305 306class DataLoaderV1(object): 307 """Loads data from a dataset and returns mini-batches of data. 308 309 Parameters 310 ---------- 311 dataset : Dataset 312 Source dataset. Note that numpy and mxnet arrays can be directly used 313 as a Dataset. 314 batch_size : int 315 Size of mini-batch. 316 shuffle : bool 317 Whether to shuffle the samples. 318 sampler : Sampler 319 The sampler to use. Either specify sampler or shuffle, not both. 320 last_batch : {'keep', 'discard', 'rollover'} 321 How to handle the last batch if batch_size does not evenly divide 322 `len(dataset)`. 323 324 keep - A batch with less samples than previous batches is returned. 325 discard - The last batch is discarded if its incomplete. 326 rollover - The remaining samples are rolled over to the next epoch. 327 batch_sampler : Sampler 328 A sampler that returns mini-batches. Do not specify batch_size, 329 shuffle, sampler, and last_batch if batch_sampler is specified. 330 batchify_fn : callable 331 Callback function to allow users to specify how to merge samples 332 into a batch. Defaults to `default_batchify_fn`:: 333 334 def default_batchify_fn(data): 335 if isinstance(data[0], nd.NDArray): 336 return nd.stack(*data) 337 elif isinstance(data[0], tuple): 338 data = zip(*data) 339 return [default_batchify_fn(i) for i in data] 340 else: 341 data = np.asarray(data) 342 return nd.array(data, dtype=data.dtype) 343 344 num_workers : int, default 0 345 The number of multiprocessing workers to use for data preprocessing. 346 pin_memory : boolean, default False 347 If ``True``, the dataloader will copy NDArrays into pinned memory 348 before returning them. Copying from CPU pinned memory to GPU is faster 349 than from normal CPU memory. 350 pin_device_id : int, default 0 351 The device id to use for allocating pinned memory if pin_memory is ``True`` 352 """ 353 def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, 354 last_batch=None, batch_sampler=None, batchify_fn=None, 355 num_workers=0, pin_memory=False, pin_device_id=0): 356 self._dataset = dataset 357 self._pin_memory = pin_memory 358 self._pin_device_id = pin_device_id 359 360 if batch_sampler is None: 361 if batch_size is None: 362 raise ValueError("batch_size must be specified unless " \ 363 "batch_sampler is specified") 364 if sampler is None: 365 if shuffle: 366 sampler = _sampler.RandomSampler(len(dataset)) 367 else: 368 sampler = _sampler.SequentialSampler(len(dataset)) 369 elif shuffle: 370 raise ValueError("shuffle must not be specified if sampler is specified") 371 372 batch_sampler = _sampler.BatchSampler( 373 sampler, batch_size, last_batch if last_batch else 'keep') 374 elif batch_size is not None or shuffle or sampler is not None or \ 375 last_batch is not None: 376 raise ValueError("batch_size, shuffle, sampler and last_batch must " \ 377 "not be specified if batch_sampler is specified.") 378 379 self._batch_sampler = batch_sampler 380 self._num_workers = num_workers if num_workers >= 0 else 0 381 if batchify_fn is None: 382 if num_workers > 0: 383 self._batchify_fn = default_mp_batchify_fn 384 else: 385 self._batchify_fn = default_batchify_fn 386 else: 387 self._batchify_fn = batchify_fn 388 389 def __iter__(self): 390 if self._num_workers == 0: 391 def same_process_iter(): 392 for batch in self._batch_sampler: 393 ret = self._batchify_fn([self._dataset[idx] for idx in batch]) 394 if self._pin_memory: 395 ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id)) 396 yield ret 397 return same_process_iter() 398 399 # multi-worker 400 return _MultiWorkerIterV1(self._num_workers, self._dataset, 401 self._batchify_fn, self._batch_sampler, 402 self._pin_memory, self._pin_device_id) 403 404 def __len__(self): 405 return len(self._batch_sampler) 406 407 408def _thread_worker_initializer(active_shape, active_array): 409 """Initializer for ThreadPool.""" 410 set_np(shape=active_shape, array=active_array) 411 412 413_worker_dataset = None 414def _worker_initializer(dataset, active_shape, active_array): 415 """Initialier for processing pool.""" 416 # global dataset is per-process based and only available in worker processes 417 # this is only necessary to handle MXIndexedRecordIO because otherwise dataset 418 # can be passed as argument 419 global _worker_dataset 420 _worker_dataset = dataset 421 set_np(shape=active_shape, array=active_array) 422 423def _worker_fn(samples, batchify_fn, dataset=None): 424 """Function for processing data in worker process.""" 425 # pylint: disable=unused-argument 426 # it is required that each worker process has to fork a new MXIndexedRecordIO handle 427 # preserving dataset as global variable can save tons of overhead and is safe in new process 428 global _worker_dataset 429 batch = batchify_fn([_worker_dataset[i] for i in samples]) 430 buf = io.BytesIO() 431 ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch) 432 return buf.getvalue() 433 434def _thread_worker_fn(samples, batchify_fn, dataset): 435 """Threadpool worker function for processing data.""" 436 return batchify_fn([dataset[i] for i in samples]) 437 438class _MultiWorkerIter(object): 439 """Internal multi-worker iterator for DataLoader.""" 440 def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False, 441 pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None, 442 data_loader=None, timeout=120): 443 self._worker_pool = worker_pool 444 self._batchify_fn = batchify_fn 445 self._batch_sampler = batch_sampler 446 self._data_buffer = {} 447 self._rcvd_idx = 0 448 self._sent_idx = 0 449 self._iter = iter(self._batch_sampler) 450 self._worker_fn = worker_fn 451 self._pin_memory = pin_memory 452 self._pin_device_id = pin_device_id 453 self._dataset = dataset 454 self._data_loader = data_loader 455 self._timeout = timeout 456 # pre-fetch 457 for _ in range(prefetch): 458 self._push_next() 459 460 def __len__(self): 461 return len(self._batch_sampler) 462 463 def _push_next(self): 464 """Assign next batch workload to workers.""" 465 r = next(self._iter, None) 466 if r is None: 467 return 468 async_ret = self._worker_pool.apply_async( 469 self._worker_fn, (r, self._batchify_fn, self._dataset)) 470 self._data_buffer[self._sent_idx] = async_ret 471 self._sent_idx += 1 472 473 def __next__(self): 474 self._push_next() 475 if self._rcvd_idx == self._sent_idx: 476 assert not self._data_buffer, "Data buffer should be empty at this moment" 477 raise StopIteration 478 479 assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx" 480 assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing" 481 ret = self._data_buffer.pop(self._rcvd_idx) 482 try: 483 if self._dataset is None: 484 batch = pickle.loads(ret.get(self._timeout)) 485 else: 486 batch = ret.get(self._timeout) 487 if self._pin_memory: 488 batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id)) 489 self._rcvd_idx += 1 490 return batch 491 except multiprocessing.context.TimeoutError: 492 msg = '''Worker timed out after {} seconds. This might be caused by \n 493 - Slow transform. Please increase timeout to allow slower data loading in each worker. 494 '''.format(self._timeout) 495 if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool): 496 msg += '''- Insufficient shared_memory if `timeout` is large enough. 497 Please consider reduce `num_workers` or increase shared_memory in system. 498 ''' 499 print(msg) 500 raise 501 except Exception: 502 self._worker_pool.terminate() 503 raise 504 505 def next(self): 506 return self.__next__() 507 508 def __iter__(self): 509 return self 510 511 512class DataLoader(object): 513 """Loads data from a dataset and returns mini-batches of data. 514 515 Parameters 516 ---------- 517 dataset : Dataset 518 Source dataset. Note that numpy and mxnet arrays can be directly used 519 as a Dataset. 520 batch_size : int 521 Size of mini-batch. 522 shuffle : bool 523 Whether to shuffle the samples. 524 sampler : Sampler 525 The sampler to use. Either specify sampler or shuffle, not both. 526 last_batch : {'keep', 'discard', 'rollover'} 527 How to handle the last batch if batch_size does not evenly divide 528 `len(dataset)`. 529 530 keep - A batch with less samples than previous batches is returned. 531 discard - The last batch is discarded if its incomplete. 532 rollover - The remaining samples are rolled over to the next epoch. 533 batch_sampler : Sampler 534 A sampler that returns mini-batches. Do not specify batch_size, 535 shuffle, sampler, and last_batch if batch_sampler is specified. 536 batchify_fn : callable 537 Callback function to allow users to specify how to merge samples 538 into a batch. Defaults to `default_batchify_fn`:: 539 540 def default_batchify_fn(data): 541 if isinstance(data[0], nd.NDArray): 542 return nd.stack(*data) 543 elif isinstance(data[0], tuple): 544 data = zip(*data) 545 return [default_batchify_fn(i) for i in data] 546 else: 547 data = np.asarray(data) 548 return nd.array(data, dtype=data.dtype) 549 550 num_workers : int, default 0 551 The number of multiprocessing workers to use for data preprocessing. 552 pin_memory : boolean, default False 553 If ``True``, the dataloader will copy NDArrays into pinned memory 554 before returning them. Copying from CPU pinned memory to GPU is faster 555 than from normal CPU memory. 556 pin_device_id : int, default 0 557 The device id to use for allocating pinned memory if pin_memory is ``True`` 558 prefetch : int, default is `num_workers * 2` 559 The number of prefetching batches only works if `num_workers` > 0. 560 If `prefetch` > 0, it allow worker process to prefetch certain batches before 561 acquiring data from iterators. 562 Note that using large prefetching batch will provide smoother bootstrapping performance, 563 but will consume more shared_memory. Using smaller number may forfeit the purpose of using 564 multiple worker processes, try reduce `num_workers` in this case. 565 By default it defaults to `num_workers * 2`. 566 thread_pool : bool, default False 567 If ``True``, use threading pool instead of multiprocessing pool. Using threadpool 568 can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing 569 problem, threadpool version may achieve better performance than multiprocessing. 570 timeout : int, default is 120 571 The timeout in seconds for each worker to fetch a batch data. Only modify this number 572 unless you are experiencing timeout and you know it's due to slow data loading. 573 Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these 574 cases please reduce `num_workers` or increase system `shared_memory` size instead. 575 auto_reload : bool, default is True 576 control whether prefetch data after a batch is ended. 577 578 Example: 579 >>> from mxnet.gluon.data import DataLoader, ArrayDataset 580 >>> train_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)]) 581 >>> def transform_train(sample): 582 ... if sample == 0 : print('(pre)fetching data here') 583 ... return sample 584 ... 585 >>> train_iter = DataLoader(train_data.transform_first(transform_train), 586 ... auto_reload=False, batch_size=1,num_workers=1) 587 >>> # no prefetch is performed, the prefetch & autoload start after 588 >>> # train_iter.__iter__() is called. 589 >>> for i in train_iter:pass 590 (pre)fetching data here 591 >>> train_iter = DataLoader(train_data.transform_first(transform_train), 592 ... batch_size=1,num_workers=1) 593 (pre)fetching data here 594 >>> it = iter(train_iter) # nothing is generated since lazy-evaluation occurs 595 >>> it2 = iter(train_iter) 596 >>> it3 = iter(train_iter) 597 >>> it4 = iter(train_iter) 598 >>> _ = next(it2) # the first iter we are using is the prefetched iter. 599 >>> _ = next(it) # since the prefetched iter is consumed, we have to fetch data for `it`. 600 (pre)fetching data here 601 >>> _ = [None for _ in it3] 602 (pre)fetching data here 603 (pre)fetching data here 604 >>> # Here, 2 prefetches are triggered, one is fetching the first batch of `it3` and 605 >>> # another is when `it3` yield its last item, a prefetch is automatically performed. 606 >>> _ = [None for _ in it] 607 >>> # no prefetch is happened since train_loader has already prefetch data. 608 >>> _ = next(it4) 609 >>> # since the prefetch is performed, it4 become the prefetched iter. 610 >>> 611 >>> test_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)]) 612 >>> test_iter = DataLoader(test_data, batch_size=1,num_workers=1) 613 >>> for epoch in range(200): 614 ... # there is almost no difference between it and the default DataLoader 615 ... for data, label in train_iter: 616 ... # training... 617 ... for data, label in test_iter: 618 ... # testing... 619 """ 620 def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None, 621 last_batch=None, batch_sampler=None, batchify_fn=None, 622 num_workers=0, pin_memory=False, pin_device_id=0, 623 prefetch=None, thread_pool=False, timeout=120, auto_reload=False): 624 self._dataset = dataset 625 self._pin_memory = pin_memory 626 self._pin_device_id = pin_device_id 627 self._thread_pool = thread_pool 628 self._timeout = timeout 629 assert timeout > 0, "timeout must be positive, given {}".format(timeout) 630 631 if batch_sampler is None: 632 if batch_size is None: 633 raise ValueError("batch_size must be specified unless " \ 634 "batch_sampler is specified") 635 if sampler is None: 636 if shuffle: 637 sampler = _sampler.RandomSampler(len(dataset)) 638 else: 639 sampler = _sampler.SequentialSampler(len(dataset)) 640 elif shuffle: 641 raise ValueError("shuffle must not be specified if sampler is specified") 642 643 batch_sampler = _sampler.BatchSampler( 644 sampler, batch_size, last_batch if last_batch else 'keep') 645 elif batch_size is not None or shuffle or sampler is not None or \ 646 last_batch is not None: 647 raise ValueError("batch_size, shuffle, sampler and last_batch must " \ 648 "not be specified if batch_sampler is specified.") 649 650 self._batch_sampler = batch_sampler 651 self._num_workers = num_workers if num_workers >= 0 else 0 652 self._worker_pool = None 653 self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers) 654 nd.waitall() 655 import gc 656 gc.collect() 657 nd.waitall() 658 if self._num_workers > 0: 659 if self._thread_pool: 660 self._worker_pool = ThreadPool(self._num_workers, 661 initializer=_thread_worker_initializer, 662 initargs=(is_np_shape(), is_np_array())) 663 else: 664 # set ignore keyboard interupt signal before forking processes 665 original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN) 666 self._worker_pool = multiprocessing.Pool( 667 self._num_workers, initializer=_worker_initializer, 668 initargs=[self._dataset, is_np_shape(), is_np_array()]) 669 # resume keyboard interupt signal in main process 670 signal.signal(signal.SIGINT, original_sigint_handler) 671 if batchify_fn is None: 672 if num_workers > 0: 673 self._batchify_fn = default_mp_batchify_fn 674 else: 675 self._batchify_fn = default_batchify_fn 676 else: 677 self._batchify_fn = batchify_fn 678 self.auto_reload = auto_reload 679 if self.auto_reload: 680 self.refresh() 681 else: 682 self.clean() # ensure self._iter exists. 683 684 def __iter__(self): 685 if self._iter is None: 686 self.refresh() 687 t = self._iter 688 self._iter = None # ensure a single iter would not using twice. 689 for item in t: 690 yield item 691 if self._iter is None and self.auto_reload: 692 # ensure we do not waste any exist iter by mistake 693 self.refresh() 694 695 def _prefetch_iter(self): 696 if self._num_workers == 0: 697 def same_process_iter(): 698 for batch in self._batch_sampler: 699 ret = self._batchify_fn([self._dataset[idx] for idx in batch]) 700 if self._pin_memory: 701 ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id)) 702 yield ret 703 return same_process_iter() 704 705 # multi-worker 706 return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler, 707 pin_memory=self._pin_memory, pin_device_id=self._pin_device_id, 708 worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn, 709 prefetch=self._prefetch, 710 dataset=self._dataset if self._thread_pool else None, 711 data_loader=self, timeout=self._timeout) 712 713 def __len__(self): 714 return len(self._batch_sampler) 715 716 def __del__(self): 717 if self._worker_pool: 718 # manually terminate due to a bug that pool is not automatically terminated 719 # https://bugs.python.org/issue34172 720 assert isinstance(self._worker_pool, multiprocessing.pool.Pool) 721 self._worker_pool.terminate() 722 723 def refresh(self): 724 """Refresh its iter, fetch data again from its dataset""" 725 self._iter = self._prefetch_iter() 726 727 def clean(self): 728 """Remove its prefetched iter, the prefetch step will start after call its __iter__()""" 729 self._iter = None 730