1import os
2import random
3import logging
4import ray
5from pathlib import Path
6from tqdm import tqdm
7
8from .io.video_io import VideoFrameReader, write_img_files_to_vid_ffmpeg, \
9    resize_and_write_video_ffmpeg, convert_vid_ffmpeg
10from .utils.serialization_utils import save_json
11from .dataset import GluonCVMotionDataset, DataSample, FieldNames, get_resized_video_location, \
12    get_vis_gt_location, get_vis_thumb_location, get_vis_video_location
13
14
15def generate_resized_video(data_sample:DataSample, short_edge_res:int, overwrite=False, cache_dir=None,
16                           upscale=True, use_vis_path=True, force_encode=False, encode_kwargs=None, **kwargs):
17    if cache_dir is None:
18        resized_path = get_resized_video_location(data_sample,short_edge_res)
19    else:
20        resized_path = data_sample.get_cache_file(cache_dir, extension='.mp4')
21    if os.path.isfile(resized_path) and not overwrite:
22        return
23    os.makedirs(os.path.dirname(resized_path), exist_ok=True)
24
25    orig_vid_path = get_vis_video_location(data_sample) if use_vis_path else data_sample.data_path
26    if not upscale and min(data_sample.width, data_sample.height) <= short_edge_res:
27        if force_encode:
28            if encode_kwargs is None:
29                encode_kwargs = {k: v for k, v in kwargs.items() if k == "addn_args"}
30            convert_vid_ffmpeg(orig_vid_path, resized_path, **encode_kwargs)
31        else:
32            os.symlink(orig_vid_path, resized_path)
33    else:
34        resize_and_write_video_ffmpeg(orig_vid_path, resized_path, short_edge_res, **kwargs)
35
36    return resized_path
37
38
39def generate_video(data_sample:DataSample, force_encode=False, overwrite=False, **kwargs):
40    new_file = get_vis_video_location(data_sample)
41    if (os.path.isfile(new_file) and not overwrite) and not (force_encode and os.path.islink(new_file)):
42        return
43    os.makedirs(os.path.dirname(new_file), exist_ok=True)
44
45    video_file = data_sample.data_path
46    #### Generate Video ####
47    if os.path.isdir(video_file):
48        # the data is a set of images, generate a video
49        img_files = [os.path.join(video_file, f) for f in sorted(os.listdir(video_file))]
50        write_img_files_to_vid_ffmpeg(out_file=new_file, in_files=img_files, fps=data_sample.metadata['fps'])
51    else:
52        if not video_file.endswith(".mp4") or force_encode:
53            # Convert the video to mp4
54            convert_vid_ffmpeg(video_file, new_file, **kwargs)
55        else:
56            # the data is a video, symlink to it
57            if os.path.exists(new_file):
58                os.remove(new_file)
59            os.symlink(video_file, new_file)
60
61    return new_file
62
63
64def generate_thumbnail(data_sample:DataSample, overwrite=False):
65    video_thumbnail_frame = get_vis_thumb_location(data_sample)
66    if os.path.isfile(video_thumbnail_frame) and not overwrite:
67        return
68    os.makedirs(os.path.dirname(video_thumbnail_frame), exist_ok=True)
69
70    video_file = get_vis_video_location(data_sample)
71
72    #### Generate Thumbnail ####
73    vid = VideoFrameReader(video_file)
74    img, ts = vid.get_frame(30) if len(vid) > 30 else vid.get_frame(0)
75    img.thumbnail((300, 300))
76    img.save(video_thumbnail_frame)
77
78    return video_thumbnail_frame
79
80
81def generate_gt_vis_json(data_sample:DataSample, cache_suffix="", overwrite=False):
82    gt_file = get_vis_gt_location(data_sample, cache_suffix)
83    if os.path.isfile(gt_file) and not overwrite:
84        return
85    #### Generate GT Track json ####
86    os.makedirs(os.path.dirname(gt_file), exist_ok=True)
87
88    vis_video_file = get_vis_video_location(data_sample)
89    vis_vid = VideoFrameReader(vis_video_file)
90    sample_dict = data_sample.to_dict(include_id=True)
91    sample_dict[FieldNames.METADATA][FieldNames.FPS] = vis_vid.fps
92    save_json(sample_dict, gt_file, indent=0)
93
94    return gt_file
95
96
97@ray.remote
98def generate_files_for_one_sample_ray(data_sample:DataSample, generator_list, overwrite):
99    generate_files_for_one_sample(data_sample, generator_list, overwrite)
100
101@ray.remote
102def generate_files_for_multi_samples_ray(data_samples, generator_list, overwrite):
103    for id, data_sample in data_samples:
104        generate_files_for_one_sample(data_sample, generator_list, overwrite)
105        logging.info('Finished: {}'.format(data_sample.data_path))
106
107def generate_files_for_one_sample(data_sample:DataSample, generator_list, overwrite):
108    if os.path.isabs(data_sample.data_relative_path):
109        logging.error("Relative path of sample id: {} is absolute and so we cannot add to the cache, skipping."
110                      " Path: {}".format(data_sample.id, data_sample.data_relative_path))
111        return
112    for gen in generator_list:
113        try:
114            gen(data_sample, overwrite)
115        except Exception as e:
116            try:
117                gen(data_sample, True)
118            except Exception as e:
119                logging.exception('Failed: {}'.format(data_sample.data_path))
120                return
121
122
123def generate_preprocess_files(part=0, parts=1,
124                              annotation_file='./kinetics/annotation/anno_400.json',
125                              use_ray=False, num_cpus=4, overwrite=False, force_encode=False, short_edge_res=256,
126                              distributed=False, shuffle_seed=None, dataset=None):
127
128    if dataset is None:
129        dataset = GluonCVMotionDataset(annotation_file)
130
131    generator_list = [
132        lambda sample, overwrite: generate_video(sample, force_encode, overwrite),
133        generate_thumbnail,
134        lambda sample, overwrite: generate_gt_vis_json(sample, dataset.get_anno_subpath(), overwrite),
135    ]
136
137    samples = sorted(dataset.samples)
138    if shuffle_seed is not None and shuffle_seed != "None":
139        random.seed(shuffle_seed)
140        random.shuffle(samples)
141    samples = samples[part::parts]
142
143    logging.info("Using ray {} ,  distributed: {}".format(use_ray, distributed))
144    if use_ray:
145        num_ray_threads = 500
146        if distributed:
147            ray.init(redis_address="localhost:6379")
148        else:
149            ray.init(num_cpus=num_cpus)
150        ray.get([generate_files_for_multi_samples_ray.remote(samples[i::num_ray_threads], generator_list, overwrite) for i in range(num_ray_threads)])
151    else:
152        for id, data_sample in tqdm(samples, mininterval=1.0):
153            generate_files_for_one_sample(data_sample, generator_list, overwrite)
154
155
156if __name__ == '__main__':
157    import fire
158    fire.Fire(generate_preprocess_files)
159