1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19# pylint: disable=ungrouped-imports
20"""Dataset generator."""
21__all__ = ['DataLoader']
22
23import pickle
24import io
25import sys
26import signal
27import multiprocessing
28import multiprocessing.queues
29from multiprocessing.reduction import ForkingPickler
30from multiprocessing.pool import ThreadPool
31import threading
32import numpy as np
33
34try:
35    import multiprocessing.resource_sharer
36except ImportError:
37    pass
38
39from . import sampler as _sampler
40from ... import nd, context
41from ...util import is_np_shape, is_np_array, set_np
42from ... import numpy as _mx_np  # pylint: disable=reimported
43
44if sys.platform == 'darwin' or sys.platform == 'win32':
45    def rebuild_ndarray(*args):
46        """Rebuild ndarray from pickled shared memory"""
47        # pylint: disable=no-value-for-parameter
48        return nd.NDArray(nd.ndarray._new_from_shared_mem(*args))
49
50    def reduce_ndarray(data):
51        """Reduce ndarray to shared memory handle"""
52        return rebuild_ndarray, data._to_shared_mem()
53else:
54    def rebuild_ndarray(pid, fd, shape, dtype):
55        """Rebuild ndarray from pickled shared memory"""
56        # pylint: disable=no-value-for-parameter
57        fd = fd.detach()
58        return nd.NDArray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))
59
60    def reduce_ndarray(data):
61        """Reduce ndarray to shared memory handle"""
62        # keep a local ref before duplicating fd
63        data = data.as_in_context(context.Context('cpu_shared', 0))
64        pid, fd, shape, dtype = data._to_shared_mem()
65        fd = multiprocessing.reduction.DupFd(fd)
66        return rebuild_ndarray, (pid, fd, shape, dtype)
67
68ForkingPickler.register(nd.NDArray, reduce_ndarray)
69
70if sys.platform == 'darwin' or sys.platform == 'win32':
71    def rebuild_np_ndarray(*args):
72        """Rebuild ndarray from pickled shared memory"""
73        # pylint: disable=no-value-for-parameter
74        return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(*args))
75
76    def reduce_np_ndarray(data):
77        """Reduce ndarray to shared memory handle"""
78        return rebuild_np_ndarray, data._to_shared_mem()
79else:
80    def rebuild_np_ndarray(pid, fd, shape, dtype):
81        """Rebuild ndarray from pickled shared memory"""
82        # pylint: disable=no-value-for-parameter
83        fd = fd.detach()
84        return _mx_np.ndarray(nd.ndarray._new_from_shared_mem(pid, fd, shape, dtype))
85
86    def reduce_np_ndarray(data):
87        """Reduce ndarray to shared memory handle"""
88        # keep a local ref before duplicating fd
89        data = data.as_in_context(context.Context('cpu_shared', 0))
90        pid, fd, shape, dtype = data._to_shared_mem()
91        fd = multiprocessing.reduction.DupFd(fd)
92        return rebuild_np_ndarray, (pid, fd, shape, dtype)
93
94ForkingPickler.register(_mx_np.ndarray, reduce_np_ndarray)
95
96
97class ConnectionWrapper(object):
98    """Connection wrapper for multiprocessing that supports sending
99    NDArray via shared memory."""
100
101    def __init__(self, conn):
102        self._conn = conn
103
104    def send(self, obj):
105        """Send object"""
106        buf = io.BytesIO()
107        ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
108        self.send_bytes(buf.getvalue())
109
110    def recv(self):
111        """Receive object"""
112        buf = self.recv_bytes()
113        return pickle.loads(buf)
114
115    def __getattr__(self, name):
116        """Emmulate conn"""
117        attr = self.__dict__.get('_conn', None)
118        return getattr(attr, name)
119
120
121class Queue(multiprocessing.queues.Queue):
122    """Wrapper for multiprocessing queue that dumps NDArray with shared memory."""
123    def __init__(self, *args, **kwargs):
124        super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs)
125        self._reader = ConnectionWrapper(self._reader)
126        self._writer = ConnectionWrapper(self._writer)
127        self._send = self._writer.send
128        self._recv = self._reader.recv
129
130
131class SimpleQueue(multiprocessing.queues.SimpleQueue):
132    """Wrapper for multiprocessing SimpleQueue that dumps NDArray with shared memory.
133       SimpleQueue don't use threading internally.
134    """
135    def __init__(self, *args, **kwargs):
136        super().__init__(*args, ctx=multiprocessing.get_context(), **kwargs)
137        self._reader = ConnectionWrapper(self._reader)
138        self._writer = ConnectionWrapper(self._writer)
139        self._send = self._writer.send
140        self._recv = self._reader.recv
141
142
143def default_batchify_fn(data):
144    """Collate data into batch."""
145    if isinstance(data[0], nd.NDArray):
146        return _mx_np.stack(data) if is_np_array() else nd.stack(*data)
147    elif isinstance(data[0], tuple):
148        data = zip(*data)
149        return [default_batchify_fn(i) for i in data]
150    else:
151        data = np.asarray(data)
152        array_fn = _mx_np.array if is_np_array() else nd.array
153        return array_fn(data, dtype=data.dtype)
154
155
156def default_mp_batchify_fn(data):
157    """Collate data into batch. Use shared memory for stacking."""
158    if isinstance(data[0], nd.NDArray):
159        empty_fn = _mx_np.empty if is_np_array() else nd.empty
160        out = empty_fn((len(data),) + data[0].shape, dtype=data[0].dtype,
161                       ctx=context.Context('cpu_shared', 0))
162        if is_np_array():
163            return _mx_np.stack(data, out=out)
164        else:
165            return nd.stack(*data, out=out)
166    elif isinstance(data[0], tuple):
167        data = zip(*data)
168        return [default_mp_batchify_fn(i) for i in data]
169    else:
170        data = np.asarray(data)
171        array_fn = _mx_np.array if is_np_array() else nd.array
172        return array_fn(data, dtype=data.dtype,
173                        ctx=context.Context('cpu_shared', 0))
174
175
176def _as_in_context(data, ctx):
177    """Move data into new context."""
178    if isinstance(data, nd.NDArray):
179        return data.as_in_context(ctx)
180    elif isinstance(data, (list, tuple)):
181        return [_as_in_context(d, ctx) for d in data]
182    return data
183
184
185def worker_loop_v1(dataset, key_queue, data_queue, batchify_fn):
186    """Worker loop for multiprocessing DataLoader."""
187    while True:
188        idx, samples = key_queue.get()
189        if idx is None:
190            break
191        batch = batchify_fn([dataset[i] for i in samples])
192        data_queue.put((idx, batch))
193
194def fetcher_loop_v1(data_queue, data_buffer, pin_memory=False,
195                    pin_device_id=0, data_buffer_lock=None):
196    """Fetcher loop for fetching data from queue and put in reorder dict."""
197    while True:
198        idx, batch = data_queue.get()
199        if idx is None:
200            break
201        if pin_memory:
202            batch = _as_in_context(batch, context.cpu_pinned(pin_device_id))
203        else:
204            batch = _as_in_context(batch, context.cpu())
205        if data_buffer_lock is not None:
206            with data_buffer_lock:
207                data_buffer[idx] = batch
208        else:
209            data_buffer[idx] = batch
210
211
212class _MultiWorkerIterV1(object):
213    """Internal multi-worker iterator for DataLoader."""
214    def __init__(self, num_workers, dataset, batchify_fn, batch_sampler,
215                 pin_memory=False, pin_device_id=0, worker_fn=worker_loop_v1):
216        assert num_workers > 0, "_MultiWorkerIter is not for {} workers".format(num_workers)
217        self._num_workers = num_workers
218        self._dataset = dataset
219        self._batchify_fn = batchify_fn
220        self._batch_sampler = batch_sampler
221        self._key_queue = Queue()
222        self._data_queue = SimpleQueue()
223
224        self._data_buffer = {}
225        self._data_buffer_lock = threading.Lock()
226
227        self._rcvd_idx = 0
228        self._sent_idx = 0
229        self._iter = iter(self._batch_sampler)
230        self._shutdown = False
231
232        workers = []
233        for _ in range(self._num_workers):
234            worker = multiprocessing.Process(
235                target=worker_fn,
236                args=(self._dataset, self._key_queue, self._data_queue, self._batchify_fn))
237            worker.daemon = True
238            worker.start()
239            workers.append(worker)
240        self._workers = workers
241
242        self._fetcher = threading.Thread(
243            target=fetcher_loop_v1,
244            args=(self._data_queue, self._data_buffer, pin_memory,
245                  pin_device_id, self._data_buffer_lock))
246        self._fetcher.daemon = True
247        self._fetcher.start()
248
249        # pre-fetch
250        for _ in range(2 * self._num_workers):
251            self._push_next()
252
253    def __len__(self):
254        return len(self._batch_sampler)
255
256    def __del__(self):
257        self.shutdown()
258
259    def _push_next(self):
260        """Assign next batch workload to workers."""
261        r = next(self._iter, None)
262        if r is None:
263            return
264        self._key_queue.put((self._sent_idx, r))
265        self._sent_idx += 1
266
267    def __next__(self):
268        assert not self._shutdown, "call __next__ after shutdown is forbidden"
269        if self._rcvd_idx == self._sent_idx:
270            assert not self._data_buffer, "Data buffer should be empty at this moment"
271            self.shutdown()
272            raise StopIteration
273
274        while True:
275            if self._rcvd_idx in self._data_buffer:
276                with self._data_buffer_lock:
277                    batch = self._data_buffer.pop(self._rcvd_idx)
278                self._rcvd_idx += 1
279                self._push_next()
280                return batch
281
282    def next(self):
283        return self.__next__()
284
285    def __iter__(self):
286        return self
287
288    def shutdown(self):
289        """Shutdown internal workers by pushing terminate signals."""
290        if not self._shutdown:
291            # send shutdown signal to the fetcher and join data queue first
292            # Remark:   loop_fetcher need to be joined prior to the workers.
293            #           otherwise, the fetcher may fail at getting data
294            self._data_queue.put((None, None))
295            self._fetcher.join()
296            # send shutdown signal to all worker processes
297            for _ in range(self._num_workers):
298                self._key_queue.put((None, None))
299            # force shut down any alive worker processes
300            for w in self._workers:
301                if w.is_alive():
302                    w.terminate()
303            self._shutdown = True
304
305
306class DataLoaderV1(object):
307    """Loads data from a dataset and returns mini-batches of data.
308
309    Parameters
310    ----------
311    dataset : Dataset
312        Source dataset. Note that numpy and mxnet arrays can be directly used
313        as a Dataset.
314    batch_size : int
315        Size of mini-batch.
316    shuffle : bool
317        Whether to shuffle the samples.
318    sampler : Sampler
319        The sampler to use. Either specify sampler or shuffle, not both.
320    last_batch : {'keep', 'discard', 'rollover'}
321        How to handle the last batch if batch_size does not evenly divide
322        `len(dataset)`.
323
324        keep - A batch with less samples than previous batches is returned.
325        discard - The last batch is discarded if its incomplete.
326        rollover - The remaining samples are rolled over to the next epoch.
327    batch_sampler : Sampler
328        A sampler that returns mini-batches. Do not specify batch_size,
329        shuffle, sampler, and last_batch if batch_sampler is specified.
330    batchify_fn : callable
331        Callback function to allow users to specify how to merge samples
332        into a batch. Defaults to `default_batchify_fn`::
333
334            def default_batchify_fn(data):
335                if isinstance(data[0], nd.NDArray):
336                    return nd.stack(*data)
337                elif isinstance(data[0], tuple):
338                    data = zip(*data)
339                    return [default_batchify_fn(i) for i in data]
340                else:
341                    data = np.asarray(data)
342                    return nd.array(data, dtype=data.dtype)
343
344    num_workers : int, default 0
345        The number of multiprocessing workers to use for data preprocessing.
346    pin_memory : boolean, default False
347        If ``True``, the dataloader will copy NDArrays into pinned memory
348        before returning them. Copying from CPU pinned memory to GPU is faster
349        than from normal CPU memory.
350    pin_device_id : int, default 0
351        The device id to use for allocating pinned memory if pin_memory is ``True``
352    """
353    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
354                 last_batch=None, batch_sampler=None, batchify_fn=None,
355                 num_workers=0, pin_memory=False, pin_device_id=0):
356        self._dataset = dataset
357        self._pin_memory = pin_memory
358        self._pin_device_id = pin_device_id
359
360        if batch_sampler is None:
361            if batch_size is None:
362                raise ValueError("batch_size must be specified unless " \
363                                 "batch_sampler is specified")
364            if sampler is None:
365                if shuffle:
366                    sampler = _sampler.RandomSampler(len(dataset))
367                else:
368                    sampler = _sampler.SequentialSampler(len(dataset))
369            elif shuffle:
370                raise ValueError("shuffle must not be specified if sampler is specified")
371
372            batch_sampler = _sampler.BatchSampler(
373                sampler, batch_size, last_batch if last_batch else 'keep')
374        elif batch_size is not None or shuffle or sampler is not None or \
375                last_batch is not None:
376            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
377                             "not be specified if batch_sampler is specified.")
378
379        self._batch_sampler = batch_sampler
380        self._num_workers = num_workers if num_workers >= 0 else 0
381        if batchify_fn is None:
382            if num_workers > 0:
383                self._batchify_fn = default_mp_batchify_fn
384            else:
385                self._batchify_fn = default_batchify_fn
386        else:
387            self._batchify_fn = batchify_fn
388
389    def __iter__(self):
390        if self._num_workers == 0:
391            def same_process_iter():
392                for batch in self._batch_sampler:
393                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
394                    if self._pin_memory:
395                        ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
396                    yield ret
397            return same_process_iter()
398
399        # multi-worker
400        return _MultiWorkerIterV1(self._num_workers, self._dataset,
401                                  self._batchify_fn, self._batch_sampler,
402                                  self._pin_memory, self._pin_device_id)
403
404    def __len__(self):
405        return len(self._batch_sampler)
406
407
408def _thread_worker_initializer(active_shape, active_array):
409    """Initializer for ThreadPool."""
410    set_np(shape=active_shape, array=active_array)
411
412
413_worker_dataset = None
414def _worker_initializer(dataset, active_shape, active_array):
415    """Initialier for processing pool."""
416    # global dataset is per-process based and only available in worker processes
417    # this is only necessary to handle MXIndexedRecordIO because otherwise dataset
418    # can be passed as argument
419    global _worker_dataset
420    _worker_dataset = dataset
421    set_np(shape=active_shape, array=active_array)
422
423def _worker_fn(samples, batchify_fn, dataset=None):
424    """Function for processing data in worker process."""
425    # pylint: disable=unused-argument
426    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
427    # preserving dataset as global variable can save tons of overhead and is safe in new process
428    global _worker_dataset
429    batch = batchify_fn([_worker_dataset[i] for i in samples])
430    buf = io.BytesIO()
431    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
432    return buf.getvalue()
433
434def _thread_worker_fn(samples, batchify_fn, dataset):
435    """Threadpool worker function for processing data."""
436    return batchify_fn([dataset[i] for i in samples])
437
438class _MultiWorkerIter(object):
439    """Internal multi-worker iterator for DataLoader."""
440    def __init__(self, worker_pool, batchify_fn, batch_sampler, pin_memory=False,
441                 pin_device_id=0, worker_fn=_worker_fn, prefetch=0, dataset=None,
442                 data_loader=None, timeout=120):
443        self._worker_pool = worker_pool
444        self._batchify_fn = batchify_fn
445        self._batch_sampler = batch_sampler
446        self._data_buffer = {}
447        self._rcvd_idx = 0
448        self._sent_idx = 0
449        self._iter = iter(self._batch_sampler)
450        self._worker_fn = worker_fn
451        self._pin_memory = pin_memory
452        self._pin_device_id = pin_device_id
453        self._dataset = dataset
454        self._data_loader = data_loader
455        self._timeout = timeout
456        # pre-fetch
457        for _ in range(prefetch):
458            self._push_next()
459
460    def __len__(self):
461        return len(self._batch_sampler)
462
463    def _push_next(self):
464        """Assign next batch workload to workers."""
465        r = next(self._iter, None)
466        if r is None:
467            return
468        async_ret = self._worker_pool.apply_async(
469            self._worker_fn, (r, self._batchify_fn, self._dataset))
470        self._data_buffer[self._sent_idx] = async_ret
471        self._sent_idx += 1
472
473    def __next__(self):
474        self._push_next()
475        if self._rcvd_idx == self._sent_idx:
476            assert not self._data_buffer, "Data buffer should be empty at this moment"
477            raise StopIteration
478
479        assert self._rcvd_idx < self._sent_idx, "rcvd_idx must be smaller than sent_idx"
480        assert self._rcvd_idx in self._data_buffer, "fatal error with _push_next, rcvd_idx missing"
481        ret = self._data_buffer.pop(self._rcvd_idx)
482        try:
483            if self._dataset is None:
484                batch = pickle.loads(ret.get(self._timeout))
485            else:
486                batch = ret.get(self._timeout)
487            if self._pin_memory:
488                batch = _as_in_context(batch, context.cpu_pinned(self._pin_device_id))
489            self._rcvd_idx += 1
490            return batch
491        except multiprocessing.context.TimeoutError:
492            msg = '''Worker timed out after {} seconds. This might be caused by \n
493            - Slow transform. Please increase timeout to allow slower data loading in each worker.
494            '''.format(self._timeout)
495            if not isinstance(self._worker_pool, multiprocessing.pool.ThreadPool):
496                msg += '''- Insufficient shared_memory if `timeout` is large enough.
497            Please consider reduce `num_workers` or increase shared_memory in system.
498            '''
499            print(msg)
500            raise
501        except Exception:
502            self._worker_pool.terminate()
503            raise
504
505    def next(self):
506        return self.__next__()
507
508    def __iter__(self):
509        return self
510
511
512class DataLoader(object):
513    """Loads data from a dataset and returns mini-batches of data.
514
515    Parameters
516    ----------
517    dataset : Dataset
518        Source dataset. Note that numpy and mxnet arrays can be directly used
519        as a Dataset.
520    batch_size : int
521        Size of mini-batch.
522    shuffle : bool
523        Whether to shuffle the samples.
524    sampler : Sampler
525        The sampler to use. Either specify sampler or shuffle, not both.
526    last_batch : {'keep', 'discard', 'rollover'}
527        How to handle the last batch if batch_size does not evenly divide
528        `len(dataset)`.
529
530        keep - A batch with less samples than previous batches is returned.
531        discard - The last batch is discarded if its incomplete.
532        rollover - The remaining samples are rolled over to the next epoch.
533    batch_sampler : Sampler
534        A sampler that returns mini-batches. Do not specify batch_size,
535        shuffle, sampler, and last_batch if batch_sampler is specified.
536    batchify_fn : callable
537        Callback function to allow users to specify how to merge samples
538        into a batch. Defaults to `default_batchify_fn`::
539
540            def default_batchify_fn(data):
541                if isinstance(data[0], nd.NDArray):
542                    return nd.stack(*data)
543                elif isinstance(data[0], tuple):
544                    data = zip(*data)
545                    return [default_batchify_fn(i) for i in data]
546                else:
547                    data = np.asarray(data)
548                    return nd.array(data, dtype=data.dtype)
549
550    num_workers : int, default 0
551        The number of multiprocessing workers to use for data preprocessing.
552    pin_memory : boolean, default False
553        If ``True``, the dataloader will copy NDArrays into pinned memory
554        before returning them. Copying from CPU pinned memory to GPU is faster
555        than from normal CPU memory.
556    pin_device_id : int, default 0
557        The device id to use for allocating pinned memory if pin_memory is ``True``
558    prefetch : int, default is `num_workers * 2`
559        The number of prefetching batches only works if `num_workers` > 0.
560        If `prefetch` > 0, it allow worker process to prefetch certain batches before
561        acquiring data from iterators.
562        Note that using large prefetching batch will provide smoother bootstrapping performance,
563        but will consume more shared_memory. Using smaller number may forfeit the purpose of using
564        multiple worker processes, try reduce `num_workers` in this case.
565        By default it defaults to `num_workers * 2`.
566    thread_pool : bool, default False
567        If ``True``, use threading pool instead of multiprocessing pool. Using threadpool
568        can avoid shared memory usage. If `DataLoader` is more IO bounded or GIL is not a killing
569        problem, threadpool version may achieve better performance than multiprocessing.
570    timeout : int, default is 120
571        The timeout in seconds for each worker to fetch a batch data. Only modify this number
572        unless you are experiencing timeout and you know it's due to slow data loading.
573        Sometimes full `shared_memory` will cause all workers to hang and causes timeout. In these
574        cases please reduce `num_workers` or increase system `shared_memory` size instead.
575    auto_reload : bool, default is True
576        control whether prefetch data after a batch is ended.
577
578    Example:
579    >>> from mxnet.gluon.data import DataLoader, ArrayDataset
580    >>> train_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
581    >>> def transform_train(sample):
582    ...   if sample == 0 : print('(pre)fetching data here')
583    ...   return sample
584    ...
585    >>> train_iter = DataLoader(train_data.transform_first(transform_train),
586    ...                         auto_reload=False, batch_size=1,num_workers=1)
587    >>> # no prefetch is performed, the prefetch & autoload start after
588    >>> # train_iter.__iter__() is called.
589    >>> for i in train_iter:pass
590    (pre)fetching data here
591    >>> train_iter = DataLoader(train_data.transform_first(transform_train),
592    ...                         batch_size=1,num_workers=1)
593    (pre)fetching data here
594    >>> it = iter(train_iter) # nothing is generated since lazy-evaluation occurs
595    >>> it2 = iter(train_iter)
596    >>> it3 = iter(train_iter)
597    >>> it4 = iter(train_iter)
598    >>> _ = next(it2) # the first iter we are using is the prefetched iter.
599    >>> _ = next(it) # since the prefetched iter is consumed, we have to fetch data for `it`.
600    (pre)fetching data here
601    >>> _ = [None for _ in it3]
602    (pre)fetching data here
603    (pre)fetching data here
604    >>> # Here, 2 prefetches are triggered, one is fetching the first batch of `it3` and
605    >>> # another is when `it3` yield its last item, a prefetch is automatically performed.
606    >>> _ = [None for _ in it]
607    >>> # no prefetch is happened since train_loader has already prefetch data.
608    >>> _ = next(it4)
609    >>> # since the prefetch is performed, it4 become the prefetched iter.
610    >>>
611    >>> test_data = ArrayDataset([i for i in range(10)],[9-i for i in range(10)])
612    >>> test_iter = DataLoader(test_data, batch_size=1,num_workers=1)
613    >>> for epoch in range(200):
614    ...   # there is almost no difference between it and the default DataLoader
615    ...   for data, label in train_iter:
616    ...     # training...
617    ...   for data, label in test_iter:
618    ...     # testing...
619    """
620    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
621                 last_batch=None, batch_sampler=None, batchify_fn=None,
622                 num_workers=0, pin_memory=False, pin_device_id=0,
623                 prefetch=None, thread_pool=False, timeout=120, auto_reload=False):
624        self._dataset = dataset
625        self._pin_memory = pin_memory
626        self._pin_device_id = pin_device_id
627        self._thread_pool = thread_pool
628        self._timeout = timeout
629        assert timeout > 0, "timeout must be positive, given {}".format(timeout)
630
631        if batch_sampler is None:
632            if batch_size is None:
633                raise ValueError("batch_size must be specified unless " \
634                                 "batch_sampler is specified")
635            if sampler is None:
636                if shuffle:
637                    sampler = _sampler.RandomSampler(len(dataset))
638                else:
639                    sampler = _sampler.SequentialSampler(len(dataset))
640            elif shuffle:
641                raise ValueError("shuffle must not be specified if sampler is specified")
642
643            batch_sampler = _sampler.BatchSampler(
644                sampler, batch_size, last_batch if last_batch else 'keep')
645        elif batch_size is not None or shuffle or sampler is not None or \
646                last_batch is not None:
647            raise ValueError("batch_size, shuffle, sampler and last_batch must " \
648                             "not be specified if batch_sampler is specified.")
649
650        self._batch_sampler = batch_sampler
651        self._num_workers = num_workers if num_workers >= 0 else 0
652        self._worker_pool = None
653        self._prefetch = max(0, int(prefetch) if prefetch is not None else 2 * self._num_workers)
654        nd.waitall()
655        import gc
656        gc.collect()
657        nd.waitall()
658        if self._num_workers > 0:
659            if self._thread_pool:
660                self._worker_pool = ThreadPool(self._num_workers,
661                                               initializer=_thread_worker_initializer,
662                                               initargs=(is_np_shape(), is_np_array()))
663            else:
664                # set ignore keyboard interupt signal before forking processes
665                original_sigint_handler = signal.signal(signal.SIGINT, signal.SIG_IGN)
666                self._worker_pool = multiprocessing.Pool(
667                    self._num_workers, initializer=_worker_initializer,
668                    initargs=[self._dataset, is_np_shape(), is_np_array()])
669                # resume keyboard interupt signal in main process
670                signal.signal(signal.SIGINT, original_sigint_handler)
671        if batchify_fn is None:
672            if num_workers > 0:
673                self._batchify_fn = default_mp_batchify_fn
674            else:
675                self._batchify_fn = default_batchify_fn
676        else:
677            self._batchify_fn = batchify_fn
678        self.auto_reload = auto_reload
679        if self.auto_reload:
680            self.refresh()
681        else:
682            self.clean() # ensure self._iter exists.
683
684    def __iter__(self):
685        if self._iter is None:
686            self.refresh()
687        t = self._iter
688        self._iter = None # ensure a single iter would not using twice.
689        for item in t:
690            yield item
691        if self._iter is None and self.auto_reload:
692            # ensure we do not waste any exist iter by mistake
693            self.refresh()
694
695    def _prefetch_iter(self):
696        if self._num_workers == 0:
697            def same_process_iter():
698                for batch in self._batch_sampler:
699                    ret = self._batchify_fn([self._dataset[idx] for idx in batch])
700                    if self._pin_memory:
701                        ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))
702                    yield ret
703            return same_process_iter()
704
705        # multi-worker
706        return _MultiWorkerIter(self._worker_pool, self._batchify_fn, self._batch_sampler,
707                                pin_memory=self._pin_memory, pin_device_id=self._pin_device_id,
708                                worker_fn=_thread_worker_fn if self._thread_pool else _worker_fn,
709                                prefetch=self._prefetch,
710                                dataset=self._dataset if self._thread_pool else None,
711                                data_loader=self, timeout=self._timeout)
712
713    def __len__(self):
714        return len(self._batch_sampler)
715
716    def __del__(self):
717        if self._worker_pool:
718            # manually terminate due to a bug that pool is not automatically terminated
719            # https://bugs.python.org/issue34172
720            assert isinstance(self._worker_pool, multiprocessing.pool.Pool)
721            self._worker_pool.terminate()
722
723    def refresh(self):
724        """Refresh its iter, fetch data again from its dataset"""
725        self._iter = self._prefetch_iter()
726
727    def clean(self):
728        """Remove its prefetched iter, the prefetch step will start after call its __iter__()"""
729        self._iter = None
730