1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3import csv 4import functools 5import itertools 6import os 7from collections import defaultdict 8from typing import Any, Callable, List, Optional, Tuple, Type 9 10import torch 11import torch.utils.data 12from iopath.common.file_io import g_pathmgr 13from pytorchvideo.data.clip_sampling import ClipSampler 14from pytorchvideo.data.frame_video import FrameVideo 15 16from .utils import MultiProcessSampler 17 18 19class Charades(torch.utils.data.IterableDataset): 20 """ 21 Action recognition video dataset for 22 `Charades <https://prior.allenai.org/projects/charades>`_ stored as image frames. 23 24 This dataset handles the parsing of frames, loading and clip sampling for the 25 videos. All io is done through :code:`iopath.common.file_io.PathManager`, enabling 26 non-local storage uri's to be used. 27 """ 28 29 # Number of classes represented by this dataset's annotated labels. 30 NUM_CLASSES = 157 31 32 def __init__( 33 self, 34 data_path: str, 35 clip_sampler: ClipSampler, 36 video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler, 37 transform: Optional[Callable[[dict], Any]] = None, 38 video_path_prefix: str = "", 39 frames_per_clip: Optional[int] = None, 40 ) -> None: 41 """ 42 Args: 43 data_path (str): Path to the data file. This file must be a space 44 separated csv with the format: (original_vido_id video_id frame_id 45 path_labels) 46 47 clip_sampler (ClipSampler): Defines how clips should be sampled from each 48 video. See the clip sampling documentation for more information. 49 50 video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal 51 video container. This defines the order videos are decoded and, 52 if necessary, the distributed split. 53 54 transform (Optional[Callable]): This callable is evaluated on the clip output before 55 the clip is returned. It can be used for user defined preprocessing and 56 augmentations on the clips. The clip output format is described in __next__(). 57 58 video_path_prefix (str): prefix path to add to all paths from data_path. 59 60 frames_per_clip (Optional[int]): The number of frames per clip to sample. 61 """ 62 63 torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Charades.__init__") 64 65 self._transform = transform 66 self._clip_sampler = clip_sampler 67 ( 68 self._path_to_videos, 69 self._labels, 70 self._video_labels, 71 ) = _read_video_paths_and_labels(data_path, prefix=video_path_prefix) 72 self._video_sampler = video_sampler(self._path_to_videos) 73 self._video_sampler_iter = None # Initialized on first call to self.__next__() 74 self._frame_filter = ( 75 functools.partial( 76 Charades._sample_clip_frames, 77 frames_per_clip=frames_per_clip, 78 ) 79 if frames_per_clip is not None 80 else None 81 ) 82 83 # Depending on the clip sampler type, we may want to sample multiple clips 84 # from one video. In that case, we keep the store video, label and previous sampled 85 # clip time in these variables. 86 self._loaded_video = None 87 self._loaded_clip = None 88 self._next_clip_start_time = 0.0 89 90 @staticmethod 91 def _sample_clip_frames( 92 frame_indices: List[int], frames_per_clip: int 93 ) -> List[int]: 94 """ 95 Args: 96 frame_indices (list): list of frame indices. 97 frames_per+clip (int): The number of frames per clip to sample. 98 99 Returns: 100 (list): Outputs a subsampled list with num_samples frames. 101 """ 102 num_frames = len(frame_indices) 103 indices = torch.linspace(0, num_frames - 1, frames_per_clip) 104 indices = torch.clamp(indices, 0, num_frames - 1).long() 105 106 return [frame_indices[idx] for idx in indices] 107 108 @property 109 def video_sampler(self) -> torch.utils.data.Sampler: 110 return self._video_sampler 111 112 def __next__(self) -> dict: 113 """ 114 Retrieves the next clip based on the clip sampling strategy and video sampler. 115 116 Returns: 117 A dictionary with the following format. 118 119 .. code-block:: text 120 121 { 122 'video': <video_tensor>, 123 'label': <index_label>, 124 'video_label': <index_label> 125 'video_index': <video_index>, 126 'clip_index': <clip_index>, 127 'aug_index': <aug_index>, 128 } 129 """ 130 if not self._video_sampler_iter: 131 # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned. 132 self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler)) 133 134 if self._loaded_video: 135 video, video_index = self._loaded_video 136 else: 137 video_index = next(self._video_sampler_iter) 138 path_to_video_frames = self._path_to_videos[video_index] 139 video = FrameVideo.from_frame_paths(path_to_video_frames) 140 self._loaded_video = (video, video_index) 141 142 clip_start, clip_end, clip_index, aug_index, is_last_clip = self._clip_sampler( 143 self._next_clip_start_time, video.duration, {} 144 ) 145 # Only load the clip once and reuse previously stored clip if there are multiple 146 # views for augmentations to perform on the same clip. 147 if aug_index == 0: 148 self._loaded_clip = video.get_clip(clip_start, clip_end, self._frame_filter) 149 150 frames, frame_indices = ( 151 self._loaded_clip["video"], 152 self._loaded_clip["frame_indices"], 153 ) 154 self._next_clip_start_time = clip_end 155 156 if is_last_clip: 157 self._loaded_video = None 158 self._next_clip_start_time = 0.0 159 160 # Merge unique labels from each frame into clip label. 161 labels_by_frame = [ 162 self._labels[video_index][i] 163 for i in range(min(frame_indices), max(frame_indices) + 1) 164 ] 165 sample_dict = { 166 "video": frames, 167 "label": labels_by_frame, 168 "video_label": self._video_labels[video_index], 169 "video_name": str(video_index), 170 "video_index": video_index, 171 "clip_index": clip_index, 172 "aug_index": aug_index, 173 } 174 if self._transform is not None: 175 sample_dict = self._transform(sample_dict) 176 177 return sample_dict 178 179 def __iter__(self): 180 return self 181 182 183def _read_video_paths_and_labels( 184 video_path_label_file: List[str], prefix: str = "" 185) -> Tuple[List[str], List[int]]: 186 """ 187 Args: 188 video_path_label_file (List[str]): a file that contains frame paths for each 189 video and the corresponding frame label. The file must be a space separated 190 csv of the format: 191 `original_vido_id video_id frame_id path labels` 192 193 prefix (str): prefix path to add to all paths from video_path_label_file. 194 195 """ 196 image_paths = defaultdict(list) 197 labels = defaultdict(list) 198 with g_pathmgr.open(video_path_label_file, "r") as f: 199 200 # Space separated CSV with format: original_vido_id video_id frame_id path labels 201 csv_reader = csv.DictReader(f, delimiter=" ") 202 for row in csv_reader: 203 assert len(row) == 5 204 video_name = row["original_vido_id"] 205 path = os.path.join(prefix, row["path"]) 206 image_paths[video_name].append(path) 207 frame_labels = row["labels"].replace('"', "") 208 label_list = [] 209 if frame_labels: 210 label_list = [int(x) for x in frame_labels.split(",")] 211 212 labels[video_name].append(label_list) 213 214 # Extract image paths from dictionary and return paths and labels as list. 215 video_names = image_paths.keys() 216 image_paths = [image_paths[key] for key in video_names] 217 labels = [labels[key] for key in video_names] 218 # Aggregate labels from all frames to form video-level labels. 219 video_labels = [list(set(itertools.chain(*label_list))) for label_list in labels] 220 return image_paths, labels, video_labels 221