1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=ungrouped-imports
19"""DatasetLoader. An extension of Gluon data loader that allows
20reading and processing multiple files on-the-fly.
21"""
22
23__all__ = ['DatasetLoader']
24
25import io
26import pickle
27import warnings
28import multiprocessing
29from functools import partial
30from mxnet import context
31from mxnet.gluon.data.dataloader import ForkingPickler, _as_in_context
32from mxnet.gluon.data.dataloader import default_mp_batchify_fn, default_batchify_fn
33from .stream import _PathDataset
34
35
36# manager for creating shared object
37_manager = None
38_dataset = None
39def _initialize_dataset_worker(manager):
40    global _manager
41    _manager = manager
42
43
44def _dataset_worker_fn(urls, dataset_fn, batch_sampler_fn):
45    """Function to generate datasets and batch sampler for each worker."""
46    global _manager, _dataset
47    dataset = dataset_fn(urls)
48    batch_sampler = batch_sampler_fn(dataset)
49    if _manager:
50        dataset = _manager.list(zip(*dataset._data))
51    _dataset = dataset
52    return dataset, batch_sampler
53
54
55def _batch_worker_fn(samples, batchify_fn, dataset=None, counter=None):
56    """Function for processing data in worker process."""
57    # pylint: disable=unused-argument
58    # it is required that each worker process has to fork a new MXIndexedRecordIO handle
59    # preserving dataset as global variable can save tons of overhead and is safe in new process
60    if len(dataset[0]) > 1:
61        if isinstance(samples[0], (list, tuple)):
62            batch = [batchify_fn([dataset[i] for i in shard]) for shard in samples]
63        else:
64            batch = batchify_fn([dataset[i] for i in samples])
65    else:
66        if isinstance(samples[0], (list, tuple)):
67            batch = [batchify_fn([dataset[i][0] for i in shard]) for shard in samples]
68        else:
69            batch = batchify_fn([dataset[i][0] for i in samples])
70    buf = io.BytesIO()
71    ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(batch)
72    return buf.getvalue(), counter
73
74
75class _MultiBatchWorkerIter:
76    """Internal multi-worker iterator for DataLoader."""
77    def __init__(self, worker_pool, batchify_fn, dataset_iter=None,
78                 pin_memory=False, worker_fn=_batch_worker_fn, prefetch=0,
79                 manager=None):
80        self._worker_pool = worker_pool
81        self._batchify_fn = batchify_fn
82        self._data_buffer = {}
83        self._rcvd_idx = 0
84        self._sent_idx = 0
85        self._dataset_iter = iter(dataset_iter)
86        self._worker_fn = worker_fn
87        self._pin_memory = pin_memory
88        self._prefetch = prefetch
89        self._dataset = None
90        self._batch_iter = None
91        self._manager = manager
92
93        # datasets reference list
94        self._dataset_refs = []
95
96        # counter reference dict
97        self._counter_ref = {}
98
99        # pre-fetch
100        for _ in range(self._prefetch):
101            self._push_next()
102
103    def _count_dataset_ref(self, new_dataset):
104        dataset_refs = []
105        for dataset in self._dataset_refs:
106            if self._counter_ref[id(dataset)].value > 0:
107                dataset_refs.append(dataset)
108            else:
109                del self._counter_ref[id(dataset)]
110        if self._dataset:
111            if self._counter_ref[id(self._dataset)].value > 0:
112                if id(new_dataset) != id(self._dataset):
113                    dataset_refs.append(self._dataset)
114            else:
115                del self._counter_ref[id(self._dataset)]
116
117        self._dataset_refs = dataset_refs
118
119    def _next_dataset(self):
120        try:
121            dataset, batch_sampler = next(self._dataset_iter)
122        except StopIteration:
123            return None
124        return dataset, batch_sampler
125
126    def _push_next(self):
127        """Assign next batch workload to workers."""
128        if self._batch_iter is not None:
129            r = next(self._batch_iter, None)
130        else:
131            r = None
132        if r is None:
133            result = self._next_dataset()
134            if result is None:
135                return
136            else:
137                dataset, batch_sampler = result
138                # Without checking the reference counts of previous datasets in the master process,
139                # the key error can be triggered occasionally. This may be a bug in Python.
140                self._count_dataset_ref(dataset)
141                self._dataset = dataset
142                # initialize reference counter
143                if id(dataset) not in self._counter_ref:
144                    self._counter_ref[id(dataset)] = self._manager.Value('i', 0)
145                self._batch_iter = iter(batch_sampler)
146                self._push_next()
147        else:
148            counter = self._counter_ref[id(self._dataset)]
149            counter.value += 1
150            async_ret = self._worker_pool.apply_async(
151                self._worker_fn, (r, self._batchify_fn, self._dataset, counter))
152            self._data_buffer[self._sent_idx] = async_ret
153            self._sent_idx += 1
154
155    def __next__(self):
156        self._push_next()
157        if self._rcvd_idx == self._sent_idx:
158            assert not self._data_buffer, 'Data buffer should be empty at this moment'
159            raise StopIteration
160
161        assert self._rcvd_idx < self._sent_idx, 'rcvd_idx must be smaller than sent_idx'
162        assert self._rcvd_idx in self._data_buffer, 'fatal error with _push_next, rcvd_idx missing'
163        ret = self._data_buffer.pop(self._rcvd_idx)
164        batch, counter = ret.get()
165        batch = pickle.loads(batch)
166        counter.value -= 1
167        if self._pin_memory:
168            batch = _as_in_context(batch, context.cpu_pinned())
169        self._rcvd_idx += 1
170        return batch
171
172    def next(self):
173        return self.__next__()
174
175    def __iter__(self):
176        return self
177
178
179class _MultiDatasetWorkerIter:
180    """Internal multi-worker iterator for DataLoader."""
181    def __init__(self, worker_pool, file_sampler,
182                 dataset_fn, batch_sampler_fn,
183                 worker_fn=_dataset_worker_fn,
184                 prefetch=0, dataset=None, circle_length=1,
185                 cached=False, num_max_cached=0):
186        if cached:
187            assert num_max_cached > 0,\
188                'When cached is turned on, num_max_cached must be positive.'
189        self._worker_pool = worker_pool
190        self._dataset_fn = dataset_fn
191        self._batch_sampler_fn = batch_sampler_fn
192        self._worker_fn = worker_fn
193        self._prefetch = prefetch
194        self._circle_length = circle_length
195        self._cached = cached
196        self._num_max_cached = num_max_cached
197
198        # send and receive index for datasets
199        self._rcvd_idx = 0
200        self._sent_idx = 0
201        self._data_buffer = {}
202
203        self._dataset = [dataset[i] for i in iter(file_sampler)]
204        self._num_datasets = len(self._dataset)
205
206        # construct cached list
207        self._cached_dataset = []
208
209        # pre-fetch
210        for _ in range(self._prefetch):
211            self._push_next_dataset()
212
213    def _push_next_dataset(self):
214        """Assign next dataset workload to workers."""
215        current_dataset_idx = self._sent_idx * self._circle_length
216        if current_dataset_idx < self._num_datasets:
217            circle_length = min(self._circle_length,
218                                self._num_datasets - current_dataset_idx)
219            urls = [self._dataset[current_dataset_idx + i] for i in range(circle_length)]
220        else:
221            return
222        # push to worker asynchronously
223        async_ret = self._worker_pool.apply_async(
224            self._worker_fn, (urls, self._dataset_fn, self._batch_sampler_fn))
225        # data buffer stores the async result
226        self._data_buffer[self._sent_idx] = async_ret
227        self._sent_idx += 1
228
229    def _next_dataset(self):
230        """Retrieve the next dataset. Returns None if no dataset is available."""
231        if self._rcvd_idx == self._sent_idx:
232            assert not self._data_buffer, 'Data buffer should be empty at this moment'
233            return None
234
235        assert self._rcvd_idx < self._sent_idx, \
236               'rcvd_idx must be smaller than sent_idx'
237        assert self._rcvd_idx in self._data_buffer, \
238               'fatal error with _next_dataset, rcvd_idx missing'
239
240        if len(self._cached_dataset) == 0 or self._data_buffer[self._rcvd_idx].ready():
241            ret = self._data_buffer.pop(self._rcvd_idx)
242            dataset, batch_sampler = ret.get()
243            self._rcvd_idx += 1
244            if self._cached and len(self._cached_dataset) < self._num_max_cached:
245                self._cached_dataset.append((dataset, batch_sampler))
246        else:
247            dataset, batch_sampler = self._cached_dataset.pop(0)
248
249        return dataset, batch_sampler
250
251    def __next__(self):
252        """Next dataset"""
253        self._push_next_dataset()
254        result = self._next_dataset()
255
256        if result is None:
257            raise StopIteration
258
259        return result
260
261    def next(self):
262        """Next dataset"""
263        return self.__next__()
264
265    def __iter__(self):
266        """Returns the iterator object"""
267        return self
268
269
270class DatasetLoader:
271    """Loads data from a list of datasets and returns mini-batches of data.
272
273    One dataset is loaded at a time.
274
275    Parameters
276    ----------
277    file_patterns: str
278        Path to the input text files.
279    file_sampler : str or gluon.data.Sampler, defaults to 'random'
280        The sampler used to sample a file. The following string values are supported:
281
282        - 'sequential': SequentialSampler
283        - 'random': RandomSampler
284    dataset_fn : DatasetFn, callable
285        Callable object to generate a gluon.data.Dataset given a url.
286    batch_sampler_fn : SamplerFn, callable
287        Callable object to generate a gluon.data.sampler.Sampler given a dataset.
288    dataset_params : dict, default is None
289        Dictionary of parameters passed to dataset_fn.
290    batch_sampler_params : dict, default is None
291        Dictionary of parameters passed to batch_sampler_fn.
292    batchify_fn : callable
293        Callback function to allow users to specify how to merge samples
294        into a batch. Defaults to `default_batchify_fn`::
295
296            def default_batchify_fn(data):
297                if isinstance(data[0], nd.NDArray):
298                    return nd.stack(*data)
299                elif isinstance(data[0], tuple):
300                    data = zip(*data)
301                    return [default_batchify_fn(i) for i in data]
302                else:
303                    data = np.asarray(data)
304                    return nd.array(data, dtype=data.dtype)
305    num_dataset_workers : int
306        Number of worker process for dataset creation.
307    num_batch_workers : int
308        Number of worker process for batch creation.
309    pin_memory : boolean, default False
310        If ``True``, the dataloader will copy NDArrays into pinned memory
311        before returning them. Copying from CPU pinned memory to GPU is faster
312        than from normal CPU memory. At the same time, it increases GPU memory.
313    circle_length : int, default is 1
314        The number of files to be read at the same time. When `circle_length` is larger than 1,
315        we merge `circle_length` number of files.
316    dataset_prefetch : int, default is `num_dataset_workers`
317        The number of prefetching datasets only works if `num_workers` > 0.
318        If `prefetch` > 0, it allow worker process to prefetch certain datasets before
319        acquiring data from iterators.
320        Note that using large prefetching batch will provide smoother bootstrapping performance,
321        but will consume more memory. Using smaller number may forfeit the purpose of using
322        multiple worker processes, try reduce `num_dataset_workers` in this case.
323        By default it defaults to `num_dataset_workers`.
324    batch_prefetch : int, default is `num_batch_workers * 2`
325        The number of prefetching batches only works if `num_workers` > 0.
326        If `prefetch` > 0, it allow worker process to prefetch certain batches before
327        acquiring data from iterators.
328        Note that using large prefetching batch will provide smoother bootstrapping performance,
329        but will consume more shared_memory. Using smaller number may forfeit the purpose of using
330        multiple worker processes, try reduce `num_batch_workers` in this case.
331        By default it defaults to `num_batch_workers * 2`.
332    dataset_cached : bool, default is False
333        Whether or not to cache last processed dataset. Each processed dataset can
334        only be cached for once. When there is no new available processed dataset to be fetched,
335        we pop a cached processed dataset.
336    num_max_dataset_cached : int, default is 0
337        Maximum number of cached datasets. It is valid only if `dataset_cached` is True
338    """
339    def __init__(self, file_patterns, file_sampler,
340                 dataset_fn=None, batch_sampler_fn=None,
341                 dataset_params=None, batch_sampler_params=None, batchify_fn=None,
342                 num_dataset_workers=0, num_batch_workers=0,
343                 pin_memory=False, circle_length=1,
344                 dataset_prefetch=None, batch_prefetch=None,
345                 dataset_cached=False, num_max_dataset_cached=0):
346        assert num_dataset_workers >= 0, \
347               'num_dataset_workers must be non-negative'
348        assert num_batch_workers >= 0, \
349               'num_batch_workers must be non-negative'
350        if num_batch_workers > 0:
351            assert num_dataset_workers > 0, \
352                'num_dataset_workers must be positive when num_batch_workers > 0'
353        else:
354            if num_dataset_workers > 0:
355                warnings.warn('The multi-processing functionalities for both dataset and'
356                              ' batch sampling are disabled when num_batch_workers=0 though '
357                              'num_dataset_workers={} > 0'.format(num_dataset_workers))
358        assert circle_length >= 1, \
359               'circle_length must be larger than or equal to 1'
360        if dataset_cached:
361            assert num_max_dataset_cached > 0, \
362                'When dataset_cached is True, num_max_dataset_cached must be positive'
363
364        self._dataset = _PathDataset(file_patterns)
365        self._file_sampler = file_sampler
366
367        assert dataset_fn is not None, 'dataset_fn is not given.'
368        assert batch_sampler_fn is not None, 'batch_sampler_fn is not given.'
369        if dataset_params is not None:
370            self._dataset_fn = partial(dataset_fn, **dataset_params)
371        else:
372            self._dataset_fn = dataset_fn
373        if batch_sampler_params is not None:
374            self._batch_sampler_fn = partial(batch_sampler_fn, **batch_sampler_params)
375        else:
376            self._batch_sampler_fn = batch_sampler_fn
377
378        self._num_dataset_workers = num_dataset_workers
379        self._num_batch_workers = num_batch_workers
380        self._dataset_prefetch = max(0, int(dataset_prefetch) \
381                if dataset_prefetch is not None else self._num_dataset_workers)
382        self._batch_prefetch = max(0, int(batch_prefetch) \
383                if batch_prefetch is not None else 2 * self._num_batch_workers)
384
385        self._pin_memory = pin_memory
386        self._circle_length = circle_length
387        self._dataset_cached = dataset_cached
388        self._num_max_dataset_cached = num_max_dataset_cached
389
390        self._manager = None
391        self._dataset_worker_pool = None
392        if self._num_dataset_workers > 0:
393            self._manager = multiprocessing.Manager()
394            self._dataset_worker_pool = multiprocessing.Pool(self._num_dataset_workers,
395                                                             initializer=_initialize_dataset_worker,
396                                                             initargs=[self._manager])
397        self._batch_worker_pool = None
398        if self._num_batch_workers > 0:
399            self._batch_worker_pool = multiprocessing.Pool(self._num_batch_workers)
400        if batchify_fn is None:
401            if self._num_batch_workers > 0:
402                self._batchify_fn = default_mp_batchify_fn
403            else:
404                self._batchify_fn = default_batchify_fn
405        else:
406            self._batchify_fn = batchify_fn
407
408    def __iter__(self):
409        if self._num_dataset_workers == 0:
410            def _same_process_iter():
411                urls = []
412                dataset = [self._dataset[i] for i in iter(self._file_sampler)]
413                for i, url in enumerate(dataset):
414                    urls.append(url)
415                    if i < len(dataset) - 1:
416                        if len(urls) < self._circle_length:
417                            continue
418                    if self._circle_length == 1:
419                        urls = urls[0]
420                    dataset, batch_sampler = _dataset_worker_fn(urls, self._dataset_fn,
421                                                                self._batch_sampler_fn)
422                    for batch in batch_sampler:
423                        ret = self._batchify_fn([dataset[idx] for idx in batch])
424                        if self._pin_memory:
425                            ret = _as_in_context(ret, context.cpu_pinned())
426                        yield ret
427                    urls = []
428            return _same_process_iter()
429
430        # multi-worker
431        dataset_iter = _MultiDatasetWorkerIter(self._dataset_worker_pool,
432                                               worker_fn=_dataset_worker_fn,
433                                               dataset=self._dataset,
434                                               file_sampler=self._file_sampler,
435                                               dataset_fn=self._dataset_fn,
436                                               batch_sampler_fn=self._batch_sampler_fn,
437                                               prefetch=self._dataset_prefetch,
438                                               circle_length=self._circle_length,
439                                               cached=self._dataset_cached,
440                                               num_max_cached=self._num_max_dataset_cached)
441        return _MultiBatchWorkerIter(self._batch_worker_pool, self._batchify_fn, dataset_iter,
442                                     pin_memory=self._pin_memory, worker_fn=_batch_worker_fn,
443                                     prefetch=self._batch_prefetch, manager=self._manager)
444
445    def __del__(self):
446        if self._dataset_worker_pool:
447            # manually terminate due to a bug that pool is not automatically terminated
448            # https://bugs.python.org/issue34172
449            assert isinstance(self._dataset_worker_pool, multiprocessing.pool.Pool)
450            self._dataset_worker_pool.terminate()
451        if self._batch_worker_pool:
452            assert isinstance(self._batch_worker_pool, multiprocessing.pool.Pool)
453            self._batch_worker_pool.terminate()
454