1import os
2import sys
3import time
4import argparse
5import logging
6import math
7import gc
8import json
9
10import numpy as np
11import mxnet as mx
12from mxnet import nd
13from mxnet.gluon.data.vision import transforms
14from gluoncv.data.transforms import video
15from gluoncv.model_zoo import get_model
16from gluoncv.data import VideoClsCustom
17from gluoncv.utils.filesystem import try_import_decord
18
19def parse_args():
20    parser = argparse.ArgumentParser(description='Extract features from pre-trained models for video related tasks.')
21    parser.add_argument('--data-dir', type=str, default='',
22                        help='the root path to your data')
23    parser.add_argument('--need-root', action='store_true',
24                        help='if set to True, --data-dir needs to be provided as the root path to find your videos.')
25    parser.add_argument('--data-list', type=str, default='',
26                        help='the list of your data. You can either provide complete path or relative path.')
27    parser.add_argument('--dtype', type=str, default='float32',
28                        help='data type for training. default is float32')
29    parser.add_argument('--gpu-id', type=int, default=0,
30                        help='number of gpus to use. Use -1 for CPU')
31    parser.add_argument('--mode', type=str,
32                        help='mode in which to train the model. options are symbolic, imperative, hybrid')
33    parser.add_argument('--model', type=str, required=True,
34                        help='type of model to use. see vision_model for options.')
35    parser.add_argument('--input-size', type=int, default=224,
36                        help='size of the input image size. default is 224')
37    parser.add_argument('--use-pretrained', action='store_true', default=True,
38                        help='enable using pretrained model from GluonCV.')
39    parser.add_argument('--hashtag', type=str, default='',
40                        help='hashtag for pretrained models.')
41    parser.add_argument('--resume-params', type=str, default='',
42                        help='path of parameters to load from.')
43    parser.add_argument('--log-interval', type=int, default=10,
44                        help='Number of batches to wait before logging.')
45    parser.add_argument('--new-height', type=int, default=256,
46                        help='new height of the resize image. default is 256')
47    parser.add_argument('--new-width', type=int, default=340,
48                        help='new width of the resize image. default is 340')
49    parser.add_argument('--new-length', type=int, default=32,
50                        help='new length of video sequence. default is 32')
51    parser.add_argument('--new-step', type=int, default=1,
52                        help='new step to skip video sequence. default is 1')
53    parser.add_argument('--num-classes', type=int, default=400,
54                        help='number of classes.')
55    parser.add_argument('--ten-crop', action='store_true',
56                        help='whether to use ten crop evaluation.')
57    parser.add_argument('--three-crop', action='store_true',
58                        help='whether to use three crop evaluation.')
59    parser.add_argument('--video-loader', action='store_true', default=True,
60                        help='if set to True, read videos directly instead of reading frames.')
61    parser.add_argument('--use-decord', action='store_true', default=True,
62                        help='if set to True, use Decord video loader to load data.')
63    parser.add_argument('--slowfast', action='store_true',
64                        help='if set to True, use data loader designed for SlowFast network.')
65    parser.add_argument('--slow-temporal-stride', type=int, default=16,
66                        help='the temporal stride for sparse sampling of video frames for slow branch in SlowFast network.')
67    parser.add_argument('--fast-temporal-stride', type=int, default=2,
68                        help='the temporal stride for sparse sampling of video frames for fast branch in SlowFast network.')
69    parser.add_argument('--num-crop', type=int, default=1,
70                        help='number of crops for each image. default is 1')
71    parser.add_argument('--data-aug', type=str, default='v1',
72                        help='different types of data augmentation pipelines. Supports v1, v2, v3 and v4.')
73    parser.add_argument('--num-segments', type=int, default=1,
74                        help='number of segments to evenly split the video.')
75    parser.add_argument('--save-dir', type=str, default='./',
76                        help='directory of saved results')
77    opt = parser.parse_args()
78    return opt
79
80def read_data(opt, video_name, transform, video_utils):
81
82    decord = try_import_decord()
83    decord_vr = decord.VideoReader(video_name, width=opt.new_width, height=opt.new_height)
84    duration = len(decord_vr)
85
86    opt.skip_length = opt.new_length * opt.new_step
87    segment_indices, skip_offsets = video_utils._sample_test_indices(duration)
88
89    if opt.video_loader:
90        if opt.slowfast:
91            clip_input = video_utils._video_TSN_decord_slowfast_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
92        else:
93            clip_input = video_utils._video_TSN_decord_batch_loader(video_name, decord_vr, duration, segment_indices, skip_offsets)
94    else:
95        raise RuntimeError('We only support video-based inference.')
96
97    clip_input = transform(clip_input)
98
99    if opt.slowfast:
100        sparse_sampels = len(clip_input) // (opt.num_segments * opt.num_crop)
101        clip_input = np.stack(clip_input, axis=0)
102        clip_input = clip_input.reshape((-1,) + (sparse_sampels, 3, opt.input_size, opt.input_size))
103        clip_input = np.transpose(clip_input, (0, 2, 1, 3, 4))
104    else:
105        clip_input = np.stack(clip_input, axis=0)
106        clip_input = clip_input.reshape((-1,) + (opt.new_length, 3, opt.input_size, opt.input_size))
107        clip_input = np.transpose(clip_input, (0, 2, 1, 3, 4))
108
109    if opt.new_length == 1:
110        clip_input = np.squeeze(clip_input, axis=2)    # this is for 2D input case
111
112    return nd.array(clip_input)
113
114def main(logger):
115    opt = parse_args()
116    logger.info(opt)
117    gc.set_threshold(100, 5, 5)
118
119    if not os.path.exists(opt.save_dir):
120        os.makedirs(opt.save_dir)
121
122    # set env
123    if opt.gpu_id == -1:
124        context = mx.cpu()
125    else:
126        gpu_id = opt.gpu_id
127        context = mx.gpu(gpu_id)
128
129    # get data preprocess
130    image_norm_mean = [0.485, 0.456, 0.406]
131    image_norm_std = [0.229, 0.224, 0.225]
132    if opt.ten_crop:
133        transform_test = transforms.Compose([
134            video.VideoTenCrop(opt.input_size),
135            video.VideoToTensor(),
136            video.VideoNormalize(image_norm_mean, image_norm_std)
137        ])
138        opt.num_crop = 10
139    elif opt.three_crop:
140        transform_test = transforms.Compose([
141            video.VideoThreeCrop(opt.input_size),
142            video.VideoToTensor(),
143            video.VideoNormalize(image_norm_mean, image_norm_std)
144        ])
145        opt.num_crop = 3
146    else:
147        transform_test = video.VideoGroupValTransform(size=opt.input_size, mean=image_norm_mean, std=image_norm_std)
148        opt.num_crop = 1
149
150    # get model
151    if opt.use_pretrained and len(opt.hashtag) > 0:
152        opt.use_pretrained = opt.hashtag
153    classes = opt.num_classes
154    model_name = opt.model
155    net = get_model(name=model_name, nclass=classes, pretrained=opt.use_pretrained,
156                    feat_ext=True, num_segments=opt.num_segments, num_crop=opt.num_crop)
157    net.cast(opt.dtype)
158    net.collect_params().reset_ctx(context)
159    if opt.mode == 'hybrid':
160        net.hybridize(static_alloc=True, static_shape=True)
161    if opt.resume_params != '' and not opt.use_pretrained:
162        net.load_parameters(opt.resume_params, ctx=context)
163        logger.info('Pre-trained model %s is successfully loaded.' % (opt.resume_params))
164    else:
165        logger.info('Pre-trained model is successfully loaded from the model zoo.')
166    logger.info("Successfully built model {}".format(model_name))
167
168    # get data
169    anno_file = opt.data_list
170    f = open(anno_file, 'r')
171    data_list = f.readlines()
172    logger.info('Load %d video samples.' % len(data_list))
173
174    # build a pseudo dataset instance to use its children class methods
175    video_utils = VideoClsCustom(root=opt.data_dir,
176                                 setting=opt.data_list,
177                                 num_segments=opt.num_segments,
178                                 num_crop=opt.num_crop,
179                                 new_length=opt.new_length,
180                                 new_step=opt.new_step,
181                                 new_width=opt.new_width,
182                                 new_height=opt.new_height,
183                                 video_loader=opt.video_loader,
184                                 use_decord=opt.use_decord,
185                                 slowfast=opt.slowfast,
186                                 slow_temporal_stride=opt.slow_temporal_stride,
187                                 fast_temporal_stride=opt.fast_temporal_stride,
188                                 data_aug=opt.data_aug,
189                                 lazy_init=True)
190
191    start_time = time.time()
192    for vid, vline in enumerate(data_list):
193        video_path = vline.split()[0]
194        video_name = video_path.split('/')[-1]
195        if opt.need_root:
196            video_path = os.path.join(opt.data_dir, video_path)
197        video_data = read_data(opt, video_path, transform_test, video_utils)
198        video_input = video_data.as_in_context(context)
199        video_feat = net(video_input.astype(opt.dtype, copy=False))
200
201        feat_file = '%s_%s_feat.npy' % (model_name, video_name)
202        np.save(os.path.join(opt.save_dir, feat_file), video_feat.asnumpy())
203
204        if vid > 0 and vid % opt.log_interval == 0:
205            logger.info('%04d/%04d is done' % (vid, len(data_list)))
206
207    end_time = time.time()
208    logger.info('Total feature extraction time is %4.2f minutes' % ((end_time - start_time) / 60))
209
210if __name__ == '__main__':
211    logging.basicConfig()
212    logger = logging.getLogger('logger')
213    logger.setLevel(logging.INFO)
214
215    main(logger)
216