1import os
2import json
3import logging
4import hashlib
5import pickle
6import numpy as np
7from collections import defaultdict
8from pathlib import Path
9
10from .io.video_io import VideoSortedFolderReader, VideoFrameReader
11from .utils.serialization_utils import save_json, load_json, ComplexEncoder, \
12    save_pickle, load_pickle
13
14
15logging.basicConfig(level=logging.INFO)
16log = logging.getLogger()
17
18
19class FieldNames:
20    """
21    Keys found in annotation dict
22    """
23    # Entity Fields
24    TIME = 'time'
25    LABELS = 'labels'
26    ID = 'id'
27    BLOB = 'blob'
28
29    # Sample fields
30    METADATA = 'metadata'
31    ENTITY_LIST = 'entities'
32    SAMPLE_FILE = 'sample_file'
33
34    # Dataset fields
35    SAMPLE_DICT = 'samples'
36    SAMPLES = 'samples'
37    DATASET_METADATA = 'metadata'
38    DATASET_VERSION = 'version'
39    CLASS_MAPPING = 'class_mapping'
40
41    # Data fields
42    DATE_ADDED = 'date_added'
43    DURATION = 'duration'
44    DATA_PATH = 'data_path'
45    FILENAME = 'filename'
46    BASE_DIR = 'base_dir'
47    FILE_EXT = 'file_ext'
48    FPS = 'fps'
49    NUM_FRAMES = 'number_of_frames'
50    RESOLUTION = 'resolution'
51    WIDTH = 'width'
52    HEIGHT = 'height'
53    DATASET_SOURCE = 'data_source'
54    SOURCE_ID = 'source_id'
55    SAMPLE_SOURCE = 'sample_source'
56    ORIG_ID = 'orig_id'
57    TEMPORAL_SEGMENTS = 'temporal_segments'
58    DESCRIPTION = 'description'
59    BOUNDING_BOXES = 'bb'
60    HEAD_BBOX = "bb_head"
61    FACE_BBOX = "bb_face"
62    MASK = 'mask'
63    KEYPOINTS = 'keypoints'
64    CONFIDENCE = 'confidence'
65    DEFAULT_VALUES = 'default_values'
66    DEFAULT_GT = 'default_gt'
67    FORMAT_VERSION = 'format_version'
68    DATE_MODIFIED = 'last_modified'
69    CHANGES = 'changes'
70    KEY_HASH = 'key_hash'
71    FRAME_IDX = 'frame_idx'
72
73    @classmethod
74    def get_key_hash(cls):
75        return 0
76
77
78class SplitNames:
79    TRAIN = "train"
80    VAL = "val"
81    TEST = "test"
82
83
84def get_instance_list(cls, raw_entity_list):
85    return [cls(raw_info=x) for x in raw_entity_list]
86
87
88class AnnoEntity:
89    """
90    One annotation entity.
91    A entity can refer to a video label, or a time segment annotation, or a frame annotation
92    It has a required field "time", which should be in milliseconds
93    Other than that, there are a few optional fields:
94        id: if the entity has an id
95        labels: if the entity has some categorical labels
96        blob: other annotation information of this entity
97    """
98
99    def __init__(self, time=None, id=None, raw_info=None, validate=True):
100        self._raw_info = raw_info
101        self._time = None
102        self._labels = None
103        self._mask = None
104        self._blob = {}
105
106        self.id = None
107        self.bbox = None
108        self.keypoints = None
109        self.confidence = None
110
111        if raw_info is not None:
112            if time is not None or id is not None:
113                log.warning("time or id were specified, but raw_info was set and so they will be ignored")
114            self._parse_info()
115        else:
116            if validate:
117                self.time = time
118            else:
119                self._time = time
120            self.id = id
121
122    @property
123    def time(self):
124        return self._time
125
126    @time.setter
127    def time(self, time):
128        if isinstance(time, (int, float)):
129            assert time >= 0
130            self._time = time
131        elif isinstance(time, str):
132            self._time = float(time)
133        elif isinstance(time, tuple):
134            assert len(time) == 2
135            self._time = float(time[0]), float(time[1])
136
137    @property
138    def labels(self):
139        return self._labels
140
141    @labels.setter
142    def labels(self, new_labels):
143        if self._labels:
144            # compare keys and report warning
145            diff_keys = set(self._labels.keys()).difference(new_labels.keys())
146            log.warning("updating the labels with different keys: {}".format(diff_keys))
147        self._labels = new_labels
148
149
150    @property
151    def keypoints_xyv(self):
152        return np.vstack([np.asarray(self.keypoints[0::3], dtype=np.float32),
153                np.asarray(self.keypoints[1::3], dtype=np.float32),
154                np.asarray(self.keypoints[2::3], dtype=np.float32)]).T
155
156    @keypoints_xyv.setter
157    def keypoints_xyv(self, kp_xyv):
158        self.keypoints = list(np.asarray(kp_xyv).flatten())
159
160    @property
161    def mask(self):
162        return self._mask
163
164    @mask.setter
165    def mask(self, mask):
166        self._mask = mask
167
168    @property
169    def frame_num(self):
170        if FieldNames.FRAME_IDX not in self.blob:
171            return None
172        return self.blob[FieldNames.FRAME_IDX]
173
174    @frame_num.setter
175    def frame_num(self, frame_num):
176        self.blob[FieldNames.FRAME_IDX] = frame_num
177
178    @property
179    def blob(self):
180        return self._blob
181
182    @blob.setter
183    def blob(self, new_blob):
184        if self._blob:
185            # compare keys and report warning
186            diff_keys = set(self._blob.keys()).difference(new_blob.keys())
187            log.warning("updating the blob with different keys: {}".format(diff_keys))
188        self._blob = new_blob
189
190    def to_dict(self):
191        FN = FieldNames
192        out = dict()
193
194        out[FN.TIME] = self.time
195
196        if self.id is not None:
197            out[FN.ID] = self.id
198
199        if self.labels is not None:
200            out[FN.LABELS] = self.labels
201
202        if self.confidence is not None:
203            out[FN.CONFIDENCE] = self.confidence
204
205        if self.bbox is not None:
206            out[FN.BOUNDING_BOXES] = self.bbox
207
208        if self.mask is not None:
209            out[FN.MASK] = self.mask
210
211        if self.keypoints is not None:
212            out[FN.KEYPOINTS] = self.keypoints
213
214        if self.blob:
215            out[FN.BLOB] = self.blob
216
217        return out
218
219    def _parse_info(self):
220        FN = FieldNames
221
222        raw_info = self._raw_info
223
224        self._time = raw_info[FN.TIME]
225
226        self.id = raw_info.get(FN.ID, None)
227        self.bbox = raw_info.get(FN.BOUNDING_BOXES, None)
228        self.keypoints = raw_info.get(FN.KEYPOINTS, None)
229        self.confidence = raw_info.get(FN.CONFIDENCE, 1.0)
230        self._labels = raw_info.get(FN.LABELS, None)
231        self._mask = raw_info.get(FN.MASK, None)
232        self._blob = raw_info.get(FN.BLOB, {})
233
234    ### helper functions ###
235    @property
236    def x(self):
237        return self.bbox[0] if self.bbox else None
238    @property
239    def y(self):
240        return self.bbox[1] if self.bbox else None
241    @property
242    def w(self):
243        return self.bbox[2] if self.bbox else None
244    @property
245    def h(self):
246        return self.bbox[3] if self.bbox else None
247
248
249class DataReader:
250    def __init__(self, data_sample, max_frame_deviation=1, fps=None):
251
252        self._data_sample = data_sample
253        self._data_path = data_sample.data_path
254        self._frame_reader = None
255        self._frame_iter = None
256
257        if fps is None:
258            fps = data_sample.fps
259        self._fps = fps
260
261        self._max_frame_deviation = max_frame_deviation
262        self._period = 1.0 / fps * 1000
263        self._time_diff_cutoff = int(self._max_frame_deviation * self._period) - 1
264
265        self._frame_reader = self._data_sample.frame_reader
266
267    def __iter__(self):
268        self._frame_iter = iter(self._frame_reader)
269        return self
270
271    def __next__(self):
272        frame, ts = next(self._frame_iter)
273        if frame is None:
274            raise StopIteration
275        entities = self._data_sample.get_entities_near_time(int(ts), self._time_diff_cutoff)
276        return frame, ts, entities
277
278    def __getitem__(self, item):
279        frame, ts = self._frame_reader[item]
280        entities = self._data_sample.get_entities_near_time(int(ts), self._time_diff_cutoff)
281        return frame, ts, entities
282
283
284class DataSample:
285    """
286    One sample in the dataset. This can be a video file or an image file.
287    It contains a list of entities as annotations.
288    Each data sample can have some meta data. For example in videos, one can have
289        FPS
290        Duration
291        Number of frames
292        Source id
293        ...
294    """
295    _NO_SERIALIZE_FIELDS = ("_dataset",)
296
297    def __init__(self, id, raw_info=None, root_path=None, metadata=None, dataset=None):
298        self._raw_info = raw_info
299        self._id = id
300        self._entities = []
301        if metadata is not None:
302            self._metadata = dict(metadata)
303        else:
304            self._metadata = {}
305        self._dataset = dataset
306        self._filepath = None
307        self._lazy_loaded = False
308        self._root_path = None
309        self._data_root_path = None
310        self._cache_root_path = None
311        self._raw_entities = None
312        self._time_entity_dict = None
313        self._entity_times = None
314        self._times_unsorted = True
315        self._id_entity_dict = None
316        self._frame_entity_dict = None
317        self._label_entity_dict = None
318        self._init_entity_fields()
319
320        self.set_root_path(root_path)
321        if self._raw_info:
322            self._parse()
323
324    def set_root_path(self, root_path):
325        self._root_path = root_path
326        if root_path:
327            self._data_root_path = GluonCVMotionDataset.get_data_path_from_root(root_path)
328            self._cache_root_path = GluonCVMotionDataset.get_cache_path_from_root(root_path)
329
330    def _set_filepath(self, filepath, already_loaded=False):
331        self._filepath = filepath
332        self._lazy_loaded = already_loaded
333
334    @property
335    def id(self):
336        return self._id
337
338    @property
339    def metadata(self):
340        return self._metadata
341
342    @metadata.setter
343    def metadata(self, new_md):
344        if self._metadata:
345            # compare keys and report warning
346            diff_keys = set(self._metadata.keys()).difference(new_md.keys())
347            log.warning("updating the metadata with different keys: {}".format(diff_keys))
348        self._metadata = new_md
349
350    def _lazy_init(self):
351        if self._filepath and not self._lazy_loaded:
352            self._lazy_load()
353        if (self._raw_entities is not None) and not self._entities:
354            self._entities = get_instance_list(AnnoEntity, self._raw_entities)
355            self._raw_entities = None
356            self._init_entity_fields()
357
358    def __len__(self):
359        return self.metadata[FieldNames.NUM_FRAMES]
360
361    @property
362    def duration(self):
363        return self.metadata[FieldNames.DURATION]
364
365    @property
366    def entities(self) -> [AnnoEntity]:
367        self._lazy_init()
368        return self._entities
369
370    @property
371    def data_relative_path(self):
372        if FieldNames.DATA_PATH in self.metadata:
373            data_path = self.metadata[FieldNames.DATA_PATH]
374        if FieldNames.BASE_DIR in self.metadata:
375            data_path = self.metadata[FieldNames.BASE_DIR]
376        return data_path
377
378    @data_relative_path.setter
379    def data_relative_path(self, data_path):
380        self.metadata[FieldNames.DATA_PATH] = data_path
381
382    @property
383    def data_path(self):
384        data_path = self.data_relative_path
385        if self._data_root_path:
386            data_path = os.path.join(self._data_root_path, data_path)
387
388        return data_path
389
390    @property
391    def frame_reader(self):
392        data_path = self.data_path
393        if os.path.isdir(data_path):
394            frame_reader = VideoSortedFolderReader(data_path, self.fps)
395        else:
396            frame_reader = VideoFrameReader(data_path)
397        return frame_reader
398
399    def get_cache_file(self, cache_name, extension=''):
400        rel_path = os.path.splitext(self.data_relative_path)[0] if os.path.isfile(self.data_path) else self.data_relative_path
401        return os.path.join(self._cache_root_path, cache_name, rel_path + extension)
402
403    @property
404    def fps(self):
405        fps = self.metadata.get(FieldNames.FPS, None)
406        return fps
407
408    @property
409    def period(self):
410        """Retrieves the period in milliseconds (1000 / fps), if fps is unset, returns None"""
411        fps = self.fps
412        return 1000 / fps if fps else None
413
414    @property
415    def width(self):
416        width = self.metadata.get(FieldNames.RESOLUTION, {}).get(FieldNames.WIDTH, None)
417        return width
418
419    @property
420    def height(self):
421        height = self.metadata.get(FieldNames.RESOLUTION, {}).get(FieldNames.HEIGHT, None)
422        return height
423
424    @property
425    def num_minutes(self):
426        frames = self.metadata[FieldNames.NUM_FRAMES]
427        fps = self.fps if self.fps > 1 else 30.
428        return frames/fps/60.
429
430    def get_data_reader(self):
431        return DataReader(self)
432
433    def _init_entity_fields(self):
434        self._time_entity_dict = defaultdict(list)
435        self._id_entity_dict = defaultdict(list)
436        self._frame_entity_dict = defaultdict(list)
437        self._label_entity_dict = defaultdict(list)
438        self._entity_times = []
439        for entity in self.entities:
440            self._update_key_dicts(entity)
441
442    @property
443    def frame_num_entity_dict(self):
444        self._lazy_init()
445        return self._frame_entity_dict
446
447    def get_entities_for_frame_num(self, frame_idx):
448        return self.frame_num_entity_dict[frame_idx]
449
450    @property
451    def id_entity_dict(self):
452        self._lazy_init()
453        return self._id_entity_dict
454
455    def get_entities_with_id(self, id):
456        return self.id_entity_dict[id]
457
458    @property
459    def time_entity_dict(self):
460        self._lazy_init()
461        return self._time_entity_dict
462
463    def get_entities_at_time(self, time):
464        return self.time_entity_dict[time]
465
466    @property
467    def label_entity_dict(self):
468        self._lazy_init()
469        return self._label_entity_dict
470
471    def get_entities_with_label(self, label):
472        return self.label_entity_dict[label]
473
474    def _get_entity_times_sorted(self):
475        if self._times_unsorted:
476            self._entity_times.sort()
477            self._times_unsorted = False
478        return self._entity_times
479
480    def get_entities_near_time(self, time, time_diff_cutoff=None):
481        import bisect
482        self._lazy_init()
483
484        if time_diff_cutoff is None:
485            # Set to the period msec - 1
486            time_diff_cutoff = int(1.0 / self.fps * 1000) - 1
487
488        entity_times = self._get_entity_times_sorted()
489
490        if not entity_times:
491            return []
492
493        insert_pos = bisect.bisect(entity_times, time)
494        time_left = entity_times[max(insert_pos - 1, 0)]
495        time_right = entity_times[min(insert_pos, len(entity_times)-1)]
496
497        diff_left = abs(time_left - time)
498        diff_right = abs(time_right - time)
499
500        if diff_right < diff_left:
501            closest_time = time_right
502            diff = diff_right
503        else:
504            closest_time = time_left
505            diff = diff_left
506
507        if diff < time_diff_cutoff:
508            entities = self._time_entity_dict[closest_time]
509        else:
510            entities = []
511
512        return entities
513
514    def _update_key_dicts(self, entity):
515        self._id_entity_dict[entity.id].append(entity)
516        new_time = entity.time not in self._time_entity_dict
517        self._time_entity_dict[entity.time].append(entity)
518        if entity.time is not None and new_time:
519            self._entity_times.append(entity.time)
520            self._times_unsorted = True
521        # an entity could have multiple labels
522        if entity.labels is not None:
523            for k, v in entity.labels.items():
524                self._label_entity_dict[k].append(entity)
525        if entity.frame_num is not None:
526            self._frame_entity_dict[entity.frame_num].append(entity)
527        else:
528            if entity.time is not None and self.fps:
529                # Time in seconds * fps = frame_num
530                frame_num = round((entity.time / 1000) * self.fps)
531            else:
532                frame_num = None
533            self._frame_entity_dict[frame_num].append(entity)
534
535    def add_entity(self, entities):
536        """
537        Add a new entity or a list of entities to the sample
538        :param entities:
539        :return:
540        """
541        self._lazy_init()
542
543        if isinstance(entities, AnnoEntity):
544            self._entities.append(entities)
545            self._update_key_dicts(entities)
546        else:
547            self._entities.extend(entities)
548            for entity in entities:
549                self._update_key_dicts(entity)
550
551    def get_copy_without_entities(self, new_id=None):
552        """
553        :return: A new DataSample with the same id and metadata but no entities
554        """
555        if new_id is None:
556            new_id = self.id
557        return DataSample(new_id, root_path=self._root_path, metadata=self.metadata, dataset=self._dataset)
558
559    def filter_entities(self, filter_fn):
560        """
561        :param filter_fn: When true, keep the entity, otherwise omit it
562        """
563        new_sample = self.get_copy_without_entities()
564        for entity in self.entities:
565            if filter_fn(entity):
566                new_sample.add_entity(entity)
567        return new_sample
568
569    def get_non_empty_frames(self, filter_fn=None, fps=0):
570        """
571          Return indexes of all valid frames with the specified fps,
572          whose annotation exists
573        """
574        if fps == 0:
575            fps = self.fps
576        interval = int(np.ceil(self.fps / fps))
577
578        frame_idxs = []
579        for idx in range(0, len(self), interval):
580            entities = self.get_entities_for_frame_num(idx)
581            if filter_fn is not None:
582                entities, _ = filter_fn(entities)
583            if len(entities) > 0:
584                frame_idxs.append(idx)
585        return sorted(frame_idxs)
586
587    def to_dict(self, include_id=False, lazy_load_format=False):
588        """
589        Dump the information in this sample to a dict
590        :return:
591        """
592        out = dict()
593
594        if lazy_load_format and self._filepath:
595            if self._metadata:
596                out[FieldNames.METADATA] = self._metadata
597            out[FieldNames.SAMPLE_FILE] = self._filepath
598            return out
599
600        self._lazy_init()
601
602        if include_id:
603            out[FieldNames.ID] = self.id
604        out[FieldNames.METADATA] = self.metadata
605        out[FieldNames.ENTITY_LIST] = [x.to_dict() for x in self.entities]
606        return out
607
608    def dump(self, filename, indent=0, include_id=True, **kwargs):
609        save_json(self.to_dict(include_id=include_id), filename, indent=indent, **kwargs)
610
611    def _get_lazy_load_path(self):
612        if not self._filepath or not self._dataset:
613            raise ValueError("Cannot get lazy load path without a filepath and dataset")
614        base_path = Path(self._dataset.anno_path).with_suffix("")
615        if self._filepath is True:
616            filepath = base_path / (self.id + ".json")
617        else:
618            filepath = base_path / self._filepath
619        return filepath
620
621    def _lazy_load(self):
622        filepath = self._get_lazy_load_path()
623        self._raw_info = load_json(filepath)
624        if self.metadata and self._raw_info.get(FieldNames.METADATA) != self.metadata:
625            log.info("metadata did not match lazy loaded value for sample: {}, ignoring loaded value, this should be"
626                      " resolved next time you dump the dataset".format(self.id))
627            self._raw_info[FieldNames.METADATA] = self.metadata
628        self._parse()
629        self._lazy_loaded = True
630
631    def _dump_for_lazy_load(self):
632        if not self._lazy_loaded:
633            log.debug("nothing lazy loaded to dump, returning")
634            return
635        filepath = self._get_lazy_load_path()
636        filepath.parent.mkdir(parents=True, exist_ok=True)
637        self.dump(filepath, include_id=False)
638
639    def clear_lazy_loaded(self, clear_metadata=False):
640        if not self._lazy_loaded:
641            log.debug("nothing lazy loaded so nothing to clear")
642            return
643        self._entities = []
644        self._raw_info = None
645        self._init_entity_fields()
646        if clear_metadata:
647            self._metadata = {}
648        self._lazy_loaded = False
649
650    @classmethod
651    def load(cls, filename, **kwargs):
652        data = load_json(filename)
653        raw_info = data["raw_info"] if "raw_info" in data else data
654        return cls(data[FieldNames.ID], raw_info=raw_info, **kwargs)
655
656    def _parse(self):
657        self._metadata = self._raw_info.get(FieldNames.METADATA, {})
658        # Lazy load entities for speed when loading dataset
659        self._raw_entities = self._raw_info.get(FieldNames.ENTITY_LIST, [])
660        filepath = self._raw_info.get(FieldNames.SAMPLE_FILE)
661        if filepath is not None:
662            self._filepath = filepath
663
664    def copy(self):
665        new = pickle.loads(pickle.dumps(self))
666        for f in self._NO_SERIALIZE_FIELDS:
667            setattr(new, f, getattr(self, f))
668        return new
669
670    def __getstate__(self):
671        # Used by pickle and deepcopy, this prevents trying to copy the whole dataset due to the dataset back reference
672        return {k: v for k, v in vars(self).items() if k not in self._NO_SERIALIZE_FIELDS}
673
674    def __setstate__(self, state):
675        self.__dict__.update(state)
676        for f in self._NO_SERIALIZE_FIELDS:
677            setattr(self, f, None)
678
679    def __enter__(self):
680        self._lazy_init()
681        return self
682
683    def __exit__(self):
684        self.clear_lazy_loaded()
685
686
687class GluonCVMotionDataset:
688    ANNO_DIR = "annotation"
689    CACHE_DIR = "cache"
690    DATA_DIR = "raw_data"
691
692    _DEFAULT_ANNO_FILE = "anno.json"
693    _DEFAULT_SPLIT_FILE = "splits.json"
694
695    def __init__(self, annotation_file=None, root_path=None, split_file=None, load_anno=True):
696        """
697        GluonCVMotionDataset
698        :param annotation_file: The path to the annotation file, either a full path or a path relative to the root
699         annotation path (root_path/annotation/), defaults to 'anno.json'
700        :param root_path: The root path of the dataset, containing the 'annotation', 'cache', and 'raw_data' folders.
701         If left empty it will be inferred from the annotation_file path by searching up until the 'annotation' folder
702         is found, then going one more level up
703        :param split_file: The path to the split file relative to the annotation file. It will be relative to the root
704         annotation path instead if it starts with './'
705        :param load_anno: Whether to load the annotation file, will cause an exception if it is true and file does not
706         exist. Set this to false if you are just trying to write a new annotation file for example in an ingestion
707         script
708        """
709
710        # a dict of DataSample instances
711        import indexed
712        self._samples = indexed.IndexedOrderedDict()
713        self._splits = {}
714        self._metadata = {}
715
716        if annotation_file is None:
717            annotation_file = self._DEFAULT_ANNO_FILE
718            log.info("Annotation file not provided, defaulting to '{}'".format(annotation_file))
719
720        self._root_path = self._get_root_path(root_path, annotation_file)
721
722        if self._root_path:
723            if not os.path.isdir(self._root_path):
724                raise ValueError("Expected root folder but was not found at: {}".format(self._root_path))
725
726            self._anno_path = os.path.join(self._root_path, self.ANNO_DIR, annotation_file)
727            self._data_path = self.get_data_path_from_root(self._root_path)
728            self._cache_path = self.get_cache_path_from_root(self._root_path)
729
730            if not os.path.isdir(self._data_path):
731                raise ValueError("Expected data folder but was not found at: {}".format(self._data_path))
732        else:
733            log.warning('Root path was not set for dataset, this should only happen when loading a lone annotation'
734                        ' file for inspection')
735            self._anno_path = annotation_file
736            self._data_path = None
737            self._cache_path = None
738
739        if load_anno:
740            if os.path.exists(self._anno_path):
741                log.info('Loading annotation file {}...'.format(self._anno_path))
742                # load annotation file
743                if self._get_pickle_path().exists():
744                    log.info('Found pickle file, loading this instead')
745                loaded_pickle = self._load_pickle()
746                if not loaded_pickle:
747                    self._parse_anno(self._anno_path)
748                self._split_path = self._get_split_path(split_file, self._anno_path)
749                self._load_split()
750            else:
751                raise ValueError(
752                    "load_anno is true but the anno path does not exist at: {}".format(self._anno_path))
753        else:
754            log.info('Skipping loading for annotation file {}'.format(self._anno_path))
755            self._split_path = self._get_split_path(split_file, self._anno_path)
756
757    def __len__(self):
758        return len(self._samples)
759
760    def __contains__(self, item):
761        return item in self._samples
762
763    def __getitem__(self, item):
764        return self._samples[item]
765
766    def __iter__(self):
767        for item in self._samples.items():
768            yield item
769
770    @classmethod
771    def get_data_path_from_root(cls, root_path):
772        return os.path.join(root_path, cls.DATA_DIR)
773
774    @classmethod
775    def get_cache_path_from_root(cls, root_path):
776        return os.path.join(root_path, cls.CACHE_DIR)
777
778    def _get_root_path(self, root_path, annotation_file):
779        if root_path is None:
780            dirpath = os.path.dirname(annotation_file)
781            if self.ANNO_DIR in dirpath:
782                while os.path.basename(dirpath) != self.ANNO_DIR and dirpath != '/':
783                    dirpath = os.path.dirname(dirpath)
784                root_path = os.path.abspath(os.path.dirname(dirpath))
785                log.info("Dataset root path inferred to be: {}".format(root_path))
786        return root_path
787
788    def _get_split_path(self, split_file, anno_path):
789        split_path = split_file
790        if split_path is None:
791            split_path = self._DEFAULT_SPLIT_FILE
792        if not os.path.isabs(split_path):
793            if split_path.startswith("./"):
794                anno_dir = os.path.join(self._root_path, self.ANNO_DIR)
795            else:
796                anno_dir = os.path.dirname(anno_path)
797            split_path = os.path.join(anno_dir, split_path)
798            split_subpath = split_path.replace(self._root_path or "", "").lstrip(os.path.sep)
799            log.info("Split subpath: {}".format(split_subpath))
800        return split_path
801
802    @property
803    def iter_samples(self):
804        """
805        returns a iterator of samples
806        :return:
807        """
808        return self._samples.items()
809
810    @property
811    def samples(self):
812        return self._samples.items()
813
814    @property
815    def sample_ids(self):
816        return self._samples.keys()
817
818    @property
819    def sample_values(self):
820        return self._samples.values()
821
822    def get_split_ids(self, splits=None):
823        if splits is None:
824            # Default is return all ids
825            return self._samples.keys()
826        if isinstance(splits, str):
827            splits = [splits]
828
829        all_ids = []
830        for split in splits:
831            if split not in self._splits:
832                log.warning("Provided split: {} was not in dataset".format(split))
833            else:
834                all_ids.extend(self._splits[split])
835
836        return all_ids
837
838    def get_split_samples(self, splits=None):
839        split_ids = self.get_split_ids(splits)
840        samples = []
841        for split_id in split_ids:
842            if split_id in self._samples:
843                samples.append((split_id, self._samples[split_id]))
844            else:
845                log.info(f"Dataset is missing sample: {split_id} in split {self.get_use_for_id(split_id)}, skipping")
846        return samples
847
848    @property
849    def train_samples(self):
850        return self.get_split_samples(SplitNames.TRAIN)
851
852    @property
853    def val_samples(self):
854        return self.get_split_samples(SplitNames.VAL)
855
856    @property
857    def trainval_samples(self):
858        return self.get_split_samples([SplitNames.TRAIN, SplitNames.VAL])
859
860    @property
861    def test_samples(self):
862        return self.get_split_samples(SplitNames.TEST)
863
864    @property
865    def all_samples(self):
866        samples = self.get_split_samples([SplitNames.TRAIN, SplitNames.VAL, SplitNames.TEST])
867        if not len(samples):
868            samples = self.samples
869        return samples
870
871    def get_use_for_id(self, id):
872        for use in self._splits:
873            if id in self._splits[use]:
874                return use
875        return None
876
877    @property
878    def version(self):
879        return __version__
880
881    @property
882    def metadata(self):
883        return self._metadata
884
885    @property
886    def name(self):
887        if not self._root_path:
888            return None
889        return os.path.basename(self._root_path)
890
891    @property
892    def root_path(self):
893        return self._root_path
894
895    @property
896    def anno_root_path(self):
897        return os.path.join(self._root_path, self.ANNO_DIR)
898
899    @property
900    def cache_root_path(self):
901        return self._cache_path
902
903    @property
904    def data_root_path(self):
905        return self._data_path
906
907    @property
908    def anno_path(self):
909        return self._anno_path
910
911    def _get_anno_subpath(self, anno_path, with_ext):
912        subpath = Path(anno_path).relative_to(self.anno_root_path)
913        if not with_ext:
914            subpath = subpath.with_suffix("")
915        return str(subpath)
916
917    def get_anno_subpath(self, with_ext=False):
918        return self._get_anno_subpath(self._anno_path, with_ext)
919
920    def get_anno_suffix(self):
921        subpath = self.get_anno_subpath()
922        return "_" + subpath.replace(os.sep, "_")
923
924    @property
925    def split_path(self):
926        return self._split_path
927
928    def get_split_subpath(self, with_ext=False):
929        return self._get_anno_subpath(self._split_path, with_ext)
930
931    def get_split_suffix(self):
932        subpath = self.get_split_subpath()
933        return "_" + subpath.replace(os.sep, "_")
934
935    @metadata.setter
936    def metadata(self, new_md):
937        if self._metadata:
938            # compare keys and report warning
939            diff_keys = set(self._metadata.keys()).difference(new_md.keys())
940            log.warning("updating the metadata with different keys: {}".format(diff_keys))
941        self._metadata = new_md
942
943    @property
944    def description(self):
945        return self._metadata.get(FieldNames.DESCRIPTION, "")
946
947    @description.setter
948    def description(self, description):
949        self._metadata[FieldNames.DESCRIPTION] = description
950
951    def add_sample(self, sample:DataSample, dump_directly=False):
952        # create a new sample so it functions just as it would if it were loaded from disk
953        new_sample = DataSample(sample.id, raw_info=sample.to_dict(), root_path=self.root_path, dataset=self)
954        self._samples[sample.id] = new_sample
955        if dump_directly:
956            new_sample._set_filepath(True, already_loaded=True)
957            new_sample._dump_for_lazy_load()
958            new_sample.clear_lazy_loaded()
959        return new_sample
960
961    def dumps(self, encoder=ComplexEncoder, **kwargs):
962        return json.dumps(self._to_dict(), cls=encoder, **kwargs)
963
964    def dump(self, filename=None, indent=0, **kwargs):
965        if filename is None:
966            filename = self._anno_path
967            anno_dir = os.path.dirname(self._anno_path)
968            if not os.path.exists(anno_dir):
969                try:
970                    os.mkdir(anno_dir)
971                except OSError:
972                    pass
973        save_json(self._to_dict(), filename, indent=indent, **kwargs)
974
975    def _get_pickle_path(self):
976        return Path(self._anno_path).with_suffix(".pkl")
977
978    def _anno_mod_time(self):
979        return Path(self._anno_path).stat().st_mtime
980
981    def dump_pickle(self, filepath=None, **kwargs):
982        if filepath is None:
983            filepath = str(self._get_pickle_path())
984        modified_time = self._anno_mod_time()
985        to_pickle = {
986            "_samples": self._samples,
987            "_metadata": self._metadata,
988            "_raw_info": self._raw_info,
989            "modified_time": modified_time
990        }
991        save_pickle(to_pickle, filepath, **kwargs)
992
993    def _load_pickle(self, filepath=None):
994        import datetime
995        if filepath is None:
996            filepath = str(self._get_pickle_path())
997        if not os.path.exists(filepath):
998            return False
999        log.info('Loading pickle file {}'.format(filepath))
1000        try:
1001            loaded_dict = load_pickle(filepath)
1002        except OSError:
1003            log.warning("Failed to load pickle")
1004            return False
1005
1006        stored_time = loaded_dict["modified_time"]
1007        modified_time = self._anno_mod_time()
1008        if stored_time == modified_time:
1009            self._samples = loaded_dict["_samples"]
1010            self._metadata = loaded_dict["_metadata"]
1011            self._raw_info = loaded_dict["_raw_info"]
1012            return True
1013        else:
1014            log.info(('The pickle stored modification time did not match the annotation, so not loading,'
1015                     ' please remove and regenerate the pickle, renaming to .old').format(filepath))
1016            new_filepath = str(filepath) + ".old_" + str(datetime.datetime.now()).replace(" ", "_")
1017            try:
1018                Path(filepath).rename(new_filepath)
1019            except OSError:
1020                pass
1021            return False
1022
1023    def _parse_anno(self, annotation_file):
1024        json_info = load_json(annotation_file)
1025        self._raw_info = json_info
1026        log.info("loaded anno json")
1027
1028        # load metadata
1029        assert FieldNames.DATASET_METADATA in self._raw_info, \
1030            "key: {} should present in the annotation file, we only got {}".format(
1031                FieldNames.DATASET_METADATA, self._raw_info.keys())
1032        self._metadata.update(self._raw_info[FieldNames.DATASET_METADATA])
1033
1034        # check key map hash
1035        key_hash = self._metadata.get(FieldNames.KEY_HASH, '')
1036        assert key_hash == FieldNames.get_key_hash(), "Key list not matching. " \
1037                                                      "Maybe this annoation file is created with other versions." \
1038                                                      "Current version is {}".format(self.version)
1039
1040        # load samples
1041        sample_dict = self._raw_info.get(FieldNames.SAMPLE_DICT, dict())
1042        root_path = self.root_path
1043        for k in sorted(sample_dict.keys()):
1044            self._samples[k] = DataSample(k, sample_dict[k], root_path=root_path, dataset=self)
1045            # self._samples[k] = sample_dict[k]
1046        log.info("loaded {} samples".format(len(self._samples)))
1047
1048    def _load_split(self):
1049        if not os.path.exists(self._split_path):
1050            log.warning("Split path {} not found, skipping loading".format(self._split_path))
1051            return
1052        self._splits = load_json(self._split_path)
1053        split_sample_nums = {k: len(v) for k, v in self._splits.items()}
1054        log.info("Loaded splits with # samples: {}".format(split_sample_nums))
1055
1056    @property
1057    def splits(self):
1058        return dict(self._splits)
1059
1060    @splits.setter
1061    def splits(self, split_dict):
1062        self._splits = split_dict
1063
1064    def dump_splits(self, filename=None, indent=2):
1065        if filename is None:
1066            filename = self._split_path
1067        save_json(self._splits, filename, indent=indent)
1068
1069    def _to_dict(self, dump_sample_files=True):
1070        # add the version information to metadata
1071        self._metadata[FieldNames.DATASET_VERSION] = self.version
1072        self._metadata[FieldNames.KEY_HASH] = FieldNames.get_key_hash()
1073
1074        dump_dict = dict()
1075        dump_dict[FieldNames.DATASET_METADATA] = self._metadata
1076        all_samples_dict = {}
1077        for sample_id, sample in self._samples.items():
1078            sample_dict = sample.to_dict(lazy_load_format=True)
1079            if sample._lazy_loaded and dump_sample_files:
1080                sample._dump_for_lazy_load()
1081            all_samples_dict[sample_id] = sample_dict
1082        dump_dict[FieldNames.SAMPLE_DICT] = all_samples_dict
1083        return dump_dict
1084
1085
1086def get_resized_256_video_location(data_sample: DataSample) -> str:
1087    return get_resized_video_location(data_sample, 256)
1088
1089
1090def get_resized_video_location(data_sample: DataSample, short_edge_res:int) -> str:
1091    return data_sample.get_cache_file('rgb_{}_mp4'.format(short_edge_res), extension='.mp4')
1092
1093
1094def get_vis_video_location(data_sample: DataSample) -> str:
1095    return data_sample.get_cache_file('full_res_mp4', extension='.mp4')
1096
1097
1098def get_vis_thumb_location(data_sample: DataSample) -> str:
1099    return data_sample.get_cache_file('thumbnails', extension='.jpg')
1100
1101
1102def get_vis_gt_location(data_sample: DataSample, cache_subpath: str) -> str:
1103    return data_sample.get_cache_file(os.path.join('gt_vis_json', cache_subpath), extension='.json')
1104
1105
1106def get_gt_data_sample_location(data_sample: DataSample) -> str:
1107    return data_sample.get_cache_file('gt_data_sample_json', extension='.json')
1108