1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# pylint: skip-file
19import sys
20import mxnet as mx
21import numpy as np
22import tempfile
23import random
24import string
25from common import setup_module, with_seed, teardown
26
27@with_seed()
28def test_recordio():
29    frec = tempfile.mktemp()
30    N = 255
31
32    writer = mx.recordio.MXRecordIO(frec, 'w')
33    for i in range(N):
34        writer.write(bytes(str(chr(i)), 'utf-8'))
35    del writer
36
37    reader = mx.recordio.MXRecordIO(frec, 'r')
38    for i in range(N):
39        res = reader.read()
40        assert res == bytes(str(chr(i)), 'utf-8')
41
42@with_seed()
43def test_indexed_recordio():
44    fidx = tempfile.mktemp()
45    frec = tempfile.mktemp()
46    N = 255
47
48    writer = mx.recordio.MXIndexedRecordIO(fidx, frec, 'w')
49    for i in range(N):
50        writer.write_idx(i, bytes(str(chr(i)), 'utf-8'))
51    del writer
52
53    reader = mx.recordio.MXIndexedRecordIO(fidx, frec, 'r')
54    keys = reader.keys
55    assert sorted(keys) == [i for i in range(N)]
56    random.shuffle(keys)
57    for i in keys:
58        res = reader.read_idx(i)
59        assert res == bytes(str(chr(i)), 'utf-8')
60
61@with_seed()
62def test_recordio_pack_label():
63    frec = tempfile.mktemp()
64    N = 255
65
66    for i in range(1, N):
67        for j in range(N):
68            content = ''.join(random.choice(string.ascii_uppercase + string.digits) for _ in range(j))
69            content = content.encode('utf-8')
70            label = np.random.uniform(size=i).astype(np.float32)
71            header = (0, label, 0, 0)
72            s = mx.recordio.pack(header, content)
73            rheader, rcontent = mx.recordio.unpack(s)
74            assert (label == rheader.label).all()
75            assert content == rcontent
76
77if __name__ == '__main__':
78    test_recordio_pack_label()
79    test_recordio()
80    test_indexed_recordio()
81