1import os 2import json 3import logging 4import hashlib 5import pickle 6import numpy as np 7from collections import defaultdict 8from pathlib import Path 9 10from .io.video_io import VideoSortedFolderReader, VideoFrameReader 11from .utils.serialization_utils import save_json, load_json, ComplexEncoder, \ 12 save_pickle, load_pickle 13 14 15logging.basicConfig(level=logging.INFO) 16log = logging.getLogger() 17 18 19class FieldNames: 20 """ 21 Keys found in annotation dict 22 """ 23 # Entity Fields 24 TIME = 'time' 25 LABELS = 'labels' 26 ID = 'id' 27 BLOB = 'blob' 28 29 # Sample fields 30 METADATA = 'metadata' 31 ENTITY_LIST = 'entities' 32 SAMPLE_FILE = 'sample_file' 33 34 # Dataset fields 35 SAMPLE_DICT = 'samples' 36 SAMPLES = 'samples' 37 DATASET_METADATA = 'metadata' 38 DATASET_VERSION = 'version' 39 CLASS_MAPPING = 'class_mapping' 40 41 # Data fields 42 DATE_ADDED = 'date_added' 43 DURATION = 'duration' 44 DATA_PATH = 'data_path' 45 FILENAME = 'filename' 46 BASE_DIR = 'base_dir' 47 FILE_EXT = 'file_ext' 48 FPS = 'fps' 49 NUM_FRAMES = 'number_of_frames' 50 RESOLUTION = 'resolution' 51 WIDTH = 'width' 52 HEIGHT = 'height' 53 DATASET_SOURCE = 'data_source' 54 SOURCE_ID = 'source_id' 55 SAMPLE_SOURCE = 'sample_source' 56 ORIG_ID = 'orig_id' 57 TEMPORAL_SEGMENTS = 'temporal_segments' 58 DESCRIPTION = 'description' 59 BOUNDING_BOXES = 'bb' 60 HEAD_BBOX = "bb_head" 61 FACE_BBOX = "bb_face" 62 MASK = 'mask' 63 KEYPOINTS = 'keypoints' 64 CONFIDENCE = 'confidence' 65 DEFAULT_VALUES = 'default_values' 66 DEFAULT_GT = 'default_gt' 67 FORMAT_VERSION = 'format_version' 68 DATE_MODIFIED = 'last_modified' 69 CHANGES = 'changes' 70 KEY_HASH = 'key_hash' 71 FRAME_IDX = 'frame_idx' 72 73 @classmethod 74 def get_key_hash(cls): 75 return 0 76 77 78class SplitNames: 79 TRAIN = "train" 80 VAL = "val" 81 TEST = "test" 82 83 84def get_instance_list(cls, raw_entity_list): 85 return [cls(raw_info=x) for x in raw_entity_list] 86 87 88class AnnoEntity: 89 """ 90 One annotation entity. 91 A entity can refer to a video label, or a time segment annotation, or a frame annotation 92 It has a required field "time", which should be in milliseconds 93 Other than that, there are a few optional fields: 94 id: if the entity has an id 95 labels: if the entity has some categorical labels 96 blob: other annotation information of this entity 97 """ 98 99 def __init__(self, time=None, id=None, raw_info=None, validate=True): 100 self._raw_info = raw_info 101 self._time = None 102 self._labels = None 103 self._mask = None 104 self._blob = {} 105 106 self.id = None 107 self.bbox = None 108 self.keypoints = None 109 self.confidence = None 110 111 if raw_info is not None: 112 if time is not None or id is not None: 113 log.warning("time or id were specified, but raw_info was set and so they will be ignored") 114 self._parse_info() 115 else: 116 if validate: 117 self.time = time 118 else: 119 self._time = time 120 self.id = id 121 122 @property 123 def time(self): 124 return self._time 125 126 @time.setter 127 def time(self, time): 128 if isinstance(time, (int, float)): 129 assert time >= 0 130 self._time = time 131 elif isinstance(time, str): 132 self._time = float(time) 133 elif isinstance(time, tuple): 134 assert len(time) == 2 135 self._time = float(time[0]), float(time[1]) 136 137 @property 138 def labels(self): 139 return self._labels 140 141 @labels.setter 142 def labels(self, new_labels): 143 if self._labels: 144 # compare keys and report warning 145 diff_keys = set(self._labels.keys()).difference(new_labels.keys()) 146 log.warning("updating the labels with different keys: {}".format(diff_keys)) 147 self._labels = new_labels 148 149 150 @property 151 def keypoints_xyv(self): 152 return np.vstack([np.asarray(self.keypoints[0::3], dtype=np.float32), 153 np.asarray(self.keypoints[1::3], dtype=np.float32), 154 np.asarray(self.keypoints[2::3], dtype=np.float32)]).T 155 156 @keypoints_xyv.setter 157 def keypoints_xyv(self, kp_xyv): 158 self.keypoints = list(np.asarray(kp_xyv).flatten()) 159 160 @property 161 def mask(self): 162 return self._mask 163 164 @mask.setter 165 def mask(self, mask): 166 self._mask = mask 167 168 @property 169 def frame_num(self): 170 if FieldNames.FRAME_IDX not in self.blob: 171 return None 172 return self.blob[FieldNames.FRAME_IDX] 173 174 @frame_num.setter 175 def frame_num(self, frame_num): 176 self.blob[FieldNames.FRAME_IDX] = frame_num 177 178 @property 179 def blob(self): 180 return self._blob 181 182 @blob.setter 183 def blob(self, new_blob): 184 if self._blob: 185 # compare keys and report warning 186 diff_keys = set(self._blob.keys()).difference(new_blob.keys()) 187 log.warning("updating the blob with different keys: {}".format(diff_keys)) 188 self._blob = new_blob 189 190 def to_dict(self): 191 FN = FieldNames 192 out = dict() 193 194 out[FN.TIME] = self.time 195 196 if self.id is not None: 197 out[FN.ID] = self.id 198 199 if self.labels is not None: 200 out[FN.LABELS] = self.labels 201 202 if self.confidence is not None: 203 out[FN.CONFIDENCE] = self.confidence 204 205 if self.bbox is not None: 206 out[FN.BOUNDING_BOXES] = self.bbox 207 208 if self.mask is not None: 209 out[FN.MASK] = self.mask 210 211 if self.keypoints is not None: 212 out[FN.KEYPOINTS] = self.keypoints 213 214 if self.blob: 215 out[FN.BLOB] = self.blob 216 217 return out 218 219 def _parse_info(self): 220 FN = FieldNames 221 222 raw_info = self._raw_info 223 224 self._time = raw_info[FN.TIME] 225 226 self.id = raw_info.get(FN.ID, None) 227 self.bbox = raw_info.get(FN.BOUNDING_BOXES, None) 228 self.keypoints = raw_info.get(FN.KEYPOINTS, None) 229 self.confidence = raw_info.get(FN.CONFIDENCE, 1.0) 230 self._labels = raw_info.get(FN.LABELS, None) 231 self._mask = raw_info.get(FN.MASK, None) 232 self._blob = raw_info.get(FN.BLOB, {}) 233 234 ### helper functions ### 235 @property 236 def x(self): 237 return self.bbox[0] if self.bbox else None 238 @property 239 def y(self): 240 return self.bbox[1] if self.bbox else None 241 @property 242 def w(self): 243 return self.bbox[2] if self.bbox else None 244 @property 245 def h(self): 246 return self.bbox[3] if self.bbox else None 247 248 249class DataReader: 250 def __init__(self, data_sample, max_frame_deviation=1, fps=None): 251 252 self._data_sample = data_sample 253 self._data_path = data_sample.data_path 254 self._frame_reader = None 255 self._frame_iter = None 256 257 if fps is None: 258 fps = data_sample.fps 259 self._fps = fps 260 261 self._max_frame_deviation = max_frame_deviation 262 self._period = 1.0 / fps * 1000 263 self._time_diff_cutoff = int(self._max_frame_deviation * self._period) - 1 264 265 self._frame_reader = self._data_sample.frame_reader 266 267 def __iter__(self): 268 self._frame_iter = iter(self._frame_reader) 269 return self 270 271 def __next__(self): 272 frame, ts = next(self._frame_iter) 273 if frame is None: 274 raise StopIteration 275 entities = self._data_sample.get_entities_near_time(int(ts), self._time_diff_cutoff) 276 return frame, ts, entities 277 278 def __getitem__(self, item): 279 frame, ts = self._frame_reader[item] 280 entities = self._data_sample.get_entities_near_time(int(ts), self._time_diff_cutoff) 281 return frame, ts, entities 282 283 284class DataSample: 285 """ 286 One sample in the dataset. This can be a video file or an image file. 287 It contains a list of entities as annotations. 288 Each data sample can have some meta data. For example in videos, one can have 289 FPS 290 Duration 291 Number of frames 292 Source id 293 ... 294 """ 295 _NO_SERIALIZE_FIELDS = ("_dataset",) 296 297 def __init__(self, id, raw_info=None, root_path=None, metadata=None, dataset=None): 298 self._raw_info = raw_info 299 self._id = id 300 self._entities = [] 301 if metadata is not None: 302 self._metadata = dict(metadata) 303 else: 304 self._metadata = {} 305 self._dataset = dataset 306 self._filepath = None 307 self._lazy_loaded = False 308 self._root_path = None 309 self._data_root_path = None 310 self._cache_root_path = None 311 self._raw_entities = None 312 self._time_entity_dict = None 313 self._entity_times = None 314 self._times_unsorted = True 315 self._id_entity_dict = None 316 self._frame_entity_dict = None 317 self._label_entity_dict = None 318 self._init_entity_fields() 319 320 self.set_root_path(root_path) 321 if self._raw_info: 322 self._parse() 323 324 def set_root_path(self, root_path): 325 self._root_path = root_path 326 if root_path: 327 self._data_root_path = GluonCVMotionDataset.get_data_path_from_root(root_path) 328 self._cache_root_path = GluonCVMotionDataset.get_cache_path_from_root(root_path) 329 330 def _set_filepath(self, filepath, already_loaded=False): 331 self._filepath = filepath 332 self._lazy_loaded = already_loaded 333 334 @property 335 def id(self): 336 return self._id 337 338 @property 339 def metadata(self): 340 return self._metadata 341 342 @metadata.setter 343 def metadata(self, new_md): 344 if self._metadata: 345 # compare keys and report warning 346 diff_keys = set(self._metadata.keys()).difference(new_md.keys()) 347 log.warning("updating the metadata with different keys: {}".format(diff_keys)) 348 self._metadata = new_md 349 350 def _lazy_init(self): 351 if self._filepath and not self._lazy_loaded: 352 self._lazy_load() 353 if (self._raw_entities is not None) and not self._entities: 354 self._entities = get_instance_list(AnnoEntity, self._raw_entities) 355 self._raw_entities = None 356 self._init_entity_fields() 357 358 def __len__(self): 359 return self.metadata[FieldNames.NUM_FRAMES] 360 361 @property 362 def duration(self): 363 return self.metadata[FieldNames.DURATION] 364 365 @property 366 def entities(self) -> [AnnoEntity]: 367 self._lazy_init() 368 return self._entities 369 370 @property 371 def data_relative_path(self): 372 if FieldNames.DATA_PATH in self.metadata: 373 data_path = self.metadata[FieldNames.DATA_PATH] 374 if FieldNames.BASE_DIR in self.metadata: 375 data_path = self.metadata[FieldNames.BASE_DIR] 376 return data_path 377 378 @data_relative_path.setter 379 def data_relative_path(self, data_path): 380 self.metadata[FieldNames.DATA_PATH] = data_path 381 382 @property 383 def data_path(self): 384 data_path = self.data_relative_path 385 if self._data_root_path: 386 data_path = os.path.join(self._data_root_path, data_path) 387 388 return data_path 389 390 @property 391 def frame_reader(self): 392 data_path = self.data_path 393 if os.path.isdir(data_path): 394 frame_reader = VideoSortedFolderReader(data_path, self.fps) 395 else: 396 frame_reader = VideoFrameReader(data_path) 397 return frame_reader 398 399 def get_cache_file(self, cache_name, extension=''): 400 rel_path = os.path.splitext(self.data_relative_path)[0] if os.path.isfile(self.data_path) else self.data_relative_path 401 return os.path.join(self._cache_root_path, cache_name, rel_path + extension) 402 403 @property 404 def fps(self): 405 fps = self.metadata.get(FieldNames.FPS, None) 406 return fps 407 408 @property 409 def period(self): 410 """Retrieves the period in milliseconds (1000 / fps), if fps is unset, returns None""" 411 fps = self.fps 412 return 1000 / fps if fps else None 413 414 @property 415 def width(self): 416 width = self.metadata.get(FieldNames.RESOLUTION, {}).get(FieldNames.WIDTH, None) 417 return width 418 419 @property 420 def height(self): 421 height = self.metadata.get(FieldNames.RESOLUTION, {}).get(FieldNames.HEIGHT, None) 422 return height 423 424 @property 425 def num_minutes(self): 426 frames = self.metadata[FieldNames.NUM_FRAMES] 427 fps = self.fps if self.fps > 1 else 30. 428 return frames/fps/60. 429 430 def get_data_reader(self): 431 return DataReader(self) 432 433 def _init_entity_fields(self): 434 self._time_entity_dict = defaultdict(list) 435 self._id_entity_dict = defaultdict(list) 436 self._frame_entity_dict = defaultdict(list) 437 self._label_entity_dict = defaultdict(list) 438 self._entity_times = [] 439 for entity in self.entities: 440 self._update_key_dicts(entity) 441 442 @property 443 def frame_num_entity_dict(self): 444 self._lazy_init() 445 return self._frame_entity_dict 446 447 def get_entities_for_frame_num(self, frame_idx): 448 return self.frame_num_entity_dict[frame_idx] 449 450 @property 451 def id_entity_dict(self): 452 self._lazy_init() 453 return self._id_entity_dict 454 455 def get_entities_with_id(self, id): 456 return self.id_entity_dict[id] 457 458 @property 459 def time_entity_dict(self): 460 self._lazy_init() 461 return self._time_entity_dict 462 463 def get_entities_at_time(self, time): 464 return self.time_entity_dict[time] 465 466 @property 467 def label_entity_dict(self): 468 self._lazy_init() 469 return self._label_entity_dict 470 471 def get_entities_with_label(self, label): 472 return self.label_entity_dict[label] 473 474 def _get_entity_times_sorted(self): 475 if self._times_unsorted: 476 self._entity_times.sort() 477 self._times_unsorted = False 478 return self._entity_times 479 480 def get_entities_near_time(self, time, time_diff_cutoff=None): 481 import bisect 482 self._lazy_init() 483 484 if time_diff_cutoff is None: 485 # Set to the period msec - 1 486 time_diff_cutoff = int(1.0 / self.fps * 1000) - 1 487 488 entity_times = self._get_entity_times_sorted() 489 490 if not entity_times: 491 return [] 492 493 insert_pos = bisect.bisect(entity_times, time) 494 time_left = entity_times[max(insert_pos - 1, 0)] 495 time_right = entity_times[min(insert_pos, len(entity_times)-1)] 496 497 diff_left = abs(time_left - time) 498 diff_right = abs(time_right - time) 499 500 if diff_right < diff_left: 501 closest_time = time_right 502 diff = diff_right 503 else: 504 closest_time = time_left 505 diff = diff_left 506 507 if diff < time_diff_cutoff: 508 entities = self._time_entity_dict[closest_time] 509 else: 510 entities = [] 511 512 return entities 513 514 def _update_key_dicts(self, entity): 515 self._id_entity_dict[entity.id].append(entity) 516 new_time = entity.time not in self._time_entity_dict 517 self._time_entity_dict[entity.time].append(entity) 518 if entity.time is not None and new_time: 519 self._entity_times.append(entity.time) 520 self._times_unsorted = True 521 # an entity could have multiple labels 522 if entity.labels is not None: 523 for k, v in entity.labels.items(): 524 self._label_entity_dict[k].append(entity) 525 if entity.frame_num is not None: 526 self._frame_entity_dict[entity.frame_num].append(entity) 527 else: 528 if entity.time is not None and self.fps: 529 # Time in seconds * fps = frame_num 530 frame_num = round((entity.time / 1000) * self.fps) 531 else: 532 frame_num = None 533 self._frame_entity_dict[frame_num].append(entity) 534 535 def add_entity(self, entities): 536 """ 537 Add a new entity or a list of entities to the sample 538 :param entities: 539 :return: 540 """ 541 self._lazy_init() 542 543 if isinstance(entities, AnnoEntity): 544 self._entities.append(entities) 545 self._update_key_dicts(entities) 546 else: 547 self._entities.extend(entities) 548 for entity in entities: 549 self._update_key_dicts(entity) 550 551 def get_copy_without_entities(self, new_id=None): 552 """ 553 :return: A new DataSample with the same id and metadata but no entities 554 """ 555 if new_id is None: 556 new_id = self.id 557 return DataSample(new_id, root_path=self._root_path, metadata=self.metadata, dataset=self._dataset) 558 559 def filter_entities(self, filter_fn): 560 """ 561 :param filter_fn: When true, keep the entity, otherwise omit it 562 """ 563 new_sample = self.get_copy_without_entities() 564 for entity in self.entities: 565 if filter_fn(entity): 566 new_sample.add_entity(entity) 567 return new_sample 568 569 def get_non_empty_frames(self, filter_fn=None, fps=0): 570 """ 571 Return indexes of all valid frames with the specified fps, 572 whose annotation exists 573 """ 574 if fps == 0: 575 fps = self.fps 576 interval = int(np.ceil(self.fps / fps)) 577 578 frame_idxs = [] 579 for idx in range(0, len(self), interval): 580 entities = self.get_entities_for_frame_num(idx) 581 if filter_fn is not None: 582 entities, _ = filter_fn(entities) 583 if len(entities) > 0: 584 frame_idxs.append(idx) 585 return sorted(frame_idxs) 586 587 def to_dict(self, include_id=False, lazy_load_format=False): 588 """ 589 Dump the information in this sample to a dict 590 :return: 591 """ 592 out = dict() 593 594 if lazy_load_format and self._filepath: 595 if self._metadata: 596 out[FieldNames.METADATA] = self._metadata 597 out[FieldNames.SAMPLE_FILE] = self._filepath 598 return out 599 600 self._lazy_init() 601 602 if include_id: 603 out[FieldNames.ID] = self.id 604 out[FieldNames.METADATA] = self.metadata 605 out[FieldNames.ENTITY_LIST] = [x.to_dict() for x in self.entities] 606 return out 607 608 def dump(self, filename, indent=0, include_id=True, **kwargs): 609 save_json(self.to_dict(include_id=include_id), filename, indent=indent, **kwargs) 610 611 def _get_lazy_load_path(self): 612 if not self._filepath or not self._dataset: 613 raise ValueError("Cannot get lazy load path without a filepath and dataset") 614 base_path = Path(self._dataset.anno_path).with_suffix("") 615 if self._filepath is True: 616 filepath = base_path / (self.id + ".json") 617 else: 618 filepath = base_path / self._filepath 619 return filepath 620 621 def _lazy_load(self): 622 filepath = self._get_lazy_load_path() 623 self._raw_info = load_json(filepath) 624 if self.metadata and self._raw_info.get(FieldNames.METADATA) != self.metadata: 625 log.info("metadata did not match lazy loaded value for sample: {}, ignoring loaded value, this should be" 626 " resolved next time you dump the dataset".format(self.id)) 627 self._raw_info[FieldNames.METADATA] = self.metadata 628 self._parse() 629 self._lazy_loaded = True 630 631 def _dump_for_lazy_load(self): 632 if not self._lazy_loaded: 633 log.debug("nothing lazy loaded to dump, returning") 634 return 635 filepath = self._get_lazy_load_path() 636 filepath.parent.mkdir(parents=True, exist_ok=True) 637 self.dump(filepath, include_id=False) 638 639 def clear_lazy_loaded(self, clear_metadata=False): 640 if not self._lazy_loaded: 641 log.debug("nothing lazy loaded so nothing to clear") 642 return 643 self._entities = [] 644 self._raw_info = None 645 self._init_entity_fields() 646 if clear_metadata: 647 self._metadata = {} 648 self._lazy_loaded = False 649 650 @classmethod 651 def load(cls, filename, **kwargs): 652 data = load_json(filename) 653 raw_info = data["raw_info"] if "raw_info" in data else data 654 return cls(data[FieldNames.ID], raw_info=raw_info, **kwargs) 655 656 def _parse(self): 657 self._metadata = self._raw_info.get(FieldNames.METADATA, {}) 658 # Lazy load entities for speed when loading dataset 659 self._raw_entities = self._raw_info.get(FieldNames.ENTITY_LIST, []) 660 filepath = self._raw_info.get(FieldNames.SAMPLE_FILE) 661 if filepath is not None: 662 self._filepath = filepath 663 664 def copy(self): 665 new = pickle.loads(pickle.dumps(self)) 666 for f in self._NO_SERIALIZE_FIELDS: 667 setattr(new, f, getattr(self, f)) 668 return new 669 670 def __getstate__(self): 671 # Used by pickle and deepcopy, this prevents trying to copy the whole dataset due to the dataset back reference 672 return {k: v for k, v in vars(self).items() if k not in self._NO_SERIALIZE_FIELDS} 673 674 def __setstate__(self, state): 675 self.__dict__.update(state) 676 for f in self._NO_SERIALIZE_FIELDS: 677 setattr(self, f, None) 678 679 def __enter__(self): 680 self._lazy_init() 681 return self 682 683 def __exit__(self): 684 self.clear_lazy_loaded() 685 686 687class GluonCVMotionDataset: 688 ANNO_DIR = "annotation" 689 CACHE_DIR = "cache" 690 DATA_DIR = "raw_data" 691 692 _DEFAULT_ANNO_FILE = "anno.json" 693 _DEFAULT_SPLIT_FILE = "splits.json" 694 695 def __init__(self, annotation_file=None, root_path=None, split_file=None, load_anno=True): 696 """ 697 GluonCVMotionDataset 698 :param annotation_file: The path to the annotation file, either a full path or a path relative to the root 699 annotation path (root_path/annotation/), defaults to 'anno.json' 700 :param root_path: The root path of the dataset, containing the 'annotation', 'cache', and 'raw_data' folders. 701 If left empty it will be inferred from the annotation_file path by searching up until the 'annotation' folder 702 is found, then going one more level up 703 :param split_file: The path to the split file relative to the annotation file. It will be relative to the root 704 annotation path instead if it starts with './' 705 :param load_anno: Whether to load the annotation file, will cause an exception if it is true and file does not 706 exist. Set this to false if you are just trying to write a new annotation file for example in an ingestion 707 script 708 """ 709 710 # a dict of DataSample instances 711 import indexed 712 self._samples = indexed.IndexedOrderedDict() 713 self._splits = {} 714 self._metadata = {} 715 716 if annotation_file is None: 717 annotation_file = self._DEFAULT_ANNO_FILE 718 log.info("Annotation file not provided, defaulting to '{}'".format(annotation_file)) 719 720 self._root_path = self._get_root_path(root_path, annotation_file) 721 722 if self._root_path: 723 if not os.path.isdir(self._root_path): 724 raise ValueError("Expected root folder but was not found at: {}".format(self._root_path)) 725 726 self._anno_path = os.path.join(self._root_path, self.ANNO_DIR, annotation_file) 727 self._data_path = self.get_data_path_from_root(self._root_path) 728 self._cache_path = self.get_cache_path_from_root(self._root_path) 729 730 if not os.path.isdir(self._data_path): 731 raise ValueError("Expected data folder but was not found at: {}".format(self._data_path)) 732 else: 733 log.warning('Root path was not set for dataset, this should only happen when loading a lone annotation' 734 ' file for inspection') 735 self._anno_path = annotation_file 736 self._data_path = None 737 self._cache_path = None 738 739 if load_anno: 740 if os.path.exists(self._anno_path): 741 log.info('Loading annotation file {}...'.format(self._anno_path)) 742 # load annotation file 743 if self._get_pickle_path().exists(): 744 log.info('Found pickle file, loading this instead') 745 loaded_pickle = self._load_pickle() 746 if not loaded_pickle: 747 self._parse_anno(self._anno_path) 748 self._split_path = self._get_split_path(split_file, self._anno_path) 749 self._load_split() 750 else: 751 raise ValueError( 752 "load_anno is true but the anno path does not exist at: {}".format(self._anno_path)) 753 else: 754 log.info('Skipping loading for annotation file {}'.format(self._anno_path)) 755 self._split_path = self._get_split_path(split_file, self._anno_path) 756 757 def __len__(self): 758 return len(self._samples) 759 760 def __contains__(self, item): 761 return item in self._samples 762 763 def __getitem__(self, item): 764 return self._samples[item] 765 766 def __iter__(self): 767 for item in self._samples.items(): 768 yield item 769 770 @classmethod 771 def get_data_path_from_root(cls, root_path): 772 return os.path.join(root_path, cls.DATA_DIR) 773 774 @classmethod 775 def get_cache_path_from_root(cls, root_path): 776 return os.path.join(root_path, cls.CACHE_DIR) 777 778 def _get_root_path(self, root_path, annotation_file): 779 if root_path is None: 780 dirpath = os.path.dirname(annotation_file) 781 if self.ANNO_DIR in dirpath: 782 while os.path.basename(dirpath) != self.ANNO_DIR and dirpath != '/': 783 dirpath = os.path.dirname(dirpath) 784 root_path = os.path.abspath(os.path.dirname(dirpath)) 785 log.info("Dataset root path inferred to be: {}".format(root_path)) 786 return root_path 787 788 def _get_split_path(self, split_file, anno_path): 789 split_path = split_file 790 if split_path is None: 791 split_path = self._DEFAULT_SPLIT_FILE 792 if not os.path.isabs(split_path): 793 if split_path.startswith("./"): 794 anno_dir = os.path.join(self._root_path, self.ANNO_DIR) 795 else: 796 anno_dir = os.path.dirname(anno_path) 797 split_path = os.path.join(anno_dir, split_path) 798 split_subpath = split_path.replace(self._root_path or "", "").lstrip(os.path.sep) 799 log.info("Split subpath: {}".format(split_subpath)) 800 return split_path 801 802 @property 803 def iter_samples(self): 804 """ 805 returns a iterator of samples 806 :return: 807 """ 808 return self._samples.items() 809 810 @property 811 def samples(self): 812 return self._samples.items() 813 814 @property 815 def sample_ids(self): 816 return self._samples.keys() 817 818 @property 819 def sample_values(self): 820 return self._samples.values() 821 822 def get_split_ids(self, splits=None): 823 if splits is None: 824 # Default is return all ids 825 return self._samples.keys() 826 if isinstance(splits, str): 827 splits = [splits] 828 829 all_ids = [] 830 for split in splits: 831 if split not in self._splits: 832 log.warning("Provided split: {} was not in dataset".format(split)) 833 else: 834 all_ids.extend(self._splits[split]) 835 836 return all_ids 837 838 def get_split_samples(self, splits=None): 839 split_ids = self.get_split_ids(splits) 840 samples = [] 841 for split_id in split_ids: 842 if split_id in self._samples: 843 samples.append((split_id, self._samples[split_id])) 844 else: 845 log.info(f"Dataset is missing sample: {split_id} in split {self.get_use_for_id(split_id)}, skipping") 846 return samples 847 848 @property 849 def train_samples(self): 850 return self.get_split_samples(SplitNames.TRAIN) 851 852 @property 853 def val_samples(self): 854 return self.get_split_samples(SplitNames.VAL) 855 856 @property 857 def trainval_samples(self): 858 return self.get_split_samples([SplitNames.TRAIN, SplitNames.VAL]) 859 860 @property 861 def test_samples(self): 862 return self.get_split_samples(SplitNames.TEST) 863 864 @property 865 def all_samples(self): 866 samples = self.get_split_samples([SplitNames.TRAIN, SplitNames.VAL, SplitNames.TEST]) 867 if not len(samples): 868 samples = self.samples 869 return samples 870 871 def get_use_for_id(self, id): 872 for use in self._splits: 873 if id in self._splits[use]: 874 return use 875 return None 876 877 @property 878 def version(self): 879 return __version__ 880 881 @property 882 def metadata(self): 883 return self._metadata 884 885 @property 886 def name(self): 887 if not self._root_path: 888 return None 889 return os.path.basename(self._root_path) 890 891 @property 892 def root_path(self): 893 return self._root_path 894 895 @property 896 def anno_root_path(self): 897 return os.path.join(self._root_path, self.ANNO_DIR) 898 899 @property 900 def cache_root_path(self): 901 return self._cache_path 902 903 @property 904 def data_root_path(self): 905 return self._data_path 906 907 @property 908 def anno_path(self): 909 return self._anno_path 910 911 def _get_anno_subpath(self, anno_path, with_ext): 912 subpath = Path(anno_path).relative_to(self.anno_root_path) 913 if not with_ext: 914 subpath = subpath.with_suffix("") 915 return str(subpath) 916 917 def get_anno_subpath(self, with_ext=False): 918 return self._get_anno_subpath(self._anno_path, with_ext) 919 920 def get_anno_suffix(self): 921 subpath = self.get_anno_subpath() 922 return "_" + subpath.replace(os.sep, "_") 923 924 @property 925 def split_path(self): 926 return self._split_path 927 928 def get_split_subpath(self, with_ext=False): 929 return self._get_anno_subpath(self._split_path, with_ext) 930 931 def get_split_suffix(self): 932 subpath = self.get_split_subpath() 933 return "_" + subpath.replace(os.sep, "_") 934 935 @metadata.setter 936 def metadata(self, new_md): 937 if self._metadata: 938 # compare keys and report warning 939 diff_keys = set(self._metadata.keys()).difference(new_md.keys()) 940 log.warning("updating the metadata with different keys: {}".format(diff_keys)) 941 self._metadata = new_md 942 943 @property 944 def description(self): 945 return self._metadata.get(FieldNames.DESCRIPTION, "") 946 947 @description.setter 948 def description(self, description): 949 self._metadata[FieldNames.DESCRIPTION] = description 950 951 def add_sample(self, sample:DataSample, dump_directly=False): 952 # create a new sample so it functions just as it would if it were loaded from disk 953 new_sample = DataSample(sample.id, raw_info=sample.to_dict(), root_path=self.root_path, dataset=self) 954 self._samples[sample.id] = new_sample 955 if dump_directly: 956 new_sample._set_filepath(True, already_loaded=True) 957 new_sample._dump_for_lazy_load() 958 new_sample.clear_lazy_loaded() 959 return new_sample 960 961 def dumps(self, encoder=ComplexEncoder, **kwargs): 962 return json.dumps(self._to_dict(), cls=encoder, **kwargs) 963 964 def dump(self, filename=None, indent=0, **kwargs): 965 if filename is None: 966 filename = self._anno_path 967 anno_dir = os.path.dirname(self._anno_path) 968 if not os.path.exists(anno_dir): 969 try: 970 os.mkdir(anno_dir) 971 except OSError: 972 pass 973 save_json(self._to_dict(), filename, indent=indent, **kwargs) 974 975 def _get_pickle_path(self): 976 return Path(self._anno_path).with_suffix(".pkl") 977 978 def _anno_mod_time(self): 979 return Path(self._anno_path).stat().st_mtime 980 981 def dump_pickle(self, filepath=None, **kwargs): 982 if filepath is None: 983 filepath = str(self._get_pickle_path()) 984 modified_time = self._anno_mod_time() 985 to_pickle = { 986 "_samples": self._samples, 987 "_metadata": self._metadata, 988 "_raw_info": self._raw_info, 989 "modified_time": modified_time 990 } 991 save_pickle(to_pickle, filepath, **kwargs) 992 993 def _load_pickle(self, filepath=None): 994 import datetime 995 if filepath is None: 996 filepath = str(self._get_pickle_path()) 997 if not os.path.exists(filepath): 998 return False 999 log.info('Loading pickle file {}'.format(filepath)) 1000 try: 1001 loaded_dict = load_pickle(filepath) 1002 except OSError: 1003 log.warning("Failed to load pickle") 1004 return False 1005 1006 stored_time = loaded_dict["modified_time"] 1007 modified_time = self._anno_mod_time() 1008 if stored_time == modified_time: 1009 self._samples = loaded_dict["_samples"] 1010 self._metadata = loaded_dict["_metadata"] 1011 self._raw_info = loaded_dict["_raw_info"] 1012 return True 1013 else: 1014 log.info(('The pickle stored modification time did not match the annotation, so not loading,' 1015 ' please remove and regenerate the pickle, renaming to .old').format(filepath)) 1016 new_filepath = str(filepath) + ".old_" + str(datetime.datetime.now()).replace(" ", "_") 1017 try: 1018 Path(filepath).rename(new_filepath) 1019 except OSError: 1020 pass 1021 return False 1022 1023 def _parse_anno(self, annotation_file): 1024 json_info = load_json(annotation_file) 1025 self._raw_info = json_info 1026 log.info("loaded anno json") 1027 1028 # load metadata 1029 assert FieldNames.DATASET_METADATA in self._raw_info, \ 1030 "key: {} should present in the annotation file, we only got {}".format( 1031 FieldNames.DATASET_METADATA, self._raw_info.keys()) 1032 self._metadata.update(self._raw_info[FieldNames.DATASET_METADATA]) 1033 1034 # check key map hash 1035 key_hash = self._metadata.get(FieldNames.KEY_HASH, '') 1036 assert key_hash == FieldNames.get_key_hash(), "Key list not matching. " \ 1037 "Maybe this annoation file is created with other versions." \ 1038 "Current version is {}".format(self.version) 1039 1040 # load samples 1041 sample_dict = self._raw_info.get(FieldNames.SAMPLE_DICT, dict()) 1042 root_path = self.root_path 1043 for k in sorted(sample_dict.keys()): 1044 self._samples[k] = DataSample(k, sample_dict[k], root_path=root_path, dataset=self) 1045 # self._samples[k] = sample_dict[k] 1046 log.info("loaded {} samples".format(len(self._samples))) 1047 1048 def _load_split(self): 1049 if not os.path.exists(self._split_path): 1050 log.warning("Split path {} not found, skipping loading".format(self._split_path)) 1051 return 1052 self._splits = load_json(self._split_path) 1053 split_sample_nums = {k: len(v) for k, v in self._splits.items()} 1054 log.info("Loaded splits with # samples: {}".format(split_sample_nums)) 1055 1056 @property 1057 def splits(self): 1058 return dict(self._splits) 1059 1060 @splits.setter 1061 def splits(self, split_dict): 1062 self._splits = split_dict 1063 1064 def dump_splits(self, filename=None, indent=2): 1065 if filename is None: 1066 filename = self._split_path 1067 save_json(self._splits, filename, indent=indent) 1068 1069 def _to_dict(self, dump_sample_files=True): 1070 # add the version information to metadata 1071 self._metadata[FieldNames.DATASET_VERSION] = self.version 1072 self._metadata[FieldNames.KEY_HASH] = FieldNames.get_key_hash() 1073 1074 dump_dict = dict() 1075 dump_dict[FieldNames.DATASET_METADATA] = self._metadata 1076 all_samples_dict = {} 1077 for sample_id, sample in self._samples.items(): 1078 sample_dict = sample.to_dict(lazy_load_format=True) 1079 if sample._lazy_loaded and dump_sample_files: 1080 sample._dump_for_lazy_load() 1081 all_samples_dict[sample_id] = sample_dict 1082 dump_dict[FieldNames.SAMPLE_DICT] = all_samples_dict 1083 return dump_dict 1084 1085 1086def get_resized_256_video_location(data_sample: DataSample) -> str: 1087 return get_resized_video_location(data_sample, 256) 1088 1089 1090def get_resized_video_location(data_sample: DataSample, short_edge_res:int) -> str: 1091 return data_sample.get_cache_file('rgb_{}_mp4'.format(short_edge_res), extension='.mp4') 1092 1093 1094def get_vis_video_location(data_sample: DataSample) -> str: 1095 return data_sample.get_cache_file('full_res_mp4', extension='.mp4') 1096 1097 1098def get_vis_thumb_location(data_sample: DataSample) -> str: 1099 return data_sample.get_cache_file('thumbnails', extension='.jpg') 1100 1101 1102def get_vis_gt_location(data_sample: DataSample, cache_subpath: str) -> str: 1103 return data_sample.get_cache_file(os.path.join('gt_vis_json', cache_subpath), extension='.json') 1104 1105 1106def get_gt_data_sample_location(data_sample: DataSample) -> str: 1107 return data_sample.get_cache_file('gt_data_sample_json', extension='.json') 1108