1"""Pascal VOC object detection dataset."""
2from __future__ import absolute_import
3from __future__ import division
4
5import glob
6import logging
7import os
8import warnings
9
10import numpy as np
11
12try:
13    import xml.etree.cElementTree as ET
14except ImportError:
15    import xml.etree.ElementTree as ET
16import mxnet as mx
17from ..base import VisionDataset
18
19
20class VOCDetection(VisionDataset):
21    """Pascal VOC detection Dataset.
22
23    Parameters
24    ----------
25    root : str, default '~/mxnet/datasets/voc'
26        Path to folder storing the dataset.
27    splits : list of tuples, default ((2007, 'trainval'), (2012, 'trainval'))
28        List of combinations of (year, name)
29        For years, candidates can be: 2007, 2012.
30        For names, candidates can be: 'train', 'val', 'trainval', 'test'.
31    transform : callable, default None
32        A function that takes data and label and transforms them. Refer to
33        :doc:`./transforms` for examples.
34
35        A transform function for object detection should take label into consideration,
36        because any geometric modification will require label to be modified.
37    index_map : dict, default None
38        In default, the 20 classes are mapped into indices from 0 to 19. We can
39        customize it by providing a str to int dict specifying how to map class
40        names to indices. Use by advanced users only, when you want to swap the orders
41        of class labels.
42    preload_label : bool, default True
43        If True, then parse and load all labels into memory during
44        initialization. It often accelerate speed but require more memory
45        usage. Typical preloaded labels took tens of MB. You only need to disable it
46        when your dataset is extremely large.
47    """
48    CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
49               'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike',
50               'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
51
52    def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'voc'),
53                 splits=((2007, 'trainval'), (2012, 'trainval')),
54                 transform=None, index_map=None, preload_label=True):
55        super(VOCDetection, self).__init__(root)
56        self._im_shapes = {}
57        self._root = os.path.expanduser(root)
58        self._transform = transform
59        self._splits = splits
60        self._items = self._load_items(splits)
61        self._anno_path = os.path.join('{}', 'Annotations', '{}.xml')
62        self._image_path = os.path.join('{}', 'JPEGImages', '{}.jpg')
63        self.index_map = index_map or dict(zip(self.classes, range(self.num_class)))
64        self._label_cache = self._preload_labels() if preload_label else None
65
66    def __str__(self):
67        detail = ','.join([str(s[0]) + s[1] for s in self._splits])
68        return self.__class__.__name__ + '(' + detail + ')'
69
70    @property
71    def classes(self):
72        """Category names."""
73        try:
74            self._validate_class_names(self.CLASSES)
75        except AssertionError as e:
76            raise RuntimeError("Class names must not contain {}".format(e))
77        return type(self).CLASSES
78
79    def __len__(self):
80        return len(self._items)
81
82    def __getitem__(self, idx):
83        img_id = self._items[idx]
84        img_path = self._image_path.format(*img_id)
85        label = self._label_cache[idx] if self._label_cache else self._load_label(idx)
86        img = mx.image.imread(img_path, 1)
87        if self._transform is not None:
88            return self._transform(img, label)
89        return img, label.copy()
90
91    def _load_items(self, splits):
92        """Load individual image indices from splits."""
93        ids = []
94        for subfolder, name in splits:
95            root = os.path.join(
96                self._root, ('VOC' + str(subfolder)) if isinstance(subfolder, int) else subfolder)
97            lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt')
98            with open(lf, 'r') as f:
99                ids += [(root, line.strip()) for line in f.readlines()]
100        return ids
101
102    def _load_label(self, idx):
103        """Parse xml file and return labels."""
104        img_id = self._items[idx]
105        anno_path = self._anno_path.format(*img_id)
106        root = ET.parse(anno_path).getroot()
107        size = root.find('size')
108        width = float(size.find('width').text)
109        height = float(size.find('height').text)
110        if idx not in self._im_shapes:
111            # store the shapes for later usage
112            self._im_shapes[idx] = (width, height)
113        label = []
114        for obj in root.iter('object'):
115            try:
116                difficult = int(obj.find('difficult').text)
117            except ValueError:
118                difficult = 0
119            cls_name = obj.find('name').text.strip().lower()
120            if cls_name not in self.classes:
121                continue
122            cls_id = self.index_map[cls_name]
123            xml_box = obj.find('bndbox')
124            xmin = (float(xml_box.find('xmin').text) - 1)
125            ymin = (float(xml_box.find('ymin').text) - 1)
126            xmax = (float(xml_box.find('xmax').text) - 1)
127            ymax = (float(xml_box.find('ymax').text) - 1)
128            try:
129                self._validate_label(xmin, ymin, xmax, ymax, width, height)
130                label.append([xmin, ymin, xmax, ymax, cls_id, difficult])
131            except AssertionError as e:
132                logging.warning("Invalid label at %s, %s", anno_path, e)
133        return np.array(label)
134
135    def _validate_label(self, xmin, ymin, xmax, ymax, width, height):
136        """Validate labels."""
137        assert 0 <= xmin < width, "xmin must in [0, {}), given {}".format(width, xmin)
138        assert 0 <= ymin < height, "ymin must in [0, {}), given {}".format(height, ymin)
139        assert xmin < xmax <= width, "xmax must in (xmin, {}], given {}".format(width, xmax)
140        assert ymin < ymax <= height, "ymax must in (ymin, {}], given {}".format(height, ymax)
141
142    def _validate_class_names(self, class_list):
143        """Validate class names."""
144        assert all(c.islower() for c in class_list), "uppercase characters"
145        stripped = [c for c in class_list if c.strip() != c]
146        if stripped:
147            warnings.warn('white space removed for {}'.format(stripped))
148
149    def _preload_labels(self):
150        """Preload all labels into memory."""
151        logging.debug("Preloading %s labels into memory...", str(self))
152        return [self._load_label(idx) for idx in range(len(self))]
153
154
155class CustomVOCDetection(VOCDetection):
156    """Custom Pascal VOC detection Dataset.
157    Classes are generated from dataset
158    generate_classes : bool, default False
159        If True, generate class labels base on the annotations instead of the default classe labels.
160    """
161
162    def __init__(self, generate_classes=False, **kwargs):
163        super(CustomVOCDetection, self).__init__(**kwargs)
164        if generate_classes:
165            self.CLASSES = self._generate_classes()
166
167    def _generate_classes(self):
168        classes = set()
169        all_xml = glob.glob(os.path.join(self._root, 'Annotations', '*.xml'))
170        for each_xml_file in all_xml:
171            tree = ET.parse(each_xml_file)
172            root = tree.getroot()
173            for child in root:
174                if child.tag == 'object':
175                    for item in child:
176                        if item.tag == 'name':
177                            classes.add(item.text)
178        classes = sorted(list(classes))
179        return classes
180
181
182class CustomVOCDetectionBase(VOCDetection):
183    """Base class for custom Dataset which follows protocol/formatting of the well-known VOC object detection dataset.
184
185    Parameters
186    ----------
187    class: tuple of classes, default = None
188        We reuse the neural network weights if the corresponding class appears in the pretrained model.
189        Otherwise, we randomly initialize the neural network weights for new classes.
190    root : str, default '~/mxnet/datasets/voc'
191        Path to folder storing the dataset.
192    splits : list of tuples, default ((2007, 'trainval'), (2012, 'trainval'))
193        List of combinations of (year, name)
194        For years, candidates can be: 2007, 2012.
195        For names, candidates can be: 'train', 'val', 'trainval', 'test'.
196    transform : callable, default = None
197        A function that takes data and label and transforms them. Refer to
198        :doc:`./transforms` for examples.
199        A transform function for object detection should take label into consideration,
200        because any geometric modification will require label to be modified.
201    index_map : dict, default = None
202        By default, the 20 classes are mapped into indices from 0 to 19. We can
203        customize it by providing a str to int dict specifying how to map class
204        names to indices. This is only for advanced users, when you want to swap the orders
205        of class labels.
206    preload_label : bool, default = True
207        If True, then parse and load all labels into memory during
208        initialization. It often accelerate speed but require more memory
209        usage. Typical preloaded labels took tens of MB. You only need to disable it
210        when your dataset is extremely large.
211    """
212
213    def __init__(self, classes=None, root=os.path.join('~', '.mxnet', 'datasets', 'voc'),
214                 splits=((2007, 'trainval'), (2012, 'trainval')),
215                 transform=None, index_map=None, preload_label=True):
216
217        # update classes
218        if classes:
219            self._set_class(classes)
220        super(CustomVOCDetectionBase, self).__init__(root=root,
221                                                     splits=splits,
222                                                     transform=transform,
223                                                     index_map=index_map,
224                                                     preload_label=False)
225        self._items_new = [self._items[each_id] for each_id in range(len(self._items)) if self._check_valid(each_id)]
226        self._items = self._items_new
227        self._label_cache = self._preload_labels() if preload_label else None
228
229    @classmethod
230    def _set_class(cls, classes):
231        cls.CLASSES = classes
232
233    def _load_items(self, splits):
234        """Load individual image indices from splits."""
235        ids = []
236        for subfolder, name in splits:
237            root = os.path.join(self._root, subfolder) if subfolder else self._root
238            lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt')
239            with open(lf, 'r') as f:
240                ids += [(root, line.strip()) for line in f.readlines()]
241        return ids
242
243    def _check_valid(self, idx, allow_difficult=True):
244        """Parse xml file and return labels."""
245        img_id = self._items[idx]
246        anno_path = self._anno_path.format(*img_id)
247        root = ET.parse(anno_path).getroot()
248        size = root.find('size')
249        width = float(size.find('width').text)
250        height = float(size.find('height').text)
251        if idx not in self._im_shapes:
252            # store the shapes for later usage
253            self._im_shapes[idx] = (width, height)
254        for obj in root.iter('object'):
255            try:
256                difficult = int(obj.find('difficult').text)
257            except ValueError:
258                difficult = 0
259            cls_name = obj.find('name').text.strip().lower()
260            if cls_name not in self.classes:
261                continue
262            if difficult and not allow_difficult:
263                continue
264            # cls_id = self.index_map[cls_name]
265            xml_box = obj.find('bndbox')
266            xmin = (float(xml_box.find('xmin').text) - 1)
267            ymin = (float(xml_box.find('ymin').text) - 1)
268            xmax = (float(xml_box.find('xmax').text) - 1)
269            ymax = (float(xml_box.find('ymax').text) - 1)
270
271            if not ((0 <= xmin < width) and (0 <= ymin < height) \
272                and (xmin < xmax <= width) and (ymin < ymax <= height)):
273                return False
274
275        return True
276