1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3import csv
4import functools
5import itertools
6import os
7from collections import defaultdict
8from typing import Any, Callable, List, Optional, Tuple, Type
9
10import torch
11import torch.utils.data
12from iopath.common.file_io import g_pathmgr
13from pytorchvideo.data.clip_sampling import ClipSampler
14from pytorchvideo.data.frame_video import FrameVideo
15
16from .utils import MultiProcessSampler
17
18
19class Charades(torch.utils.data.IterableDataset):
20    """
21    Action recognition video dataset for
22    `Charades <https://prior.allenai.org/projects/charades>`_ stored as image frames.
23
24    This dataset handles the parsing of frames, loading and clip sampling for the
25    videos. All io is done through :code:`iopath.common.file_io.PathManager`, enabling
26    non-local storage uri's to be used.
27    """
28
29    # Number of classes represented by this dataset's annotated labels.
30    NUM_CLASSES = 157
31
32    def __init__(
33        self,
34        data_path: str,
35        clip_sampler: ClipSampler,
36        video_sampler: Type[torch.utils.data.Sampler] = torch.utils.data.RandomSampler,
37        transform: Optional[Callable[[dict], Any]] = None,
38        video_path_prefix: str = "",
39        frames_per_clip: Optional[int] = None,
40    ) -> None:
41        """
42        Args:
43            data_path (str): Path to the data file. This file must be a space
44                separated csv with the format: (original_vido_id video_id frame_id
45                path_labels)
46
47            clip_sampler (ClipSampler): Defines how clips should be sampled from each
48                video. See the clip sampling documentation for more information.
49
50            video_sampler (Type[torch.utils.data.Sampler]): Sampler for the internal
51                video container. This defines the order videos are decoded and,
52                if necessary, the distributed split.
53
54            transform (Optional[Callable]): This callable is evaluated on the clip output before
55                the clip is returned. It can be used for user defined preprocessing and
56                augmentations on the clips. The clip output format is described in __next__().
57
58            video_path_prefix (str): prefix path to add to all paths from data_path.
59
60            frames_per_clip (Optional[int]): The number of frames per clip to sample.
61        """
62
63        torch._C._log_api_usage_once("PYTORCHVIDEO.dataset.Charades.__init__")
64
65        self._transform = transform
66        self._clip_sampler = clip_sampler
67        (
68            self._path_to_videos,
69            self._labels,
70            self._video_labels,
71        ) = _read_video_paths_and_labels(data_path, prefix=video_path_prefix)
72        self._video_sampler = video_sampler(self._path_to_videos)
73        self._video_sampler_iter = None  # Initialized on first call to self.__next__()
74        self._frame_filter = (
75            functools.partial(
76                Charades._sample_clip_frames,
77                frames_per_clip=frames_per_clip,
78            )
79            if frames_per_clip is not None
80            else None
81        )
82
83        # Depending on the clip sampler type, we may want to sample multiple clips
84        # from one video. In that case, we keep the store video, label and previous sampled
85        # clip time in these variables.
86        self._loaded_video = None
87        self._loaded_clip = None
88        self._next_clip_start_time = 0.0
89
90    @staticmethod
91    def _sample_clip_frames(
92        frame_indices: List[int], frames_per_clip: int
93    ) -> List[int]:
94        """
95        Args:
96            frame_indices (list): list of frame indices.
97            frames_per+clip (int): The number of frames per clip to sample.
98
99        Returns:
100            (list): Outputs a subsampled list with num_samples frames.
101        """
102        num_frames = len(frame_indices)
103        indices = torch.linspace(0, num_frames - 1, frames_per_clip)
104        indices = torch.clamp(indices, 0, num_frames - 1).long()
105
106        return [frame_indices[idx] for idx in indices]
107
108    @property
109    def video_sampler(self) -> torch.utils.data.Sampler:
110        return self._video_sampler
111
112    def __next__(self) -> dict:
113        """
114        Retrieves the next clip based on the clip sampling strategy and video sampler.
115
116        Returns:
117            A dictionary with the following format.
118
119            .. code-block:: text
120
121                {
122                    'video': <video_tensor>,
123                    'label': <index_label>,
124                    'video_label': <index_label>
125                    'video_index': <video_index>,
126                    'clip_index': <clip_index>,
127                    'aug_index': <aug_index>,
128                }
129        """
130        if not self._video_sampler_iter:
131            # Setup MultiProcessSampler here - after PyTorch DataLoader workers are spawned.
132            self._video_sampler_iter = iter(MultiProcessSampler(self._video_sampler))
133
134        if self._loaded_video:
135            video, video_index = self._loaded_video
136        else:
137            video_index = next(self._video_sampler_iter)
138            path_to_video_frames = self._path_to_videos[video_index]
139            video = FrameVideo.from_frame_paths(path_to_video_frames)
140            self._loaded_video = (video, video_index)
141
142        clip_start, clip_end, clip_index, aug_index, is_last_clip = self._clip_sampler(
143            self._next_clip_start_time, video.duration, {}
144        )
145        # Only load the clip once and reuse previously stored clip if there are multiple
146        # views for augmentations to perform on the same clip.
147        if aug_index == 0:
148            self._loaded_clip = video.get_clip(clip_start, clip_end, self._frame_filter)
149
150        frames, frame_indices = (
151            self._loaded_clip["video"],
152            self._loaded_clip["frame_indices"],
153        )
154        self._next_clip_start_time = clip_end
155
156        if is_last_clip:
157            self._loaded_video = None
158            self._next_clip_start_time = 0.0
159
160        # Merge unique labels from each frame into clip label.
161        labels_by_frame = [
162            self._labels[video_index][i]
163            for i in range(min(frame_indices), max(frame_indices) + 1)
164        ]
165        sample_dict = {
166            "video": frames,
167            "label": labels_by_frame,
168            "video_label": self._video_labels[video_index],
169            "video_name": str(video_index),
170            "video_index": video_index,
171            "clip_index": clip_index,
172            "aug_index": aug_index,
173        }
174        if self._transform is not None:
175            sample_dict = self._transform(sample_dict)
176
177        return sample_dict
178
179    def __iter__(self):
180        return self
181
182
183def _read_video_paths_and_labels(
184    video_path_label_file: List[str], prefix: str = ""
185) -> Tuple[List[str], List[int]]:
186    """
187    Args:
188        video_path_label_file (List[str]): a file that contains frame paths for each
189            video and the corresponding frame label. The file must be a space separated
190            csv of the format:
191                `original_vido_id video_id frame_id path labels`
192
193        prefix (str): prefix path to add to all paths from video_path_label_file.
194
195    """
196    image_paths = defaultdict(list)
197    labels = defaultdict(list)
198    with g_pathmgr.open(video_path_label_file, "r") as f:
199
200        # Space separated CSV with format: original_vido_id video_id frame_id path labels
201        csv_reader = csv.DictReader(f, delimiter=" ")
202        for row in csv_reader:
203            assert len(row) == 5
204            video_name = row["original_vido_id"]
205            path = os.path.join(prefix, row["path"])
206            image_paths[video_name].append(path)
207            frame_labels = row["labels"].replace('"', "")
208            label_list = []
209            if frame_labels:
210                label_list = [int(x) for x in frame_labels.split(",")]
211
212            labels[video_name].append(label_list)
213
214    # Extract image paths from dictionary and return paths and labels as list.
215    video_names = image_paths.keys()
216    image_paths = [image_paths[key] for key in video_names]
217    labels = [labels[key] for key in video_names]
218    # Aggregate labels from all frames to form video-level labels.
219    video_labels = [list(set(itertools.chain(*label_list))) for label_list in labels]
220    return image_paths, labels, video_labels
221