1"""Visual Tracker Benchmark.
2Code adapted from https://github.com/STVIR/pysot"""
3import json
4import os
5from glob import glob
6from tqdm import tqdm
7from mxnet.gluon.data import dataset
8from gluoncv.utils.filesystem import try_import_cv2
9
10
11class Video(object):
12    """
13    Abstract video class. get video class information for example imgs.
14
15    Parameters
16    ----------
17        name : str
18            video name
19        root: str
20            dataset root
21        video_dir: str
22            video directory
23        init_rect: list
24            init rectangle
25        img_names: str
26            image names
27        gt_rect: list
28            groundtruth rectangle
29        attr: str
30            attribute of video
31    """
32    def __init__(self, name, root, video_dir, init_rect, img_names,
33                 gt_rect, attr, load_img=False):
34        self.name = name
35        self.video_dir = video_dir
36        self.init_rect = init_rect
37        self.gt_traj = gt_rect
38        self.attr = attr
39        self.pred_trajs = {}
40        self.img_names = [os.path.join(root, x) for x in img_names]
41        self.imgs = None
42        cv2 = try_import_cv2()
43
44        if load_img:
45            self.imgs = [cv2.imread(x) for x in self.img_names]
46            self.width = self.imgs[0].shape[1]
47            self.height = self.imgs[0].shape[0]
48        else:
49            img = cv2.imread(self.img_names[0])
50            assert img is not None, self.img_names[0]
51            self.width = img.shape[1]
52            self.height = img.shape[0]
53
54    def load_img(self):
55        if self.imgs is None:
56            cv2 = try_import_cv2()
57            self.imgs = [cv2.imread(x) for x in self.img_names]
58            self.width = self.imgs[0].shape[1]
59            self.height = self.imgs[0].shape[0]
60
61    def free_img(self):
62        self.imgs = None
63
64    def __len__(self):
65        return len(self.img_names)
66
67    def __getitem__(self, idx):
68        if self.imgs is None:
69            cv2 = try_import_cv2()
70            return cv2.imread(self.img_names[idx]), self.gt_traj[idx]
71        else:
72            return self.imgs[idx], self.gt_traj[idx]
73
74    def __iter__(self):
75        for i in range(len(self.img_names)):
76            if self.imgs is not None:
77                yield self.imgs[i], self.gt_traj[i]
78            else:
79                cv2 = try_import_cv2()
80                yield cv2.imread(self.img_names[i]), self.gt_traj[i]
81
82class OTBVideo(Video):
83    """
84    OTBVideo class. Including video operation
85
86    Parameters
87    ----------
88        name : str
89            video name
90        root: str
91            dataset root
92        video_dir: str
93            video directory
94        init_rect: list
95            init rectangle
96        img_names: str
97            image names
98        gt_rect: list
99            groundtruth rectangle
100        attr: str
101            attribute of video
102    """
103    def __init__(self, name, root, video_dir, init_rect, img_names,
104                 gt_rect, attr, load_img=False):
105        super(OTBVideo, self).__init__(name, root, video_dir,
106                                       init_rect, img_names, gt_rect, attr, load_img)
107
108    def load_tracker(self, path, tracker_names=None, store=True):
109        """
110        open txt and load_tracker
111        Parameters
112        ----------
113            path : str
114                path to result
115            tracker_name : list
116                name of tracker
117        """
118        if not tracker_names:
119            tracker_names = [x.split('/')[-1] for x in glob(path)
120                             if os.path.isdir(x)]
121        if isinstance(tracker_names, str):
122            tracker_names = [tracker_names]
123        for name in tracker_names:
124            traj_file = os.path.join(path, name, self.name+'.txt')
125            if not os.path.exists(traj_file):
126                if self.name == 'FleetFace':
127                    txt_name = 'fleetface.txt'
128                elif self.name == 'Jogging-1':
129                    txt_name = 'jogging_1.txt'
130                elif self.name == 'Jogging-2':
131                    txt_name = 'jogging_2.txt'
132                elif self.name == 'Skating2-1':
133                    txt_name = 'skating2_1.txt'
134                elif self.name == 'Skating2-2':
135                    txt_name = 'skating2_2.txt'
136                elif self.name == 'FaceOcc1':
137                    txt_name = 'faceocc1.txt'
138                elif self.name == 'FaceOcc2':
139                    txt_name = 'faceocc2.txt'
140                elif self.name == 'Human4-2':
141                    txt_name = 'human4_2.txt'
142                else:
143                    txt_name = self.name[0].lower()+self.name[1:]+'.txt'
144                traj_file = os.path.join(path, name, txt_name)
145            if os.path.exists(traj_file):
146                with open(traj_file, 'r') as f:
147                    pred_traj = [list(map(float, x.strip().split(',')))
148                                 for x in f.readlines()]
149                    if len(pred_traj) != len(self.gt_traj):
150                        print(name, len(pred_traj), len(self.gt_traj), self.name)
151                    if store:
152                        self.pred_trajs[name] = pred_traj
153                    else:
154                        return pred_traj
155            else:
156                print(traj_file)
157        self.tracker_names = list(self.pred_trajs.keys())
158        return None
159
160class OTBTracking(dataset.Dataset):
161    """OTB Visual Tracker Benchmark.
162
163    Parameters
164    ----------
165    name : str
166        name to data, and name to dataset json Default is 'OTB2015'
167    dataset_root: str
168        path to dataset root
169    """
170    def __init__(self, name, dataset_root, load_img=False):
171        super(OTBTracking, self).__init__()
172        self.name = name
173        self.dataset_root = dataset_root
174        with open(os.path.join(self.dataset_root, self.name+'.json'), 'r') as f:
175            meta_data = json.load(f)
176        # load videos
177        pbar = tqdm(meta_data.keys(), desc='loading '+self.name, ncols=100)
178        self.videos = {}
179        for video in pbar:
180            pbar.set_postfix_str(video)
181            self.videos[video] = OTBVideo(video,
182                                          self.dataset_root,
183                                          meta_data[video]['video_dir'],
184                                          meta_data[video]['init_rect'],
185                                          meta_data[video]['img_names'],
186                                          meta_data[video]['gt_rect'],
187                                          meta_data[video]['attr'],
188                                          load_img)
189        # set attr
190        attr = []
191        for x in self.videos.values():
192            attr += x.attr
193        attr = set(attr)
194        self.attr = {}
195        self.attr['ALL'] = list(self.videos.keys())
196        for x in attr:
197            self.attr[x] = []
198        for k, v in self.videos.items():
199            for attr_ in v.attr:
200                self.attr[attr_].append(k)
201
202    def __getitem__(self, idx):
203        if isinstance(idx, str):
204            return self.videos[idx]
205        elif isinstance(idx, int):
206            return self.videos[sorted(list(self.videos.keys()))[idx]]
207        return None
208
209    def __len__(self):
210        return len(self.videos)
211
212    def set_tracker(self, path, tracker_names):
213        """
214        Args:
215            path: path to tracker results,
216            tracker_names: list of tracker name
217        """
218        self.tracker_path = path
219        self.tracker_names = tracker_names
220