1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: disable=undefined-all-variable
19"""NLP Toolkit Data Stream API. It allows easy and customizable streaming of
20corpora and dataset files. Files can be streamed into formats that are
21ready for training and evaluation."""
22
23
24import glob
25import multiprocessing
26import multiprocessing.pool
27import os
28import queue
29import random
30import sys
31import threading
32import traceback
33
34import numpy as np
35
36import mxnet as mx
37from mxnet.gluon.data import RandomSampler, Sampler, SequentialSampler
38
39__all__ = [
40    'DataStream', 'SimpleDataStream', 'DatasetStream', 'SimpleDatasetStream',
41    'PrefetchingStream']
42
43class DataStream:
44    """Abstract Data Stream Interface.
45
46    DataStreams are useful to avoid loading big datasets to memory. A
47    DataStream is a iterable object (it implements the __iter__ function).
48    Whenever an iteration over the DataStream is requested (e.g. in a for loop
49    or by calling iter(datastream)), a new iterator over all samples in the
50    DataStream is returned. DataStreams can be lazily transformed by calling
51    `transform()` which returns a DataStream over the transformed samples.
52
53    """
54
55    def __iter__(self):
56        """Return an iterator over all elements of the DataStream.
57
58        This method returns a new iterator object that can iterate over
59        all the objects in the DataStream.
60
61        Returns
62        -------
63        iterator
64            An object implementing the Python *iterator protocol*.
65
66        """
67        raise NotImplementedError
68
69    def transform(self, fn):
70        """Transform a DataStream lazily.
71
72        Returns
73        -------
74        DataStream
75            The data stream that lazily transforms the data while streaming.
76        """
77
78        return _LazyTransformDataStream(self, fn)
79
80
81class SimpleDataStream(DataStream):
82    """SimpleDataStream wraps iterables to expose the DataStream API.
83
84    Unlike the iterable itself, the SimpleDataStream exposes the DataStream API
85    and allows lazy transformation of the iterable.
86
87    """
88    def __init__(self, iterable):
89        self._stream = iterable
90
91    def __iter__(self):
92        return iter(self._stream)
93
94
95class _LazyTransformDataStream(DataStream):
96    """Data stream that lazily transforms the data."""
97    def __init__(self, stream, fn):
98        self._stream = stream
99        self._fn = fn
100
101    def __iter__(self):
102        stream_iter = iter(self._stream)
103
104        # Yield must be hidden in closure so that __iter__ is called before
105        # __next__ is called. This is important, as calling iter(self._stream)
106        # may trigger multi-threaded or multi-processing prefetching of the
107        # stream.
108        def _closure():
109            try:
110                item = next(stream_iter)
111            except StopIteration:
112                return
113            istuple = isinstance(item, tuple)
114            if istuple:
115                yield self._fn(*item)
116                while True:
117                    try:
118                        yield self._fn(*next(stream_iter))
119                    except StopIteration:
120                        return
121            else:
122                yield self._fn(item)
123                while True:
124                    try:
125                        yield self._fn(next(stream_iter))
126                    except StopIteration:
127                        return
128
129        return _closure()
130
131
132class DatasetStream(DataStream):
133    """Abstract Dataset Stream Interface.
134
135    A DatasetStream is a DataStream where each sample is a
136    `mxnet.gluon.data.Dataset`. An iteration over a DatasetStream iterates over
137    `mxnet.gluon.data.Dataset` objects, representing a chunk or shards of some
138    large datasets.
139
140    Iterating over sizeable chunks of a dataset can be helpful to speed up
141    preprocessing as the overhead of preprocessing each sample individually is
142    reduced (this is similar to the idea of using batches for training a
143    model).
144
145    """
146
147    def __iter__(self):
148        raise NotImplementedError
149
150
151class _PathDataset(mx.gluon.data.SimpleDataset):
152    """A simple Datasets containing a list of paths given the file_pattern.
153
154    Parameters
155    ----------
156    file_pattern: str
157        Path to the input text files.
158    """
159    def __init__(self, file_pattern):
160        if not isinstance(file_pattern, str):
161            raise TypeError('file_pattern must be str, but got %s'%type(file_pattern))
162        files = []
163        for pattern in file_pattern.split(','):
164            files.extend(glob.glob(os.path.expanduser(pattern.strip())))
165        files = sorted(files)
166        if len(files) == 0:
167            raise ValueError('Cannot find any file with path "%s"'%file_pattern)
168        super(_PathDataset, self).__init__(files)
169
170
171class SimpleDatasetStream(DatasetStream):
172    """A simple stream of Datasets.
173
174    The SimpleDatasetStream is created from multiple files based on provided
175    `file_pattern`. One file is read at a time and a corresponding Dataset is
176    returned. The Dataset is created based on the file and the kwargs passed to
177    SimpleDatasetStream.
178
179    Parameters
180    ----------
181    dataset : class
182        The class for which to create an object for every file. kwargs are
183        passed to this class.
184    file_pattern: str
185        Path to the input text files.
186    file_sampler : str or gluon.data.Sampler, defaults to 'random'
187        The sampler used to sample a file. The following string values are supported:
188
189        - 'sequential': SequentialSampler
190        - 'random': RandomSampler
191    kwargs
192        All other keyword arguments are passed to the dataset constructor.
193    """
194    def __init__(self, dataset, file_pattern, file_sampler='random', **kwargs):
195        # TODO(haibin) reuse _SimpleDatasetPathStream here
196        if not isinstance(file_pattern, str):
197            raise TypeError('file_pattern must be str, but got %s'%type(file_pattern))
198        self._dataset = dataset
199        self._files = []
200        for pattern in file_pattern.split(','):
201            self._files.extend(glob.glob(os.path.expanduser(pattern.strip())))
202        self._files = sorted(self._files)
203
204        if len(self._files) == 0:
205            raise ValueError('Cannot find any file with path "%s"'%file_pattern)
206        self._file_sampler = self._get_sampler(file_sampler)
207        self._kwargs = kwargs
208
209    def _get_sampler(self, sampler):
210        if isinstance(sampler, Sampler):
211            return sampler
212        if isinstance(sampler, str):
213            length = len(self._files)
214            if sampler == 'random':
215                return RandomSampler(length)
216            if sampler == 'sequential':
217                return SequentialSampler(length)
218        raise ValueError('file_sampler must be a supported str ("random", "sequential") or'
219                         'a `gluon.data.Sampler`, but got %s'%(sampler))
220
221    def __iter__(self):
222        # generate file samples
223        for file_idx in iter(self._file_sampler):
224            filename = self._files[file_idx]
225            yield self._dataset(filename, **self._kwargs)
226
227
228class _Prefetcher:
229    """Internal shared prefetcher logic."""
230    _dataq = None  # Data queue transmits prefetched elements
231    _controlq = None  # Control queue to instruct thread / process shutdown
232    _errorq = None  # Error queue to transmit exceptions from worker to master
233
234    _checked_start = False  # True once startup has been checkd by _check_start
235
236    def __init__(self, stream, num_prefetch, seed, np_seed, mx_seed):
237        super(_Prefetcher, self).__init__()
238        self.stream = stream
239        assert num_prefetch > 0, 'Unbounded Prefetcher is unsupported.'
240        self.num_prefetch = num_prefetch
241        self.seed = seed
242        self.np_seed = np_seed
243        self.mx_seed = mx_seed
244
245    def run(self):
246        """Method representing the process’s activity."""
247        random.seed(self.seed)
248        np.random.seed(self.np_seed)
249        if not isinstance(self, multiprocessing.Process):
250            # Calling mxnet methods in a subprocess will raise an exception if
251            # mxnet is built with GPU support
252            # https://github.com/apache/incubator-mxnet/issues/4659
253            mx.random.seed(self.mx_seed)
254
255        # Startup - Master waits for this
256        try:
257            stream_iter = iter(self.stream)
258            self._errorq.put(None)
259        except Exception as e:  # pylint: disable=broad-except
260            tb = traceback.format_exc()
261            self._errorq.put((e, tb))
262
263        # Async work
264        while True:
265            try:  # Check control queue
266                c = self._controlq.get(False)
267                if c is None:
268                    break
269                raise RuntimeError('Got unexpected control code {}'.format(repr(c)))
270            except queue.Empty:
271                pass
272            except RuntimeError as e:
273                tb = traceback.format_exc()
274                self._errorq.put((e, tb))
275                self._dataq.put(None)
276
277            try:
278                data = next(stream_iter)
279                error = None
280            except Exception as e:  # pylint: disable=broad-except
281                tb = traceback.format_exc()
282                error = (e, tb)
283                data = None
284            finally:
285                self._errorq.put(error)
286                self._dataq.put(data)
287
288    def __next__(self):
289        next_item = self._dataq.get()
290        next_error = self._errorq.get()
291
292        if next_error is None:
293            return next_item
294        else:
295            self._controlq.put(None)
296            if isinstance(next_error[0], StopIteration):
297                raise StopIteration
298            return self._reraise(*next_error)
299
300    def _reraise(self, e, tb):
301        print('Reraising exception from Prefetcher', file=sys.stderr)
302        print(tb, file=sys.stderr)
303        raise e
304
305    def _check_start(self):
306        assert not self._checked_start
307        self._checked_start = True
308        next_error = self._errorq.get(block=True)
309        if next_error is not None:
310            self._reraise(*next_error)
311
312    def next(self):
313        return self.__next__()
314
315class _ProcessPrefetcher(_Prefetcher, multiprocessing.Process):
316    """Internal multi-processing prefetcher."""
317
318    def __init__(self, *args, **kwargs):
319        super(_ProcessPrefetcher, self).__init__(*args, **kwargs)
320        self._dataq = multiprocessing.Queue(self.num_prefetch)
321        self._controlq = multiprocessing.Queue()
322        self._errorq = multiprocessing.Queue(self.num_prefetch)
323        self.daemon = True
324        self.start()
325        self._check_start()
326
327
328class _ThreadPrefetcher(_Prefetcher, threading.Thread):
329    """Internal threaded prefetcher."""
330
331    def __init__(self, *args, **kwargs):
332        super(_ThreadPrefetcher, self).__init__(*args, **kwargs)
333        self._dataq = queue.Queue(self.num_prefetch)
334        self._controlq = queue.Queue()
335        self._errorq = queue.Queue(self.num_prefetch)
336        self.daemon = True
337        self.start()
338        self._check_start()
339
340
341class PrefetchingStream(DataStream):
342    """Prefetch a DataStream in a separate Thread or Process.
343
344    This iterator will create another thread or process to perform
345    ``iter_next`` and then store the data in memory. It potentially accelerates
346    the data read, at the cost of more memory usage.
347
348    The python, numpy and mxnet random states in the launched Thread or Process
349    will be initialized randomly based on the next 32 bit integer in the
350    python, numpy and mxnet random generator of the caller respectively
351    (random.getrandbits(32), numpy.random.randint(0, 2**32),
352    int(mx.nd.random.uniform(0, 2**32).asscalar())).
353
354    Parameters
355    ----------
356    stream : DataStream
357        Source stream.
358    num_prefetch : int, default 1
359        Number of elements to prefetch from the stream. Must be greater 0.
360    worker_type : 'thread' or 'process', default 'thread'
361        Use a separate Python Thread or Process to prefetch.
362
363    """
364
365    def __init__(self, stream, num_prefetch=1, worker_type='thread'):
366        self._stream = stream
367        self._num_prefetch = num_prefetch
368        if num_prefetch < 1:
369            raise ValueError('num_prefetch must be greater 0.')
370        assert worker_type.lower() in ['thread', 'process']
371        self._multiprocessing = worker_type.lower() == 'process'
372
373    def __iter__(self):
374        seed = random.getrandbits(32)
375        # TODO should be possible to change to 64 bit in MXNet 1.6 (uses int64 by default?)
376        np_seed = np.random.randint(0, np.iinfo(np.int32).max)
377        mx_seed = int(mx.nd.random.uniform(0, np.iinfo(np.int32).max).asscalar())
378        if self._multiprocessing:
379            return _ProcessPrefetcher(self._stream, self._num_prefetch,
380                                      seed=seed, np_seed=np_seed,
381                                      mx_seed=mx_seed)
382        else:
383            return _ThreadPrefetcher(self._stream, self._num_prefetch,
384                                     seed=seed, np_seed=np_seed,
385                                     mx_seed=mx_seed)
386