1import os
2import glob
3import subprocess
4import tempfile
5from pathlib import Path
6import cv2
7import numpy as np
8from PIL import Image
9
10
11class ColorSpace(object):
12    RGB = 0
13    BGR = 3
14    GRAY = 2
15
16
17convert_from_to_dict = {ColorSpace.BGR: {ColorSpace.RGB: cv2.COLOR_BGR2RGB,
18                                         ColorSpace.GRAY: cv2.COLOR_BGR2GRAY},
19                        ColorSpace.RGB: {ColorSpace.BGR: cv2.COLOR_RGB2BGR,
20                                         ColorSpace.GRAY: cv2.COLOR_RGB2GRAY},
21                        ColorSpace.GRAY: {ColorSpace.BGR: cv2.COLOR_GRAY2BGR,
22                                          ColorSpace.RGB: cv2.COLOR_GRAY2RGB}}
23
24FFMPEG_FOURCC = {
25    'libx264': 0x21,
26    'avc1': cv2.VideoWriter_fourcc(*'avc1'),
27    'mjpeg': 0x6c,
28    'mpeg-4': 0x20
29}
30
31
32def convert_color_from_to(frame, cs_from, cs_to):
33    if cs_from not in convert_from_to_dict or cs_to not in convert_from_to_dict[cs_from]:
34        raise Exception('color conversion is not supported')
35    convert_spec = convert_from_to_dict[cs_from][cs_to]
36    return cv2.cvtColor(frame, convert_spec)
37
38
39def read_vid_rgb(file):
40    cap = cv2.VideoCapture(file)
41    all_ts = []
42    all_frames = []
43    while True:
44        ts = int(cap.get(cv2.CAP_PROP_POS_MSEC))
45        frame = read_frame(cap)
46        if frame is None:
47            break
48        all_frames.append(frame)
49        all_ts.append(ts)
50
51    fps = cap.get(cv2.CAP_PROP_FPS)
52    return InMemoryVideo(all_frames, fps, frame_ts=all_ts)
53
54
55def format_frame(frame, color_space=ColorSpace.RGB):
56    if color_space != ColorSpace.BGR:
57        frame = convert_color_from_to(frame, ColorSpace.BGR, color_space)
58    return frame
59
60
61def read_frame(cap):
62    _, frame = cap.read()
63    if frame is None:
64        return frame
65    return Image.fromarray(format_frame(frame, ColorSpace.RGB), 'RGB')
66
67
68def read_img(file):
69    frame = cv2.imread(file)
70    if frame is None:
71        return frame
72    return Image.fromarray(format_frame(frame, ColorSpace.RGB), 'RGB')
73
74
75def write_img(file, img, color_space=ColorSpace.RGB):
76    img = convert_color_from_to(img, color_space, ColorSpace.BGR)
77    cv2.imwrite(file, img)
78
79
80class VideoBaseClass(object):
81    def __init__(self):
82        raise NotImplementedError()
83
84    def __del__(self):
85        raise NotImplementedError()
86
87    def __len__(self):
88        raise NotImplementedError()
89
90    def _set_frame_ndx(self, frame_num):
91        raise NotImplementedError()
92
93    def get_next_frame_time_stamp(self):
94        raise NotImplementedError()
95
96    def read(self):
97        raise NotImplementedError()
98
99    def __iter__(self):
100        self._set_frame_ndx(0)
101        return self
102
103    def iter_frame_ts(self, start_ts=0):
104        return FrameTimeStampIterator(self, start_ts)
105
106    def next(self):
107        return self.__next__()
108
109    def __next__(self):
110        ts = self.get_next_frame_time_stamp()
111        frame = self.read()
112        if frame is None:
113            raise StopIteration()
114        return frame, ts
115
116    def __getitem__(self, frame_num):
117        if self._next_frame_to_read != frame_num:
118            self._set_frame_ndx(frame_num)
119        ts = self.get_next_frame_time_stamp()
120        return self.read(), ts
121
122    @property
123    def verified_len(self):
124        return len(self)
125
126    @property
127    def fps(self):
128        return self.get_frame_rate()
129
130    @property
131    def width(self):
132        return self.get_width()
133
134    @property
135    def height(self):
136        return self.get_height()
137
138    def get_frame_ind_for_time(self, time_stamp):
139        """
140        Returns the index for the frame at the timestamp provided.
141        The frame index returned is the first frame that occurs before or at the timestamp given.
142
143        Args:
144            time_stamp (int): the millisecond time stamp for the desired frame
145
146        Returns (int):
147            the index for the frame at the given timestamp.
148
149        """
150        assert isinstance(time_stamp, int)
151        return int(self.fps * time_stamp / 1000.)
152
153    def get_frame_for_time(self, time_stamp):
154        return self[self.get_frame_ind_for_time(time_stamp)]
155
156    def get_frame_rate(self):
157        raise NotImplementedError()
158
159    def get_width(self):
160        raise NotImplementedError()
161
162    def get_height(self):
163        raise NotImplementedError()
164
165    @property
166    def duration(self):
167        raise NotImplementedError()
168
169    def asnumpy_and_ts(self):
170        out = []
171        out_ts = []
172        for frame, ts in self.iter_frame_ts():
173            out.append(frame)
174            out_ts.append(ts)
175        return out, out_ts
176
177    def asnumpy(self):
178        out = []
179        for frame in self:
180            out.append(frame)
181        return out
182
183    def num_frames(self):
184        return len(self)
185
186    def get_frame(self, index):
187        return self[index]
188
189    def get_frame_batch(self, index_list):
190        '''
191        Return a list of PIL Image classes
192        Args:
193            index_list (List[int]): list of indexes
194            color_mode (str):  color mode of the pil image typically 'RGB'
195
196        Returns: List[PIL.Image]
197
198        '''
199        return [self.get_frame(i) for i in index_list]
200
201
202class FrameTimeStampIterator(object):
203    def __init__(self, frame_reader, start_ts=0):
204        self.frame_reader = frame_reader
205        self.frame_reader._set_frame_time(start_ts)
206
207    def __iter__(self):
208        return self
209
210    def next(self):
211        return self.__next__()
212
213    def __next__(self):
214        return next(self.frame_reader)
215
216
217class InMemoryVideo(VideoBaseClass):
218    def __init__(self, frames=None, fps=None, frame_ts=None):
219        self._frames = []
220        if frames is not None:
221            self._frames = list(frames)
222
223        self._fps = fps
224        self._next_frame_to_read = 0
225
226        self._frame_ts = []
227        if len(self._frames) > 0:
228            assert len(frame_ts) == len(self._frames)
229            assert all(a <= b for a, b in zip(frame_ts[:-1], frame_ts[1:]))
230            self._frame_ts = frame_ts
231
232    def __del__(self):
233        pass
234
235    def __len__(self):
236        return len(self._frames)
237
238    def _set_frame_ndx(self, frame_num):
239        self._next_frame_to_read = frame_num
240
241    def get_next_frame_time_stamp(self):
242        if self._next_frame_to_read >= len(self._frame_ts):
243            return None
244        return self._frame_ts[self._next_frame_to_read]
245
246    def read(self):
247        if self._next_frame_to_read >= len(self._frames):
248            return None
249        f = self._frames[self._next_frame_to_read]
250        self._next_frame_to_read += 1
251        return f
252
253    def __setitem__(self, key, value):
254        self._next_frame_to_read = key + 1
255        self._frames[key] = value
256
257    def append(self, frame, ts=None):
258        assert ts is None or len(self._frame_ts) == 0 or ts > self._frame_ts[-1]
259        self._frames.append(frame)
260        self._next_frame_to_read = len(self._frames)
261        if ts is None:
262            if len(self._frame_ts) > 0:
263                self._frame_ts.append(self._frame_ts[-1] + 1000. / self.fps)
264            else:
265                self._frame_ts.append(0.)
266        else:
267            self._frame_ts.append(ts)
268
269    def extend(self, frames, tss):
270        assert all(a <= b for a, b in zip(tss[:-1], tss[1:]))
271        self._frames.extend(frames)
272        self._frame_ts.extend(tss)
273        self._next_frame_to_read = len(self._frames)
274
275    def get_frame_rate(self):
276        return self._fps
277
278    def asnumpy(self):
279        return self._frames
280
281    def get_frame_ind_for_time(self, time_stamp):
282        ind = np.searchsorted(self._frame_ts, time_stamp)
283        if ind > 0:
284            ind -= 1
285        return ind
286
287
288class InMemoryMXVideo(InMemoryVideo):
289    def asnumpy(self):
290        return [f.asnumpy() for f in self._frames]
291
292
293img_exts = ['.jpg', '.jpeg', '.jp', '.png']
294vid_exts = ['.avi', '.mpeg', '.mp4', '.mov']
295
296
297class VideoFrameReader(VideoBaseClass):
298    def __init__(self, file):
299        self.cap = None
300        self.file_name = file
301        self._next_frame_to_read = 0
302        self._verified_len = None
303        self.frame_cache = {}
304        self._is_vid = None
305        self._is_img = None
306        self._len = None
307        self._duration = None
308
309    def __del__(self):
310        if self.cap is not None:
311            self.cap.release()
312
313    @property
314    def is_video(self):
315        return not self.is_img
316
317    @property
318    def is_img(self):
319        if self._is_img is None:
320            _, ext = os.path.splitext(self.file_name)
321            self._is_img = ext.lower() in img_exts
322        return self._is_img
323
324    def _lazy_init(self):
325        if self.is_video and self.cap is None:
326            self.cap = cv2.VideoCapture(self.file_name)
327
328    def read_from_mem_cache(self):
329        return None
330
331    def read(self):
332        self._lazy_init()
333        if (not self.is_img) and self._next_frame_to_read != self.cap.get(cv2.CAP_PROP_POS_FRAMES):
334            raise Exception("failed read frame check, stored {} , cap val {} , file {}".format(
335                self._next_frame_to_read, self.cap.get(cv2.CAP_PROP_POS_FRAMES), self.file_name))
336        if self.is_video:
337            frame = read_frame(self.cap)
338        else:
339            if self._next_frame_to_read == 0:
340                frame = read_img(self.file_name)
341            else:
342                frame = None
343        if frame is None:
344            self._verified_len = self._next_frame_to_read
345        self._next_frame_to_read += 1
346        return frame
347
348    def _set_frame_ndx(self, frame_num):
349        self._lazy_init()
350        if self.is_video:
351            self.cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
352        self._next_frame_to_read = frame_num
353
354    def _set_frame_time(self, frame_ts):
355        self._lazy_init()
356        self.cap.set(cv2.CAP_PROP_POS_MSEC, frame_ts)
357        self._next_frame_to_read = self.cap.get(cv2.CAP_PROP_POS_FRAMES)
358
359    def get_frame_for_time(self, time_stamp):
360        self._lazy_init()
361        if self.is_video:
362            self.cap.set(cv2.CAP_PROP_POS_MSEC, time_stamp)
363            self._next_frame_to_read = self.cap.get(cv2.CAP_PROP_POS_FRAMES)
364        return self.read()
365
366    def get_next_frame_time_stamp(self):
367        self._lazy_init()
368        if self.is_video:
369            return max(0, int(self.cap.get(cv2.CAP_PROP_POS_MSEC)))
370        else:
371            return 0
372
373    def _init_len_and_duration(self):
374        if self._duration is None:
375            self._lazy_init()
376            pos = self.cap.get(cv2.CAP_PROP_POS_MSEC)
377            self.cap.set(cv2.CAP_PROP_POS_AVI_RATIO, 1)
378            self._duration = int(self.cap.get(cv2.CAP_PROP_POS_MSEC))
379            self._len = int(self.cap.get(cv2.CAP_PROP_POS_FRAMES))
380            self.cap.set(cv2.CAP_PROP_POS_MSEC, pos)
381
382    def __len__(self):
383        if self.is_video:
384            self._init_len_and_duration()
385            return self._len
386        else:
387            return 1
388
389    @property
390    def duration(self):
391        self._init_len_and_duration()
392        return self._duration
393
394    @property
395    def verified_len(self):
396        if self.is_video:
397            return self._verified_len
398        else:
399            return 1
400
401    def get_frame_rate(self):
402        self._lazy_init()
403        if self.is_video:
404            return self.cap.get(cv2.CAP_PROP_FPS)
405        else:
406            return 1
407
408    def get_width(self):
409        self._lazy_init()
410        return self.cap.get(cv2.CAP_PROP_FRAME_WIDTH)
411
412    def get_height(self):
413        self._lazy_init()
414        return self.cap.get(cv2.CAP_PROP_FRAME_HEIGHT)
415
416
417class VideoSortedFolderReader(VideoBaseClass):
418    def __init__(self, data_path, fps, glob_pattern="*"):
419        self._data_path = data_path
420
421        self._glob_pattern = glob_pattern
422        frame_paths = glob.glob(os.path.join(data_path, glob_pattern))
423        self._frame_paths = sorted(frame_paths)
424
425        self._next_frame_to_read = 0
426        self._last_read_frame = None
427        self._fps = fps
428        self._period = 1.0 / fps * 1000
429
430    def __del__(self):
431        pass
432
433    def __len__(self):
434        return len(self._frame_paths)
435
436    @property
437    def duration(self):
438        return round(self._period * len(self))
439
440    def get_frame_rate(self):
441        return self._fps
442
443    def _set_frame_ndx(self, frame_num):
444        self._next_frame_to_read = frame_num
445
446    def _set_frame_time(self, frame_ts):
447        self._set_frame_ndx(round(frame_ts / self._period))
448
449    def get_next_frame_time_stamp(self):
450        return int(self._next_frame_to_read * self._period)
451
452    def read(self):
453        read_idx = self._next_frame_to_read
454        if read_idx >= len(self._frame_paths):
455            return None
456        frame = read_img(self._frame_paths[read_idx])
457        self._last_read_frame = read_idx
458        self._next_frame_to_read += 1
459        return frame
460
461    def get_image_ext(self):
462        return Path(self._frame_paths[0]).suffix
463
464    def get_frame_path(self, frame_num=None):
465        if frame_num is None:
466            frame_num = self._last_read_frame
467        return self._frame_paths[frame_num]
468
469
470def write_video_rgb(file, frames, fps=None):
471    # check if data has the fps property (eg: InMemoryVideo, VideoFrameReader or VideoCacheReader)
472    if fps is None:
473        fps = 30
474    try:
475        fps = frames.fps
476    except:
477        pass
478
479    # write the video data frame-by-frame
480    writer = None
481    for frame in frames:
482        frame = np.asarray(frame)
483        if writer is None:
484            writer = cv2.VideoWriter(file, FFMPEG_FOURCC['libx264'],
485                                     fps=fps, frameSize=frame.shape[1::-1],
486                                     isColor=True)
487        frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
488        writer.write(frame)
489
490    if writer is not None:
491        writer.release()
492
493
494def cmd_with_addn_args(cmd, addn_args):
495    if addn_args:
496        if isinstance(addn_args, str):
497            addn_args = addn_args.split()
498        return cmd + addn_args
499    else:
500        return cmd
501
502
503def resize_and_write_video_ffmpeg(in_path, out_path, short_edge_res,
504                                  scaling_algorithm="lanczos", raw_scale_input=None,
505                                  keep_audio=True, addn_args=None):
506    # See https://trac.ffmpeg.org/wiki/Scaling for scaling options / details
507
508    if short_edge_res is not None and raw_scale_input is not None:
509        raise ValueError("Either short_edge_res or raw_scale_input should be provided, not both")
510
511    if short_edge_res is not None:
512        # The input height, divided by the minimum of the width and height
513        # (so either = 1 or > 1) times the new short edge,
514        # then round to the nearest 2.
515        # We keep the aspect ratio of the width and make sure it is also divisible
516        # by 2 by using '-2' (see the ffpeg scaling wiki)
517        scale_arg = "-2:'round( ih/min(iw,ih) * {} /2)*2'".format(short_edge_res)
518
519        # Alternatively:
520        # scale_arg = "{res}:{res}:force_original_aspect_ratio=increase".format(res=short_edge_res)
521        # In case the output has a non even dimension (e.g. 301) after rescaling,
522        # we crop the single extra pixel
523        # crop_arg = "floor(iw/2)*2:floor(ih/2)*2"
524    else:
525        scale_arg = raw_scale_input
526
527    if keep_audio:
528        audio_arg = None
529    else:
530        audio_arg = "-an"
531
532    scale_arg += ":flags={}".format(scaling_algorithm)
533
534    with tempfile.TemporaryDirectory() as tmp_path:
535        tmp_file_path = os.path.join(tmp_path, os.path.basename(out_path))
536
537        ffmpeg_cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
538                      "-i", in_path, "-vf", "scale={}".format(scale_arg),
539                      audio_arg]
540        ffmpeg_cmd = cmd_with_addn_args(ffmpeg_cmd, addn_args)
541        ffmpeg_cmd += ["-strict", "experimental", tmp_file_path]
542        ffmpeg_cmd = [x for x in ffmpeg_cmd if x is not None]
543
544        subprocess.run(ffmpeg_cmd, check=True)
545
546        subprocess.run(["mv", tmp_file_path, out_path], check=True)
547
548
549def resize_and_write_video(file, frames, short_edge_res, fps=None):
550    # check if data has the fps property (eg: InMemoryVideo, VideoFrameReader or VideoCacheReader)
551    if fps is None:
552        fps = 30
553    try:
554        fps = frames.fps
555    except:
556        pass
557
558    # write the video data frame-by-frame
559    writer = None
560    new_size = None
561    for frame in frames:
562        if new_size is None:
563            factor = float(short_edge_res) / min(frame.size)
564            new_size = [int(i * factor) for i in frame.size]
565
566        frame_np = frame.resize(new_size)
567        frame_np = np.asarray(frame_np)
568
569        if writer is None:
570            writer = cv2.VideoWriter(file, FFMPEG_FOURCC['libx264'],
571                                     fps=fps, frameSize=frame_np.shape[1::-1],
572                                     isColor=True)
573        frame_np = cv2.cvtColor(frame_np, cv2.COLOR_RGB2BGR)
574        writer.write(frame_np)
575
576    if writer is not None:
577        writer.release()
578
579
580def write_img_files_to_vid(out_file, in_files, fps=None):
581    # check if data has the fps property (eg: InMemoryVideo, VideoFrameReader or VideoCacheReader)
582    if fps is None:
583        fps = 30
584
585    # write the video data frame-by-frame
586    writer = None
587    for in_file in in_files:
588        with open(in_file, 'rb') as fp:
589            frame = Image.open(fp)
590            frame = np.asarray(frame)
591            frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
592            if writer is None:
593                writer = cv2.VideoWriter(out_file, FFMPEG_FOURCC['avc1'],
594                                         fps=fps, frameSize=frame.shape[1::-1],
595                                         isColor=True)
596            writer.write(frame)
597
598    if writer is not None:
599        writer.release()
600
601
602def write_img_files_to_vid_ffmpeg(out_file, in_files, fps=None):
603    if fps is None:
604        fps = 30
605    input_str = "'\nfile '".join(in_files)
606    input_str = "file '" + input_str + "'\n"
607
608    with tempfile.TemporaryDirectory() as tmp_path:
609        tmp_file_path = os.path.join(tmp_path, os.path.basename(out_file))
610        # See https://trac.ffmpeg.org/wiki/Slideshow
611        # for why we are using input_str like this (for concat filter)
612        # Need -safe 0 due to:
613        # https://stackoverflow.com/questions/38996925/ffmpeg-concat-unsafe-file-name
614        ret = subprocess.run(["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
615                              "-f", "concat", "-safe", "0", "-r", str(fps), "-i", "/dev/stdin",
616                              tmp_file_path],
617                             input=input_str.encode('utf-8'), check=True)
618        subprocess.run(["mv", tmp_file_path, out_file], check=True)
619    return ret
620
621
622def convert_vid_ffmpeg(in_path, out_path, addn_args=None):
623    # muxing queue size bug workaround:
624    cmd = ["ffmpeg", "-y", "-hide_banner", "-loglevel", "error",
625           "-i", in_path, "-max_muxing_queue_size", "99999"]
626    cmd = cmd_with_addn_args(cmd, addn_args)
627    with tempfile.TemporaryDirectory() as tmp_path:
628        tmp_file_path = os.path.join(tmp_path, os.path.basename(out_path))
629        cmd += [tmp_file_path]
630        ret = subprocess.run(cmd, check=True)
631        subprocess.run(["mv", tmp_file_path, out_path], check=True)
632    return ret
633