1"""Monocular Depth Estimation Dataset.
2Digging into Self-Supervised Monocular Depth Prediction, ICCV 2019
3https://arxiv.org/abs/1806.01260
4Code partially borrowed from
5https://github.com/nianticlabs/monodepth2/blob/master/datasets/mono_dataset.py
6"""
7import random
8import copy
9import numpy as np
10from PIL import Image  # using pillow-simd for increased speed
11
12import mxnet as mx
13from mxnet.gluon.data import dataset
14from mxnet.gluon.data.vision import transforms
15
16
17def pil_loader(path):
18    # open path as file to avoid ResourceWarning
19    # (https://github.com/python-pillow/Pillow/issues/835)
20    with open(path, 'rb') as f:
21        with Image.open(f) as img:
22            return img.convert('RGB')
23
24
25class MonoDataset(dataset.Dataset):
26    """Superclass for monocular dataloaders
27    Parameters
28    ----------
29    data_path : string
30        Path to dataset folder.
31    filenames : string
32        Path to split file.
33        For example: '$(HOME)/.mxnet/datasets/kitti/splits/eigen_full/train_files.txt'
34    height : int
35        The height for input images.
36    width : int
37        The height for input images.
38    frame_idxs : list
39        The frames to load.
40        an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
41        or "s" for the opposite image in the stereo pair.
42    num_scales : int
43        The number of scales of the image relative to the full-size image.
44    is_train : bool
45        Whether use Data Augmentation. Default is: False
46    img_ext : string
47        The extension name of input image. Default is '.jpg'
48    """
49
50    def __init__(self, data_path, filenames, height, width, frame_idxs,
51                 num_scales, is_train=False, img_ext='.jpg'):
52        super(MonoDataset, self).__init__()
53
54        self.data_path = data_path
55        self.filenames = filenames
56        self.height = height
57        self.width = width
58        self.num_scales = num_scales
59        self.interp = Image.ANTIALIAS
60
61        self.frame_idxs = frame_idxs
62
63        self.is_train = is_train
64        self.img_ext = img_ext
65
66        self.loader = pil_loader
67        self.to_tensor = transforms.ToTensor()
68
69        self.brightness = 0.2
70        self.contrast = 0.2
71        self.saturation = 0.2
72        self.hue = 0.1
73
74        self.load_depth = self.check_depth()
75
76    def preprocess(self, inputs, color_aug):
77        """Resize colour images to the required scales and augment if required
78
79        We create the color_aug object in advance and apply the same augmentation to all
80        images in this item. This ensures that all images input to the pose network receive the
81        same augmentation.
82        """
83        for k in list(inputs):
84            if "color" in k:
85                n, im, i = k
86                for i in range(self.num_scales):
87                    s = 2 ** i
88                    size = (self.height // s, self.width // s)
89                    inputs[(n, im, i)] = copy.deepcopy(
90                        inputs[(n, im, i - 1)].resize(size[::-1], self.interp))
91
92        for k in list(inputs):
93            f = mx.nd.array(inputs[k])
94            if "color" in k:
95                n, im, i = k
96                inputs[(n, im, i)] = self.to_tensor(f)
97                inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f))
98
99    def __len__(self):
100        return len(self.filenames)
101
102    def __getitem__(self, index):
103        """Returns a single training item from the dataset as a dictionary.
104
105        Values correspond to mxnet NDArray.
106        Keys in the dictionary are either strings or tuples:
107
108            ("color", <frame_id>, <scale>)          for raw colour images,
109            ("color_aug", <frame_id>, <scale>)      for augmented colour images,
110            ("K", scale) or ("inv_K", scale)        for camera intrinsics,
111            "stereo_T"                              for camera extrinsics, and
112            "depth_gt"                              for ground truth depth maps.
113
114        <frame_id> is either:
115            an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index',
116        or
117            "s" for the opposite image in the stereo pair.
118
119        <scale> is an integer representing the scale of the image relative to the full-size image:
120            -1      images at native resolution as loaded from disk
121            0       images resized to (self.width,      self.height     )
122            1       images resized to (self.width // 2, self.height // 2)
123            2       images resized to (self.width // 4, self.height // 4)
124            3       images resized to (self.width // 8, self.height // 8)
125        """
126        inputs = {}
127
128        do_color_aug = False  # self.is_train and random.random() > 0.5
129        do_flip = self.is_train and random.random() > 0.5
130
131        line = self.filenames[index].split()
132        folder = line[0]
133
134        if len(line) == 3:
135            frame_index = int(line[1])
136        else:
137            frame_index = 0
138
139        if len(line) == 3:
140            side = line[2]
141        else:
142            side = None
143
144        for i in self.frame_idxs:
145            if i == "s":
146                other_side = {"r": "l", "l": "r"}[side]
147                inputs[("color", i, -1)] = self.get_color(
148                    folder, frame_index, other_side, do_flip)
149            else:
150                inputs[("color", i, -1)] = self.get_color(
151                    folder, frame_index + i, side, do_flip)
152
153        # adjusting intrinsics to match each scale in the pyramid
154        for scale in range(self.num_scales):
155            K = self.K.copy()
156
157            K[0, :] *= self.width // (2 ** scale)
158            K[1, :] *= self.height // (2 ** scale)
159
160            inv_K = np.linalg.pinv(K)
161
162            inputs[("K", scale)] = mx.nd.array(K)
163
164            inputs[("inv_K", scale)] = mx.nd.array(inv_K)
165
166        if do_color_aug:
167            color_aug = transforms.RandomColorJitter(
168                self.brightness, self.contrast, self.saturation, self.hue)
169        else:
170            color_aug = (lambda x: x)
171
172        self.preprocess(inputs, color_aug)
173
174        for i in self.frame_idxs:
175            del inputs[("color", i, -1)]
176            del inputs[("color_aug", i, -1)]
177
178        if self.load_depth:
179            depth_gt = self.get_depth(folder, frame_index, side, do_flip)
180            inputs["depth_gt"] = np.expand_dims(depth_gt, 0)
181            inputs["depth_gt"] = mx.nd.array(inputs["depth_gt"].astype(np.float32))
182
183        if "s" in self.frame_idxs:
184            stereo_T = np.eye(4, dtype=np.float32)
185            baseline_sign = -1 if do_flip else 1
186            side_sign = -1 if side == "l" else 1
187            stereo_T[0, 3] = side_sign * baseline_sign * 0.1
188
189            inputs["stereo_T"] = mx.nd.array(stereo_T)
190
191        return inputs
192
193    def get_color(self, folder, frame_index, side, do_flip):
194        raise NotImplementedError
195
196    def check_depth(self):
197        raise NotImplementedError
198
199    def get_depth(self, folder, frame_index, side, do_flip):
200        raise NotImplementedError
201