1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18"""Read and write for the RecordIO data format."""
19from collections import namedtuple
20from multiprocessing import current_process
21
22import ctypes
23import struct
24import numbers
25import numpy as np
26
27from .base import _LIB
28from .base import RecordIOHandle
29from .base import check_call
30from .base import c_str
31try:
32    import cv2
33except ImportError:
34    cv2 = None
35
36class MXRecordIO(object):
37    """Reads/writes `RecordIO` data format, supporting sequential read and write.
38
39    Examples
40    ---------
41    >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w')
42    <mxnet.recordio.MXRecordIO object at 0x10ef40ed0>
43    >>> for i in range(5):
44    ...    record.write('record_%d'%i)
45    >>> record.close()
46    >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r')
47    >>> for i in range(5):
48    ...    item = record.read()
49    ...    print(item)
50    record_0
51    record_1
52    record_2
53    record_3
54    record_4
55    >>> record.close()
56
57    Parameters
58    ----------
59    uri : string
60        Path to the record file.
61    flag : string
62        'w' for write or 'r' for read.
63    """
64    def __init__(self, uri, flag):
65        self.uri = c_str(uri)
66        self.handle = RecordIOHandle()
67        self.flag = flag
68        self.pid = None
69        self.is_open = False
70        self.open()
71
72    def open(self):
73        """Opens the record file."""
74        if self.flag == "w":
75            check_call(_LIB.MXRecordIOWriterCreate(self.uri, ctypes.byref(self.handle)))
76            self.writable = True
77        elif self.flag == "r":
78            check_call(_LIB.MXRecordIOReaderCreate(self.uri, ctypes.byref(self.handle)))
79            self.writable = False
80        else:
81            raise ValueError("Invalid flag %s"%self.flag)
82        # pylint: disable=not-callable
83        # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699
84        self.pid = current_process().pid
85        self.is_open = True
86
87    def __del__(self):
88        self.close()
89
90    def __getstate__(self):
91        """Override pickling behavior."""
92        # pickling pointer is not allowed
93        is_open = self.is_open
94        self.close()
95        d = dict(self.__dict__)
96        d['is_open'] = is_open
97        uri = self.uri.value
98        try:
99            uri = uri.decode('utf-8')
100        except AttributeError:
101            pass
102        del d['handle']
103        d['uri'] = uri
104        return d
105
106    def __setstate__(self, d):
107        """Restore from pickled."""
108        self.__dict__ = d
109        is_open = d['is_open']
110        self.is_open = False
111        self.handle = RecordIOHandle()
112        self.uri = c_str(self.uri)
113        if is_open:
114            self.open()
115
116    def _check_pid(self, allow_reset=False):
117        """Check process id to ensure integrity, reset if in new process."""
118        # pylint: disable=not-callable
119        # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699
120        if not self.pid == current_process().pid:
121            if allow_reset:
122                self.reset()
123            else:
124                raise RuntimeError("Forbidden operation in multiple processes")
125
126    def close(self):
127        """Closes the record file."""
128        if not self.is_open:
129            return
130        if self.writable:
131            check_call(_LIB.MXRecordIOWriterFree(self.handle))
132        else:
133            check_call(_LIB.MXRecordIOReaderFree(self.handle))
134        self.is_open = False
135        self.pid = None
136
137    def reset(self):
138        """Resets the pointer to first item.
139
140        If the record is opened with 'w', this function will truncate the file to empty.
141
142        Examples
143        ---------
144        >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r')
145        >>> for i in range(2):
146        ...    item = record.read()
147        ...    print(item)
148        record_0
149        record_1
150        >>> record.reset()  # Pointer is reset.
151        >>> print(record.read()) # Started reading from start again.
152        record_0
153        >>> record.close()
154        """
155        self.close()
156        self.open()
157
158    def write(self, buf):
159        """Inserts a string buffer as a record.
160
161        Examples
162        ---------
163        >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w')
164        >>> for i in range(5):
165        ...    record.write('record_%d'%i)
166        >>> record.close()
167
168        Parameters
169        ----------
170        buf : string (python2), bytes (python3)
171            Buffer to write.
172        """
173        assert self.writable
174        self._check_pid(allow_reset=False)
175        check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle,
176                                                    ctypes.c_char_p(buf),
177                                                    ctypes.c_size_t(len(buf))))
178
179    def read(self):
180        """Returns record as a string.
181
182        Examples
183        ---------
184        >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r')
185        >>> for i in range(5):
186        ...    item = record.read()
187        ...    print(item)
188        record_0
189        record_1
190        record_2
191        record_3
192        record_4
193        >>> record.close()
194
195        Returns
196        ----------
197        buf : string
198            Buffer read.
199        """
200        assert not self.writable
201        # trying to implicitly read from multiple processes is forbidden,
202        # there's no elegant way to handle unless lock is introduced
203        self._check_pid(allow_reset=False)
204        buf = ctypes.c_char_p()
205        size = ctypes.c_size_t()
206        check_call(_LIB.MXRecordIOReaderReadRecord(self.handle,
207                                                   ctypes.byref(buf),
208                                                   ctypes.byref(size)))
209        if buf:
210            buf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char*size.value))
211            return buf.contents.raw
212        else:
213            return None
214
215class MXIndexedRecordIO(MXRecordIO):
216    """Reads/writes `RecordIO` data format, supporting random access.
217
218    Examples
219    ---------
220    >>> for i in range(5):
221    ...     record.write_idx(i, 'record_%d'%i)
222    >>> record.close()
223    >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r')
224    >>> record.read_idx(3)
225    record_3
226
227    Parameters
228    ----------
229    idx_path : str
230        Path to the index file.
231    uri : str
232        Path to the record file. Only supports seekable file types.
233    flag : str
234        'w' for write or 'r' for read.
235    key_type : type
236        Data type for keys.
237    """
238    def __init__(self, idx_path, uri, flag, key_type=int):
239        self.idx_path = idx_path
240        self.idx = {}
241        self.keys = []
242        self.key_type = key_type
243        self.fidx = None
244        super(MXIndexedRecordIO, self).__init__(uri, flag)
245
246    def open(self):
247        super(MXIndexedRecordIO, self).open()
248        self.idx = {}
249        self.keys = []
250        self.fidx = open(self.idx_path, self.flag)
251        if not self.writable:
252            for line in iter(self.fidx.readline, ''):
253                line = line.strip().split('\t')
254                key = self.key_type(line[0])
255                self.idx[key] = int(line[1])
256                self.keys.append(key)
257
258    def close(self):
259        """Closes the record file."""
260        if not self.is_open:
261            return
262        super(MXIndexedRecordIO, self).close()
263        self.fidx.close()
264
265    def __getstate__(self):
266        """Override pickling behavior."""
267        d = super(MXIndexedRecordIO, self).__getstate__()
268        d['fidx'] = None
269        return d
270
271    def seek(self, idx):
272        """Sets the current read pointer position.
273
274        This function is internally called by `read_idx(idx)` to find the current
275        reader pointer position. It doesn't return anything."""
276        assert not self.writable
277        self._check_pid(allow_reset=True)
278        pos = ctypes.c_size_t(self.idx[idx])
279        check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos))
280
281    def tell(self):
282        """Returns the current position of write head.
283
284        Examples
285        ---------
286        >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w')
287        >>> print(record.tell())
288        0
289        >>> for i in range(5):
290        ...     record.write_idx(i, 'record_%d'%i)
291        ...     print(record.tell())
292        16
293        32
294        48
295        64
296        80
297        """
298        assert self.writable
299        pos = ctypes.c_size_t()
300        check_call(_LIB.MXRecordIOWriterTell(self.handle, ctypes.byref(pos)))
301        return pos.value
302
303    def read_idx(self, idx):
304        """Returns the record at given index.
305
306        Examples
307        ---------
308        >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w')
309        >>> for i in range(5):
310        ...     record.write_idx(i, 'record_%d'%i)
311        >>> record.close()
312        >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r')
313        >>> record.read_idx(3)
314        record_3
315        """
316        self.seek(idx)
317        return self.read()
318
319    def write_idx(self, idx, buf):
320        """Inserts input record at given index.
321
322        Examples
323        ---------
324        >>> for i in range(5):
325        ...     record.write_idx(i, 'record_%d'%i)
326        >>> record.close()
327
328        Parameters
329        ----------
330        idx : int
331            Index of a file.
332        buf :
333            Record to write.
334        """
335        key = self.key_type(idx)
336        pos = self.tell()
337        self.write(buf)
338        self.fidx.write('%s\t%d\n'%(str(key), pos))
339        self.idx[key] = pos
340        self.keys.append(key)
341
342
343IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2'])
344"""An alias for HEADER. Used to store metadata (e.g. labels) accompanying a record.
345See mxnet.recordio.pack and mxnet.recordio.pack_img for example uses.
346
347Parameters
348----------
349    flag : int
350        Available for convenience, can be set arbitrarily.
351    label : float or an array of float
352        Typically used to store label(s) for a record.
353    id: int
354        Usually a unique id representing record.
355    id2: int
356        Higher order bits of the unique id, should be set to 0 (in most cases).
357"""
358_IR_FORMAT = 'IfQQ'
359_IR_SIZE = struct.calcsize(_IR_FORMAT)
360
361def pack(header, s):
362    """Pack a string into MXImageRecord.
363
364    Parameters
365    ----------
366    header : IRHeader
367        Header of the image record.
368        ``header.label`` can be a number or an array. See more detail in ``IRHeader``.
369    s : str
370        Raw image string to be packed.
371
372    Returns
373    -------
374    s : str
375        The packed string.
376
377    Examples
378    --------
379    >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3]
380    >>> id = 2574
381    >>> header = mx.recordio.IRHeader(0, label, id, 0)
382    >>> with open(path, 'r') as file:
383    ...     s = file.read()
384    >>> packed_s = mx.recordio.pack(header, s)
385    """
386    header = IRHeader(*header)
387    if isinstance(header.label, numbers.Number):
388        header = header._replace(flag=0)
389    else:
390        label = np.asarray(header.label, dtype=np.float32)
391        header = header._replace(flag=label.size, label=0)
392        s = label.tostring() + s
393    s = struct.pack(_IR_FORMAT, *header) + s
394    return s
395
396def unpack(s):
397    """Unpack a MXImageRecord to string.
398
399    Parameters
400    ----------
401    s : str
402        String buffer from ``MXRecordIO.read``.
403
404    Returns
405    -------
406    header : IRHeader
407        Header of the image record.
408    s : str
409        Unpacked string.
410
411    Examples
412    --------
413    >>> record = mx.recordio.MXRecordIO('test.rec', 'r')
414    >>> item = record.read()
415    >>> header, s = mx.recordio.unpack(item)
416    >>> header
417    HEADER(flag=0, label=14.0, id=20129312, id2=0)
418    """
419    header = IRHeader(*struct.unpack(_IR_FORMAT, s[:_IR_SIZE]))
420    s = s[_IR_SIZE:]
421    if header.flag > 0:
422        header = header._replace(label=np.frombuffer(s, np.float32, header.flag))
423        s = s[header.flag*4:]
424    return header, s
425
426def unpack_img(s, iscolor=-1):
427    """Unpack a MXImageRecord to image.
428
429    Parameters
430    ----------
431    s : str
432        String buffer from ``MXRecordIO.read``.
433    iscolor : int
434        Image format option for ``cv2.imdecode``.
435
436    Returns
437    -------
438    header : IRHeader
439        Header of the image record.
440    img : numpy.ndarray
441        Unpacked image.
442
443    Examples
444    --------
445    >>> record = mx.recordio.MXRecordIO('test.rec', 'r')
446    >>> item = record.read()
447    >>> header, img = mx.recordio.unpack_img(item)
448    >>> header
449    HEADER(flag=0, label=14.0, id=20129312, id2=0)
450    >>> img
451    array([[[ 23,  27,  45],
452            [ 28,  32,  50],
453            ...,
454            [ 36,  40,  59],
455            [ 35,  39,  58]],
456           ...,
457           [[ 91,  92, 113],
458            [ 97,  98, 119],
459            ...,
460            [168, 169, 167],
461            [166, 167, 165]]], dtype=uint8)
462    """
463    header, s = unpack(s)
464    img = np.frombuffer(s, dtype=np.uint8)
465    assert cv2 is not None
466    img = cv2.imdecode(img, iscolor)
467    return header, img
468
469def pack_img(header, img, quality=95, img_fmt='.jpg'):
470    """Pack an image into ``MXImageRecord``.
471
472    Parameters
473    ----------
474    header : IRHeader
475        Header of the image record.
476        ``header.label`` can be a number or an array. See more detail in ``IRHeader``.
477    img : numpy.ndarray
478        Image to be packed.
479    quality : int
480        Quality for JPEG encoding in range 1-100, or compression for PNG encoding in range 1-9.
481    img_fmt : str
482        Encoding of the image (.jpg for JPEG, .png for PNG).
483
484    Returns
485    -------
486    s : str
487        The packed string.
488
489    Examples
490    --------
491    >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3]
492    >>> id = 2574
493    >>> header = mx.recordio.IRHeader(0, label, id, 0)
494    >>> img = cv2.imread('test.jpg')
495    >>> packed_s = mx.recordio.pack_img(header, img)
496    """
497    assert cv2 is not None
498    jpg_formats = ['.JPG', '.JPEG']
499    png_formats = ['.PNG']
500    encode_params = None
501    if img_fmt.upper() in jpg_formats:
502        encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality]
503    elif img_fmt.upper() in png_formats:
504        encode_params = [cv2.IMWRITE_PNG_COMPRESSION, quality]
505
506    ret, buf = cv2.imencode(img_fmt, img, encode_params)
507    assert ret, 'failed to encode image'
508    return pack(header, buf.tostring())
509