1import os 2import random 3import logging 4import ray 5from pathlib import Path 6from tqdm import tqdm 7 8from .io.video_io import VideoFrameReader, write_img_files_to_vid_ffmpeg, \ 9 resize_and_write_video_ffmpeg, convert_vid_ffmpeg 10from .utils.serialization_utils import save_json 11from .dataset import GluonCVMotionDataset, DataSample, FieldNames, get_resized_video_location, \ 12 get_vis_gt_location, get_vis_thumb_location, get_vis_video_location 13 14 15def generate_resized_video(data_sample:DataSample, short_edge_res:int, overwrite=False, cache_dir=None, 16 upscale=True, use_vis_path=True, force_encode=False, encode_kwargs=None, **kwargs): 17 if cache_dir is None: 18 resized_path = get_resized_video_location(data_sample,short_edge_res) 19 else: 20 resized_path = data_sample.get_cache_file(cache_dir, extension='.mp4') 21 if os.path.isfile(resized_path) and not overwrite: 22 return 23 os.makedirs(os.path.dirname(resized_path), exist_ok=True) 24 25 orig_vid_path = get_vis_video_location(data_sample) if use_vis_path else data_sample.data_path 26 if not upscale and min(data_sample.width, data_sample.height) <= short_edge_res: 27 if force_encode: 28 if encode_kwargs is None: 29 encode_kwargs = {k: v for k, v in kwargs.items() if k == "addn_args"} 30 convert_vid_ffmpeg(orig_vid_path, resized_path, **encode_kwargs) 31 else: 32 os.symlink(orig_vid_path, resized_path) 33 else: 34 resize_and_write_video_ffmpeg(orig_vid_path, resized_path, short_edge_res, **kwargs) 35 36 return resized_path 37 38 39def generate_video(data_sample:DataSample, force_encode=False, overwrite=False, **kwargs): 40 new_file = get_vis_video_location(data_sample) 41 if (os.path.isfile(new_file) and not overwrite) and not (force_encode and os.path.islink(new_file)): 42 return 43 os.makedirs(os.path.dirname(new_file), exist_ok=True) 44 45 video_file = data_sample.data_path 46 #### Generate Video #### 47 if os.path.isdir(video_file): 48 # the data is a set of images, generate a video 49 img_files = [os.path.join(video_file, f) for f in sorted(os.listdir(video_file))] 50 write_img_files_to_vid_ffmpeg(out_file=new_file, in_files=img_files, fps=data_sample.metadata['fps']) 51 else: 52 if not video_file.endswith(".mp4") or force_encode: 53 # Convert the video to mp4 54 convert_vid_ffmpeg(video_file, new_file, **kwargs) 55 else: 56 # the data is a video, symlink to it 57 if os.path.exists(new_file): 58 os.remove(new_file) 59 os.symlink(video_file, new_file) 60 61 return new_file 62 63 64def generate_thumbnail(data_sample:DataSample, overwrite=False): 65 video_thumbnail_frame = get_vis_thumb_location(data_sample) 66 if os.path.isfile(video_thumbnail_frame) and not overwrite: 67 return 68 os.makedirs(os.path.dirname(video_thumbnail_frame), exist_ok=True) 69 70 video_file = get_vis_video_location(data_sample) 71 72 #### Generate Thumbnail #### 73 vid = VideoFrameReader(video_file) 74 img, ts = vid.get_frame(30) if len(vid) > 30 else vid.get_frame(0) 75 img.thumbnail((300, 300)) 76 img.save(video_thumbnail_frame) 77 78 return video_thumbnail_frame 79 80 81def generate_gt_vis_json(data_sample:DataSample, cache_suffix="", overwrite=False): 82 gt_file = get_vis_gt_location(data_sample, cache_suffix) 83 if os.path.isfile(gt_file) and not overwrite: 84 return 85 #### Generate GT Track json #### 86 os.makedirs(os.path.dirname(gt_file), exist_ok=True) 87 88 vis_video_file = get_vis_video_location(data_sample) 89 vis_vid = VideoFrameReader(vis_video_file) 90 sample_dict = data_sample.to_dict(include_id=True) 91 sample_dict[FieldNames.METADATA][FieldNames.FPS] = vis_vid.fps 92 save_json(sample_dict, gt_file, indent=0) 93 94 return gt_file 95 96 97@ray.remote 98def generate_files_for_one_sample_ray(data_sample:DataSample, generator_list, overwrite): 99 generate_files_for_one_sample(data_sample, generator_list, overwrite) 100 101@ray.remote 102def generate_files_for_multi_samples_ray(data_samples, generator_list, overwrite): 103 for id, data_sample in data_samples: 104 generate_files_for_one_sample(data_sample, generator_list, overwrite) 105 logging.info('Finished: {}'.format(data_sample.data_path)) 106 107def generate_files_for_one_sample(data_sample:DataSample, generator_list, overwrite): 108 if os.path.isabs(data_sample.data_relative_path): 109 logging.error("Relative path of sample id: {} is absolute and so we cannot add to the cache, skipping." 110 " Path: {}".format(data_sample.id, data_sample.data_relative_path)) 111 return 112 for gen in generator_list: 113 try: 114 gen(data_sample, overwrite) 115 except Exception as e: 116 try: 117 gen(data_sample, True) 118 except Exception as e: 119 logging.exception('Failed: {}'.format(data_sample.data_path)) 120 return 121 122 123def generate_preprocess_files(part=0, parts=1, 124 annotation_file='./kinetics/annotation/anno_400.json', 125 use_ray=False, num_cpus=4, overwrite=False, force_encode=False, short_edge_res=256, 126 distributed=False, shuffle_seed=None, dataset=None): 127 128 if dataset is None: 129 dataset = GluonCVMotionDataset(annotation_file) 130 131 generator_list = [ 132 lambda sample, overwrite: generate_video(sample, force_encode, overwrite), 133 generate_thumbnail, 134 lambda sample, overwrite: generate_gt_vis_json(sample, dataset.get_anno_subpath(), overwrite), 135 ] 136 137 samples = sorted(dataset.samples) 138 if shuffle_seed is not None and shuffle_seed != "None": 139 random.seed(shuffle_seed) 140 random.shuffle(samples) 141 samples = samples[part::parts] 142 143 logging.info("Using ray {} , distributed: {}".format(use_ray, distributed)) 144 if use_ray: 145 num_ray_threads = 500 146 if distributed: 147 ray.init(redis_address="localhost:6379") 148 else: 149 ray.init(num_cpus=num_cpus) 150 ray.get([generate_files_for_multi_samples_ray.remote(samples[i::num_ray_threads], generator_list, overwrite) for i in range(num_ray_threads)]) 151 else: 152 for id, data_sample in tqdm(samples, mininterval=1.0): 153 generate_files_for_one_sample(data_sample, generator_list, overwrite) 154 155 156if __name__ == '__main__': 157 import fire 158 fire.Fire(generate_preprocess_files) 159