1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3import json
4import logging
5import os
6from typing import Any, Callable, Dict, Optional, Type
7
8import torch
9from iopath.common.file_io import g_pathmgr
10from pytorchvideo.data.clip_sampling import (
11    ClipInfo,
12)
13from pytorchvideo.data.clip_sampling import ClipSampler
14from pytorchvideo.data.labeled_video_dataset import LabeledVideoDataset
15
16
17logger = logging.getLogger(__name__)
18
19
20def video_only_dataset(
21    data_path: str,
22    clip_sampler: ClipSampler,
23    video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
24    transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
25    video_path_prefix: str = "",
26    decode_audio: bool = True,
27    decoder: str = "pyav",
28):
29    """
30    Builds a LabeledVideoDataset with no annotations from a json file with the following
31    format:
32
33        .. code-block:: text
34
35            {
36              "video_name1": {...}
37              "video_name2": {...}
38              ....
39              "video_nameN": {...}
40            }
41
42    Args:
43        labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
44                video file paths and associated labels. If video paths are a folder
45                it's interpreted as a frame video, otherwise it must be an encoded
46                video.
47
48        clip_sampler (ClipSampler): Defines how clips should be sampled from each
49            video. See the clip sampling documentation for more information.
50
51        video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
52            video container. This defines the order videos are decoded and,
53            if necessary, the distributed split.
54
55        transform (Callable): This callable is evaluated on the clip output before
56            the clip is returned. It can be used for user defined preprocessing and
57            augmentations on the clips. The clip output format is described in __next__().
58
59        decode_audio (bool): If True, also decode audio from video.
60
61        decoder (str): Defines what type of decoder used to decode a video. Not used for
62            frame videos.
63    """
64
65    torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.json_dataset.video_only_dataset")
66
67    if g_pathmgr.isfile(data_path):
68        try:
69            with g_pathmgr.open(data_path, "r") as f:
70                annotations = json.load(f)
71        except Exception:
72            raise FileNotFoundError(f"{data_path} must be json for Ego4D dataset")
73
74        # LabeledVideoDataset requires the data to be list of tuples with format:
75        # (video_paths, annotation_dict), for no annotations we just pass in an empty dict.
76        video_paths = [
77            (os.path.join(video_path_prefix, x), {}) for x in annotations.keys()
78        ]
79    else:
80        raise FileNotFoundError(f"{data_path} not found.")
81
82    dataset = LabeledVideoDataset(
83        video_paths,
84        clip_sampler,
85        video_sampler,
86        transform,
87        decode_audio=decode_audio,
88        decoder=decoder,
89    )
90    return dataset
91
92
93def clip_recognition_dataset(
94    data_path: str,
95    clip_sampler: ClipSampler,
96    video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
97    transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
98    video_path_prefix: str = "",
99    decode_audio: bool = True,
100    decoder: str = "pyav",
101):
102    """
103    Builds a LabeledVideoDataset with noun, verb annotations from a json file with the following
104    format:
105
106        .. code-block:: text
107
108            {
109              "video_name1": {
110                  {
111                    "benchmarks": {
112                        "forecasting_hands_objects": [
113                            {
114                                "critical_frame_selection_parent_start_sec": <start_sec>
115                                "critical_frame_selection_parent_end_sec": <end_sec>
116                                {
117                                    "taxonomy: {
118                                        "noun": <label>,
119                                        "verb": <label>,
120                                    }
121                                }
122                            },
123                            {
124                                ...
125                            }
126                        ]
127                    }
128                  }
129              }
130              "video_name2": {...}
131              ....
132              "video_nameN": {...}
133            }
134
135    Args:
136        labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
137                video file paths and associated labels. If video paths are a folder
138                it's interpreted as a frame video, otherwise it must be an encoded
139                video.
140
141        clip_sampler (ClipSampler): Defines how clips should be sampled from each
142            video. See the clip sampling documentation for more information.
143
144        video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
145            video container. This defines the order videos are decoded and,
146            if necessary, the distributed split.
147
148        transform (Callable): This callable is evaluated on the clip output before
149            the clip is returned. It can be used for user defined preprocessing and
150            augmentations on the clips. The clip output format is described in __next__().
151
152        decode_audio (bool): If True, also decode audio from video.
153
154        decoder (str): Defines what type of decoder used to decode a video. Not used for
155            frame videos.
156    """
157    if g_pathmgr.isfile(data_path):
158        try:
159            with g_pathmgr.open(data_path, "r") as f:
160                annotations = json.load(f)
161        except Exception:
162            raise FileNotFoundError(f"{data_path} must be json for Ego4D dataset")
163
164        # LabeledVideoDataset requires the data to be list of tuples with format:
165        # (video_paths, annotation_dict), for no annotations we just pass in an empty dict.
166        untrimmed_clip_annotations = []
167        for video_name, child in annotations.items():
168            video_path = os.path.join(video_path_prefix, video_name)
169            for clip_annotation in child["benchmarks"]["forecasting_hands_objects"]:
170                clip_start = clip_annotation[
171                    "critical_frame_selection_parent_start_sec"
172                ]
173                clip_end = clip_annotation["critical_frame_selection_parent_end_sec"]
174                taxonomy = clip_annotation["taxonomy"]
175                noun_label = taxonomy["noun"]
176                verb_label = taxonomy["verb"]
177                verb_unsure = taxonomy["verb_unsure"]
178                noun_unsure = taxonomy["noun_unsure"]
179                if (
180                    noun_label is None
181                    or verb_label is None
182                    or verb_unsure
183                    or noun_unsure
184                ):
185                    continue
186
187                untrimmed_clip_annotations.append(
188                    (
189                        video_path,
190                        {
191                            "clip_start_sec": clip_start,
192                            "clip_end_sec": clip_end,
193                            "noun_label": noun_label,
194                            "verb_label": verb_label,
195                        },
196                    )
197                )
198    else:
199        raise FileNotFoundError(f"{data_path} not found.")
200
201    # Map noun and verb key words to unique index.
202    def map_labels_to_index(label_name):
203        labels = list({info[label_name] for _, info in untrimmed_clip_annotations})
204        label_to_idx = {label: i for i, label in enumerate(labels)}
205        for i in range(len(untrimmed_clip_annotations)):
206            label = untrimmed_clip_annotations[i][1][label_name]
207            untrimmed_clip_annotations[i][1][label_name] = label_to_idx[label]
208
209    map_labels_to_index("noun_label")
210    map_labels_to_index("verb_label")
211
212    dataset = LabeledVideoDataset(
213        untrimmed_clip_annotations,
214        UntrimmedClipSampler(clip_sampler),
215        video_sampler,
216        transform,
217        decode_audio=decode_audio,
218        decoder=decoder,
219    )
220    return dataset
221
222
223class UntrimmedClipSampler:
224    """
225    A wrapper for adapting untrimmed annotated clips from the json_dataset to the
226    standard `pytorchvideo.data.ClipSampler` expected format. Specifically, for each
227    clip it uses the provided `clip_sampler` to sample between "clip_start_sec" and
228    "clip_end_sec" from the json_dataset clip annotation.
229    """
230
231    def __init__(self, clip_sampler: ClipSampler) -> None:
232        """
233        Args:
234            clip_sampler (`pytorchvideo.data.ClipSampler`): Strategy used for sampling
235                between the untrimmed clip boundary.
236        """
237        self._trimmed_clip_sampler = clip_sampler
238
239    def __call__(
240        self, last_clip_time: float, video_duration: float, clip_info: Dict[str, Any]
241    ) -> ClipInfo:
242        clip_start_boundary = clip_info["clip_start_sec"]
243        clip_end_boundary = clip_info["clip_end_sec"]
244        duration = clip_start_boundary - clip_end_boundary
245
246        # Sample between 0 and duration of untrimmed clip, then add back start boundary.
247        clip_info = self._trimmed_clip_sampler(last_clip_time, duration, clip_info)
248        return ClipInfo(
249            clip_info.clip_start_sec + clip_start_boundary,
250            clip_info.clip_end_sec + clip_start_boundary,
251            clip_info.clip_index,
252            clip_info.aug_index,
253            clip_info.is_last_clip,
254        )
255