1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3from __future__ import annotations
4
5import logging
6from typing import Any, Callable, Dict, List, Optional, Tuple, Type
7
8import torch.utils.data
9from pytorchvideo.data.clip_sampling import ClipSampler
10from pytorchvideo.data.video import VideoPathHandler
11
12from .labeled_video_paths import LabeledVideoPaths
13from .utils import MultiProcessSampler
14
15
16logger = logging.getLogger(__name__)
17
18
19class LabeledVideoDataset(torch.utils.data.IterableDataset):
20    """
21    LabeledVideoDataset handles the storage, loading, decoding and clip sampling for a
22    video dataset. It assumes each video is stored as either an encoded video
23    (e.g. mp4, avi) or a frame video (e.g. a folder of jpg, or png)
24    """
25
26    _MAX_CONSECUTIVE_FAILURES = 10
27
28    def __init__(
29        self,
30        labeled_video_paths: List[Tuple[str, Optional[dict]]],
31        clip_sampler: ClipSampler,
32        video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
33        transform: Optional[Callable[[dict], Any]] = None,
34        decode_audio: bool = True,
35        decoder: str = "pyav",
36    ) -> None:
37        """
38        Args:
39            labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing
40                    video file paths and associated labels. If video paths are a folder
41                    it's interpreted as a frame video, otherwise it must be an encoded
42                    video.
43
44            clip_sampler (ClipSampler): Defines how clips should be sampled from each
45                video. See the clip sampling documentation for more information.
46
47            video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
48                video container. This defines the order videos are decoded and,
49                if necessary, the distributed split.
50
51            transform (Callable): This callable is evaluated on the clip output before
52                the clip is returned. It can be used for user defined preprocessing and
53                augmentations on the clips. The clip output format is described in __next__().
54
55            decode_audio (bool): If True, also decode audio from video.
56
57            decoder (str): Defines what type of decoder used to decode a video. Not used for
58                frame videos.
59        """
60        self._decode_audio = decode_audio
61        self._transform = transform
62        self._clip_sampler = clip_sampler
63        self._labeled_videos = labeled_video_paths
64        self._decoder = decoder
65
66        # If a RandomSampler is used we need to pass in a custom random generator that
67        # ensures all PyTorch multiprocess workers have the same random seed.
68        self._video_random_generator = None
69        if video_sampler == torch.utils.data.RandomSampler:
70            self._video_random_generator = torch.Generator()
71            self._video_sampler = video_sampler(
72                self._labeled_videos, generator=self._video_random_generator
73            )
74        else:
75            self._video_sampler = video_sampler(self._labeled_videos)
76
77        self._video_sampler_iter = None  # Initialized on first call to self.__next__()
78
79        # Depending on the clip sampler type, we may want to sample multiple clips
80        # from one video. In that case, we keep the store video, label and previous sampled
81        # clip time in these variables.
82        self._loaded_video_label = None
83        self._loaded_clip = None
84        self._next_clip_start_time = 0.0
85        self.video_path_handler = VideoPathHandler()
86
87    @property
88    def video_sampler(self):
89        """
90        Returns:
91            The video sampler that defines video sample order. Note that you'll need to
92            use this property to set the epoch for a torch.utils.data.DistributedSampler.
93        """
94        return self._video_sampler
95
96    @property
97    def num_videos(self):
98        """
99        Returns:
100            Number of videos in dataset.
101        """
102        return len(self.video_sampler)
103
104    def __next__(self) -> dict:
105        """
106        Retrieves the next clip based on the clip sampling strategy and video sampler.
107
108        Returns:
109            A dictionary with the following format.
110
111            .. code-block:: text
112
113                {
114                    'video': <video_tensor>,
115                    'label': <index_label>,
116                    'video_label': <index_label>
117                    'video_index': <video_index>,
118                    'clip_index': <clip_index>,
119                    'aug_index': <aug_index>,
120                }
121        """
122        if not self._video_sampler_iter:
123            # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
124            self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))
125
126        for i_try in range(self._MAX_CONSECUTIVE_FAILURES):
127            # Reuse previously stored video if there are still clips to be sampled from
128            # the last loaded video.
129            if self._loaded_video_label:
130                video, info_dict, video_index = self._loaded_video_label
131            else:
132                video_index = next(self._video_sampler_iter)
133                try:
134                    video_path, info_dict = self._labeled_videos[video_index]
135                    video = self.video_path_handler.video_from_path(
136                        video_path,
137                        decode_audio=self._decode_audio,
138                        decoder=self._decoder,
139                    )
140                    self._loaded_video_label = (video, info_dict, video_index)
141                except Exception as e:
142                    logger.debug(
143                        "Failed to load video with error: {}; trial {}".format(
144                            e,
145                            i_try,
146                        )
147                    )
148                    continue
149
150            (
151                clip_start,
152                clip_end,
153                clip_index,
154                aug_index,
155                is_last_clip,
156            ) = self._clip_sampler(
157                self._next_clip_start_time, video.duration, info_dict
158            )
159
160            # Only load the clip once and reuse previously stored clip if there are multiple
161            # views for augmentations to perform on the same clip.
162            if aug_index == 0:
163                self._loaded_clip = video.get_clip(clip_start, clip_end)
164
165            self._next_clip_start_time = clip_end
166
167            video_is_null = (
168                self._loaded_clip is None or self._loaded_clip["video"] is None
169            )
170            if is_last_clip or video_is_null:
171                # Close the loaded encoded video and reset the last sampled clip time ready
172                # to sample a new video on the next iteration.
173                self._loaded_video_label[0].close()
174                self._loaded_video_label = None
175                self._next_clip_start_time = 0.0
176
177                if video_is_null:
178                    logger.debug(
179                        "Failed to load clip {}; trial {}".format(video.name, i_try)
180                    )
181                    continue
182
183            frames = self._loaded_clip["video"]
184            audio_samples = self._loaded_clip["audio"]
185            sample_dict = {
186                "video": frames,
187                "video_name": video.name,
188                "video_index": video_index,
189                "clip_index": clip_index,
190                "aug_index": aug_index,
191                **info_dict,
192                **({"audio": audio_samples} if audio_samples is not None else {}),
193            }
194            if self._transform is not None:
195                sample_dict = self._transform(sample_dict)
196
197                # User can force dataset to continue by returning None in transform.
198                if sample_dict is None:
199                    continue
200
201            return sample_dict
202        else:
203            raise RuntimeError(
204                f"Failed to load video after {self._MAX_CONSECUTIVE_FAILURES} retries."
205            )
206
207    def __iter__(self):
208        self._video_sampler_iter = None  # Reset video sampler
209
210        # If we're in a PyTorch DataLoader multiprocessing context, we need to use the
211        # same seed for each worker's RandomSampler generator. The workers at each
212        # __iter__ call are created from the unique value: worker_info.seed - worker_info.id,
213        # which we can use for this seed.
214        worker_info = torch.utils.data.get_worker_info()
215        if self._video_random_generator is not None and worker_info is not None:
216            base_seed = worker_info.seed - worker_info.id
217            self._video_random_generator.manual_seed(base_seed)
218
219        return self
220
221
222def labeled_video_dataset(
223    data_path: str,
224    clip_sampler: ClipSampler,
225    video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
226    transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None,
227    video_path_prefix: str = "",
228    decode_audio: bool = True,
229    decoder: str = "pyav",
230) -> LabeledVideoDataset:
231    """
232    A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets.
233
234    Args:
235        data_path (str): Path to the data. The path type defines how the data
236            should be read:
237
238            * For a file path, the file is read and each line is parsed into a
239              video path and label.
240            * For a directory, the directory structure defines the classes
241              (i.e. each subdirectory is a class).
242
243        clip_sampler (ClipSampler): Defines how clips should be sampled from each
244                video. See the clip sampling documentation for more information.
245
246        video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
247                video container. This defines the order videos are decoded and,
248                if necessary, the distributed split.
249
250        transform (Callable): This callable is evaluated on the clip output before
251                the clip is returned. It can be used for user defined preprocessing and
252                augmentations to the clips. See the ``LabeledVideoDataset`` class for clip
253                output format.
254
255        video_path_prefix (str): Path to root directory with the videos that are
256                loaded in ``LabeledVideoDataset``. All the video paths before loading
257                are prefixed with this path.
258
259        decode_audio (bool): If True, also decode audio from video.
260
261        decoder (str): Defines what type of decoder used to decode a video.
262
263    """
264    labeled_video_paths = LabeledVideoPaths.from_path(data_path)
265    labeled_video_paths.path_prefix = video_path_prefix
266    dataset = LabeledVideoDataset(
267        labeled_video_paths,
268        clip_sampler,
269        video_sampler,
270        transform,
271        decode_audio=decode_audio,
272        decoder=decoder,
273    )
274    return dataset
275