1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3from __future__ import annotations 4 5import logging 6from typing import Any, Callable, Dict, List, Optional, Tuple, Type 7 8import torch.utils.data 9from pytorchvideo.data.clip_sampling import ClipSampler 10from pytorchvideo.data.video import VideoPathHandler 11 12from .labeled_video_paths import LabeledVideoPaths 13from .utils import MultiProcessSampler 14 15 16logger = logging.getLogger(__name__) 17 18 19class LabeledVideoDataset(torch.utils.data.IterableDataset): 20 """ 21 LabeledVideoDataset handles the storage, loading, decoding and clip sampling for a 22 video dataset. It assumes each video is stored as either an encoded video 23 (e.g. mp4, avi) or a frame video (e.g. a folder of jpg, or png) 24 """ 25 26 _MAX_CONSECUTIVE_FAILURES = 10 27 28 def __init__( 29 self, 30 labeled_video_paths: List[Tuple[str, Optional[dict]]], 31 clip_sampler: ClipSampler, 32 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, 33 transform: Optional[Callable[[dict], Any]] = None, 34 decode_audio: bool = True, 35 decoder: str = "pyav", 36 ) -> None: 37 """ 38 Args: 39 labeled_video_paths (List[Tuple[str, Optional[dict]]]): List containing 40 video file paths and associated labels. If video paths are a folder 41 it's interpreted as a frame video, otherwise it must be an encoded 42 video. 43 44 clip_sampler (ClipSampler): Defines how clips should be sampled from each 45 video. See the clip sampling documentation for more information. 46 47 video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal 48 video container. This defines the order videos are decoded and, 49 if necessary, the distributed split. 50 51 transform (Callable): This callable is evaluated on the clip output before 52 the clip is returned. It can be used for user defined preprocessing and 53 augmentations on the clips. The clip output format is described in __next__(). 54 55 decode_audio (bool): If True, also decode audio from video. 56 57 decoder (str): Defines what type of decoder used to decode a video. Not used for 58 frame videos. 59 """ 60 self._decode_audio = decode_audio 61 self._transform = transform 62 self._clip_sampler = clip_sampler 63 self._labeled_videos = labeled_video_paths 64 self._decoder = decoder 65 66 # If a RandomSampler is used we need to pass in a custom random generator that 67 # ensures all PyTorch multiprocess workers have the same random seed. 68 self._video_random_generator = None 69 if video_sampler == torch.utils.data.RandomSampler: 70 self._video_random_generator = torch.Generator() 71 self._video_sampler = video_sampler( 72 self._labeled_videos, generator=self._video_random_generator 73 ) 74 else: 75 self._video_sampler = video_sampler(self._labeled_videos) 76 77 self._video_sampler_iter = None # Initialized on first call to self.__next__() 78 79 # Depending on the clip sampler type, we may want to sample multiple clips 80 # from one video. In that case, we keep the store video, label and previous sampled 81 # clip time in these variables. 82 self._loaded_video_label = None 83 self._loaded_clip = None 84 self._next_clip_start_time = 0.0 85 self.video_path_handler = VideoPathHandler() 86 87 @property 88 def video_sampler(self): 89 """ 90 Returns: 91 The video sampler that defines video sample order. Note that you'll need to 92 use this property to set the epoch for a torch.utils.data.DistributedSampler. 93 """ 94 return self._video_sampler 95 96 @property 97 def num_videos(self): 98 """ 99 Returns: 100 Number of videos in dataset. 101 """ 102 return len(self.video_sampler) 103 104 def __next__(self) -> dict: 105 """ 106 Retrieves the next clip based on the clip sampling strategy and video sampler. 107 108 Returns: 109 A dictionary with the following format. 110 111 .. code-block:: text 112 113 { 114 'video': <video_tensor>, 115 'label': <index_label>, 116 'video_label': <index_label> 117 'video_index': <video_index>, 118 'clip_index': <clip_index>, 119 'aug_index': <aug_index>, 120 } 121 """ 122 if not self._video_sampler_iter: 123 # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned. 124 self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler)) 125 126 for i_try in range(self._MAX_CONSECUTIVE_FAILURES): 127 # Reuse previously stored video if there are still clips to be sampled from 128 # the last loaded video. 129 if self._loaded_video_label: 130 video, info_dict, video_index = self._loaded_video_label 131 else: 132 video_index = next(self._video_sampler_iter) 133 try: 134 video_path, info_dict = self._labeled_videos[video_index] 135 video = self.video_path_handler.video_from_path( 136 video_path, 137 decode_audio=self._decode_audio, 138 decoder=self._decoder, 139 ) 140 self._loaded_video_label = (video, info_dict, video_index) 141 except Exception as e: 142 logger.debug( 143 "Failed to load video with error: {}; trial {}".format( 144 e, 145 i_try, 146 ) 147 ) 148 continue 149 150 ( 151 clip_start, 152 clip_end, 153 clip_index, 154 aug_index, 155 is_last_clip, 156 ) = self._clip_sampler( 157 self._next_clip_start_time, video.duration, info_dict 158 ) 159 160 # Only load the clip once and reuse previously stored clip if there are multiple 161 # views for augmentations to perform on the same clip. 162 if aug_index == 0: 163 self._loaded_clip = video.get_clip(clip_start, clip_end) 164 165 self._next_clip_start_time = clip_end 166 167 video_is_null = ( 168 self._loaded_clip is None or self._loaded_clip["video"] is None 169 ) 170 if is_last_clip or video_is_null: 171 # Close the loaded encoded video and reset the last sampled clip time ready 172 # to sample a new video on the next iteration. 173 self._loaded_video_label[0].close() 174 self._loaded_video_label = None 175 self._next_clip_start_time = 0.0 176 177 if video_is_null: 178 logger.debug( 179 "Failed to load clip {}; trial {}".format(video.name, i_try) 180 ) 181 continue 182 183 frames = self._loaded_clip["video"] 184 audio_samples = self._loaded_clip["audio"] 185 sample_dict = { 186 "video": frames, 187 "video_name": video.name, 188 "video_index": video_index, 189 "clip_index": clip_index, 190 "aug_index": aug_index, 191 **info_dict, 192 **({"audio": audio_samples} if audio_samples is not None else {}), 193 } 194 if self._transform is not None: 195 sample_dict = self._transform(sample_dict) 196 197 # User can force dataset to continue by returning None in transform. 198 if sample_dict is None: 199 continue 200 201 return sample_dict 202 else: 203 raise RuntimeError( 204 f"Failed to load video after {self._MAX_CONSECUTIVE_FAILURES} retries." 205 ) 206 207 def __iter__(self): 208 self._video_sampler_iter = None # Reset video sampler 209 210 # If we're in a PyTorch DataLoader multiprocessing context, we need to use the 211 # same seed for each worker's RandomSampler generator. The workers at each 212 # __iter__ call are created from the unique value: worker_info.seed - worker_info.id, 213 # which we can use for this seed. 214 worker_info = torch.utils.data.get_worker_info() 215 if self._video_random_generator is not None and worker_info is not None: 216 base_seed = worker_info.seed - worker_info.id 217 self._video_random_generator.manual_seed(base_seed) 218 219 return self 220 221 222def labeled_video_dataset( 223 data_path: str, 224 clip_sampler: ClipSampler, 225 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, 226 transform: Optional[Callable[[Dict[str, Any]], Dict[str, Any]]] = None, 227 video_path_prefix: str = "", 228 decode_audio: bool = True, 229 decoder: str = "pyav", 230) -> LabeledVideoDataset: 231 """ 232 A helper function to create ``LabeledVideoDataset`` object for Ucf101 and Kinetics datasets. 233 234 Args: 235 data_path (str): Path to the data. The path type defines how the data 236 should be read: 237 238 * For a file path, the file is read and each line is parsed into a 239 video path and label. 240 * For a directory, the directory structure defines the classes 241 (i.e. each subdirectory is a class). 242 243 clip_sampler (ClipSampler): Defines how clips should be sampled from each 244 video. See the clip sampling documentation for more information. 245 246 video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal 247 video container. This defines the order videos are decoded and, 248 if necessary, the distributed split. 249 250 transform (Callable): This callable is evaluated on the clip output before 251 the clip is returned. It can be used for user defined preprocessing and 252 augmentations to the clips. See the ``LabeledVideoDataset`` class for clip 253 output format. 254 255 video_path_prefix (str): Path to root directory with the videos that are 256 loaded in ``LabeledVideoDataset``. All the video paths before loading 257 are prefixed with this path. 258 259 decode_audio (bool): If True, also decode audio from video. 260 261 decoder (str): Defines what type of decoder used to decode a video. 262 263 """ 264 labeled_video_paths = LabeledVideoPaths.from_path(data_path) 265 labeled_video_paths.path_prefix = video_path_prefix 266 dataset = LabeledVideoDataset( 267 labeled_video_paths, 268 clip_sampler, 269 video_sampler, 270 transform, 271 decode_audio=decode_audio, 272 decoder=decoder, 273 ) 274 return dataset 275