1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18"""Read and write for the RecordIO data format.""" 19from collections import namedtuple 20from multiprocessing import current_process 21 22import ctypes 23import struct 24import numbers 25import numpy as np 26 27from .base import _LIB 28from .base import RecordIOHandle 29from .base import check_call 30from .base import c_str 31try: 32 import cv2 33except ImportError: 34 cv2 = None 35 36class MXRecordIO(object): 37 """Reads/writes `RecordIO` data format, supporting sequential read and write. 38 39 Examples 40 --------- 41 >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w') 42 <mxnet.recordio.MXRecordIO object at 0x10ef40ed0> 43 >>> for i in range(5): 44 ... record.write('record_%d'%i) 45 >>> record.close() 46 >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') 47 >>> for i in range(5): 48 ... item = record.read() 49 ... print(item) 50 record_0 51 record_1 52 record_2 53 record_3 54 record_4 55 >>> record.close() 56 57 Parameters 58 ---------- 59 uri : string 60 Path to the record file. 61 flag : string 62 'w' for write or 'r' for read. 63 """ 64 def __init__(self, uri, flag): 65 self.uri = c_str(uri) 66 self.handle = RecordIOHandle() 67 self.flag = flag 68 self.pid = None 69 self.is_open = False 70 self.open() 71 72 def open(self): 73 """Opens the record file.""" 74 if self.flag == "w": 75 check_call(_LIB.MXRecordIOWriterCreate(self.uri, ctypes.byref(self.handle))) 76 self.writable = True 77 elif self.flag == "r": 78 check_call(_LIB.MXRecordIOReaderCreate(self.uri, ctypes.byref(self.handle))) 79 self.writable = False 80 else: 81 raise ValueError("Invalid flag %s"%self.flag) 82 # pylint: disable=not-callable 83 # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699 84 self.pid = current_process().pid 85 self.is_open = True 86 87 def __del__(self): 88 self.close() 89 90 def __getstate__(self): 91 """Override pickling behavior.""" 92 # pickling pointer is not allowed 93 is_open = self.is_open 94 self.close() 95 d = dict(self.__dict__) 96 d['is_open'] = is_open 97 uri = self.uri.value 98 try: 99 uri = uri.decode('utf-8') 100 except AttributeError: 101 pass 102 del d['handle'] 103 d['uri'] = uri 104 return d 105 106 def __setstate__(self, d): 107 """Restore from pickled.""" 108 self.__dict__ = d 109 is_open = d['is_open'] 110 self.is_open = False 111 self.handle = RecordIOHandle() 112 self.uri = c_str(self.uri) 113 if is_open: 114 self.open() 115 116 def _check_pid(self, allow_reset=False): 117 """Check process id to ensure integrity, reset if in new process.""" 118 # pylint: disable=not-callable 119 # It's bug from pylint(astroid). See https://github.com/PyCQA/pylint/issues/1699 120 if not self.pid == current_process().pid: 121 if allow_reset: 122 self.reset() 123 else: 124 raise RuntimeError("Forbidden operation in multiple processes") 125 126 def close(self): 127 """Closes the record file.""" 128 if not self.is_open: 129 return 130 if self.writable: 131 check_call(_LIB.MXRecordIOWriterFree(self.handle)) 132 else: 133 check_call(_LIB.MXRecordIOReaderFree(self.handle)) 134 self.is_open = False 135 self.pid = None 136 137 def reset(self): 138 """Resets the pointer to first item. 139 140 If the record is opened with 'w', this function will truncate the file to empty. 141 142 Examples 143 --------- 144 >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') 145 >>> for i in range(2): 146 ... item = record.read() 147 ... print(item) 148 record_0 149 record_1 150 >>> record.reset() # Pointer is reset. 151 >>> print(record.read()) # Started reading from start again. 152 record_0 153 >>> record.close() 154 """ 155 self.close() 156 self.open() 157 158 def write(self, buf): 159 """Inserts a string buffer as a record. 160 161 Examples 162 --------- 163 >>> record = mx.recordio.MXRecordIO('tmp.rec', 'w') 164 >>> for i in range(5): 165 ... record.write('record_%d'%i) 166 >>> record.close() 167 168 Parameters 169 ---------- 170 buf : string (python2), bytes (python3) 171 Buffer to write. 172 """ 173 assert self.writable 174 self._check_pid(allow_reset=False) 175 check_call(_LIB.MXRecordIOWriterWriteRecord(self.handle, 176 ctypes.c_char_p(buf), 177 ctypes.c_size_t(len(buf)))) 178 179 def read(self): 180 """Returns record as a string. 181 182 Examples 183 --------- 184 >>> record = mx.recordio.MXRecordIO('tmp.rec', 'r') 185 >>> for i in range(5): 186 ... item = record.read() 187 ... print(item) 188 record_0 189 record_1 190 record_2 191 record_3 192 record_4 193 >>> record.close() 194 195 Returns 196 ---------- 197 buf : string 198 Buffer read. 199 """ 200 assert not self.writable 201 # trying to implicitly read from multiple processes is forbidden, 202 # there's no elegant way to handle unless lock is introduced 203 self._check_pid(allow_reset=False) 204 buf = ctypes.c_char_p() 205 size = ctypes.c_size_t() 206 check_call(_LIB.MXRecordIOReaderReadRecord(self.handle, 207 ctypes.byref(buf), 208 ctypes.byref(size))) 209 if buf: 210 buf = ctypes.cast(buf, ctypes.POINTER(ctypes.c_char*size.value)) 211 return buf.contents.raw 212 else: 213 return None 214 215class MXIndexedRecordIO(MXRecordIO): 216 """Reads/writes `RecordIO` data format, supporting random access. 217 218 Examples 219 --------- 220 >>> for i in range(5): 221 ... record.write_idx(i, 'record_%d'%i) 222 >>> record.close() 223 >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') 224 >>> record.read_idx(3) 225 record_3 226 227 Parameters 228 ---------- 229 idx_path : str 230 Path to the index file. 231 uri : str 232 Path to the record file. Only supports seekable file types. 233 flag : str 234 'w' for write or 'r' for read. 235 key_type : type 236 Data type for keys. 237 """ 238 def __init__(self, idx_path, uri, flag, key_type=int): 239 self.idx_path = idx_path 240 self.idx = {} 241 self.keys = [] 242 self.key_type = key_type 243 self.fidx = None 244 super(MXIndexedRecordIO, self).__init__(uri, flag) 245 246 def open(self): 247 super(MXIndexedRecordIO, self).open() 248 self.idx = {} 249 self.keys = [] 250 self.fidx = open(self.idx_path, self.flag) 251 if not self.writable: 252 for line in iter(self.fidx.readline, ''): 253 line = line.strip().split('\t') 254 key = self.key_type(line[0]) 255 self.idx[key] = int(line[1]) 256 self.keys.append(key) 257 258 def close(self): 259 """Closes the record file.""" 260 if not self.is_open: 261 return 262 super(MXIndexedRecordIO, self).close() 263 self.fidx.close() 264 265 def __getstate__(self): 266 """Override pickling behavior.""" 267 d = super(MXIndexedRecordIO, self).__getstate__() 268 d['fidx'] = None 269 return d 270 271 def seek(self, idx): 272 """Sets the current read pointer position. 273 274 This function is internally called by `read_idx(idx)` to find the current 275 reader pointer position. It doesn't return anything.""" 276 assert not self.writable 277 self._check_pid(allow_reset=True) 278 pos = ctypes.c_size_t(self.idx[idx]) 279 check_call(_LIB.MXRecordIOReaderSeek(self.handle, pos)) 280 281 def tell(self): 282 """Returns the current position of write head. 283 284 Examples 285 --------- 286 >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') 287 >>> print(record.tell()) 288 0 289 >>> for i in range(5): 290 ... record.write_idx(i, 'record_%d'%i) 291 ... print(record.tell()) 292 16 293 32 294 48 295 64 296 80 297 """ 298 assert self.writable 299 pos = ctypes.c_size_t() 300 check_call(_LIB.MXRecordIOWriterTell(self.handle, ctypes.byref(pos))) 301 return pos.value 302 303 def read_idx(self, idx): 304 """Returns the record at given index. 305 306 Examples 307 --------- 308 >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'w') 309 >>> for i in range(5): 310 ... record.write_idx(i, 'record_%d'%i) 311 >>> record.close() 312 >>> record = mx.recordio.MXIndexedRecordIO('tmp.idx', 'tmp.rec', 'r') 313 >>> record.read_idx(3) 314 record_3 315 """ 316 self.seek(idx) 317 return self.read() 318 319 def write_idx(self, idx, buf): 320 """Inserts input record at given index. 321 322 Examples 323 --------- 324 >>> for i in range(5): 325 ... record.write_idx(i, 'record_%d'%i) 326 >>> record.close() 327 328 Parameters 329 ---------- 330 idx : int 331 Index of a file. 332 buf : 333 Record to write. 334 """ 335 key = self.key_type(idx) 336 pos = self.tell() 337 self.write(buf) 338 self.fidx.write('%s\t%d\n'%(str(key), pos)) 339 self.idx[key] = pos 340 self.keys.append(key) 341 342 343IRHeader = namedtuple('HEADER', ['flag', 'label', 'id', 'id2']) 344"""An alias for HEADER. Used to store metadata (e.g. labels) accompanying a record. 345See mxnet.recordio.pack and mxnet.recordio.pack_img for example uses. 346 347Parameters 348---------- 349 flag : int 350 Available for convenience, can be set arbitrarily. 351 label : float or an array of float 352 Typically used to store label(s) for a record. 353 id: int 354 Usually a unique id representing record. 355 id2: int 356 Higher order bits of the unique id, should be set to 0 (in most cases). 357""" 358_IR_FORMAT = 'IfQQ' 359_IR_SIZE = struct.calcsize(_IR_FORMAT) 360 361def pack(header, s): 362 """Pack a string into MXImageRecord. 363 364 Parameters 365 ---------- 366 header : IRHeader 367 Header of the image record. 368 ``header.label`` can be a number or an array. See more detail in ``IRHeader``. 369 s : str 370 Raw image string to be packed. 371 372 Returns 373 ------- 374 s : str 375 The packed string. 376 377 Examples 378 -------- 379 >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3] 380 >>> id = 2574 381 >>> header = mx.recordio.IRHeader(0, label, id, 0) 382 >>> with open(path, 'r') as file: 383 ... s = file.read() 384 >>> packed_s = mx.recordio.pack(header, s) 385 """ 386 header = IRHeader(*header) 387 if isinstance(header.label, numbers.Number): 388 header = header._replace(flag=0) 389 else: 390 label = np.asarray(header.label, dtype=np.float32) 391 header = header._replace(flag=label.size, label=0) 392 s = label.tostring() + s 393 s = struct.pack(_IR_FORMAT, *header) + s 394 return s 395 396def unpack(s): 397 """Unpack a MXImageRecord to string. 398 399 Parameters 400 ---------- 401 s : str 402 String buffer from ``MXRecordIO.read``. 403 404 Returns 405 ------- 406 header : IRHeader 407 Header of the image record. 408 s : str 409 Unpacked string. 410 411 Examples 412 -------- 413 >>> record = mx.recordio.MXRecordIO('test.rec', 'r') 414 >>> item = record.read() 415 >>> header, s = mx.recordio.unpack(item) 416 >>> header 417 HEADER(flag=0, label=14.0, id=20129312, id2=0) 418 """ 419 header = IRHeader(*struct.unpack(_IR_FORMAT, s[:_IR_SIZE])) 420 s = s[_IR_SIZE:] 421 if header.flag > 0: 422 header = header._replace(label=np.frombuffer(s, np.float32, header.flag)) 423 s = s[header.flag*4:] 424 return header, s 425 426def unpack_img(s, iscolor=-1): 427 """Unpack a MXImageRecord to image. 428 429 Parameters 430 ---------- 431 s : str 432 String buffer from ``MXRecordIO.read``. 433 iscolor : int 434 Image format option for ``cv2.imdecode``. 435 436 Returns 437 ------- 438 header : IRHeader 439 Header of the image record. 440 img : numpy.ndarray 441 Unpacked image. 442 443 Examples 444 -------- 445 >>> record = mx.recordio.MXRecordIO('test.rec', 'r') 446 >>> item = record.read() 447 >>> header, img = mx.recordio.unpack_img(item) 448 >>> header 449 HEADER(flag=0, label=14.0, id=20129312, id2=0) 450 >>> img 451 array([[[ 23, 27, 45], 452 [ 28, 32, 50], 453 ..., 454 [ 36, 40, 59], 455 [ 35, 39, 58]], 456 ..., 457 [[ 91, 92, 113], 458 [ 97, 98, 119], 459 ..., 460 [168, 169, 167], 461 [166, 167, 165]]], dtype=uint8) 462 """ 463 header, s = unpack(s) 464 img = np.frombuffer(s, dtype=np.uint8) 465 assert cv2 is not None 466 img = cv2.imdecode(img, iscolor) 467 return header, img 468 469def pack_img(header, img, quality=95, img_fmt='.jpg'): 470 """Pack an image into ``MXImageRecord``. 471 472 Parameters 473 ---------- 474 header : IRHeader 475 Header of the image record. 476 ``header.label`` can be a number or an array. See more detail in ``IRHeader``. 477 img : numpy.ndarray 478 Image to be packed. 479 quality : int 480 Quality for JPEG encoding in range 1-100, or compression for PNG encoding in range 1-9. 481 img_fmt : str 482 Encoding of the image (.jpg for JPEG, .png for PNG). 483 484 Returns 485 ------- 486 s : str 487 The packed string. 488 489 Examples 490 -------- 491 >>> label = 4 # label can also be a 1-D array, for example: label = [1,2,3] 492 >>> id = 2574 493 >>> header = mx.recordio.IRHeader(0, label, id, 0) 494 >>> img = cv2.imread('test.jpg') 495 >>> packed_s = mx.recordio.pack_img(header, img) 496 """ 497 assert cv2 is not None 498 jpg_formats = ['.JPG', '.JPEG'] 499 png_formats = ['.PNG'] 500 encode_params = None 501 if img_fmt.upper() in jpg_formats: 502 encode_params = [cv2.IMWRITE_JPEG_QUALITY, quality] 503 elif img_fmt.upper() in png_formats: 504 encode_params = [cv2.IMWRITE_PNG_COMPRESSION, quality] 505 506 ret, buf = cv2.imencode(img_fmt, img, encode_params) 507 assert ret, 'failed to encode image' 508 return pack(header, buf.tostring()) 509