1"""Pascal VOC object detection dataset.""" 2from __future__ import absolute_import 3from __future__ import division 4 5import glob 6import logging 7import os 8import warnings 9 10import numpy as np 11 12try: 13 import xml.etree.cElementTree as ET 14except ImportError: 15 import xml.etree.ElementTree as ET 16import mxnet as mx 17from ..base import VisionDataset 18 19 20class VOCDetection(VisionDataset): 21 """Pascal VOC detection Dataset. 22 23 Parameters 24 ---------- 25 root : str, default '~/mxnet/datasets/voc' 26 Path to folder storing the dataset. 27 splits : list of tuples, default ((2007, 'trainval'), (2012, 'trainval')) 28 List of combinations of (year, name) 29 For years, candidates can be: 2007, 2012. 30 For names, candidates can be: 'train', 'val', 'trainval', 'test'. 31 transform : callable, default None 32 A function that takes data and label and transforms them. Refer to 33 :doc:`./transforms` for examples. 34 35 A transform function for object detection should take label into consideration, 36 because any geometric modification will require label to be modified. 37 index_map : dict, default None 38 In default, the 20 classes are mapped into indices from 0 to 19. We can 39 customize it by providing a str to int dict specifying how to map class 40 names to indices. Use by advanced users only, when you want to swap the orders 41 of class labels. 42 preload_label : bool, default True 43 If True, then parse and load all labels into memory during 44 initialization. It often accelerate speed but require more memory 45 usage. Typical preloaded labels took tens of MB. You only need to disable it 46 when your dataset is extremely large. 47 """ 48 CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 49 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 50 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') 51 52 def __init__(self, root=os.path.join('~', '.mxnet', 'datasets', 'voc'), 53 splits=((2007, 'trainval'), (2012, 'trainval')), 54 transform=None, index_map=None, preload_label=True): 55 super(VOCDetection, self).__init__(root) 56 self._im_shapes = {} 57 self._root = os.path.expanduser(root) 58 self._transform = transform 59 self._splits = splits 60 self._items = self._load_items(splits) 61 self._anno_path = os.path.join('{}', 'Annotations', '{}.xml') 62 self._image_path = os.path.join('{}', 'JPEGImages', '{}.jpg') 63 self.index_map = index_map or dict(zip(self.classes, range(self.num_class))) 64 self._label_cache = self._preload_labels() if preload_label else None 65 66 def __str__(self): 67 detail = ','.join([str(s[0]) + s[1] for s in self._splits]) 68 return self.__class__.__name__ + '(' + detail + ')' 69 70 @property 71 def classes(self): 72 """Category names.""" 73 try: 74 self._validate_class_names(self.CLASSES) 75 except AssertionError as e: 76 raise RuntimeError("Class names must not contain {}".format(e)) 77 return type(self).CLASSES 78 79 def __len__(self): 80 return len(self._items) 81 82 def __getitem__(self, idx): 83 img_id = self._items[idx] 84 img_path = self._image_path.format(*img_id) 85 label = self._label_cache[idx] if self._label_cache else self._load_label(idx) 86 img = mx.image.imread(img_path, 1) 87 if self._transform is not None: 88 return self._transform(img, label) 89 return img, label.copy() 90 91 def _load_items(self, splits): 92 """Load individual image indices from splits.""" 93 ids = [] 94 for subfolder, name in splits: 95 root = os.path.join( 96 self._root, ('VOC' + str(subfolder)) if isinstance(subfolder, int) else subfolder) 97 lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt') 98 with open(lf, 'r') as f: 99 ids += [(root, line.strip()) for line in f.readlines()] 100 return ids 101 102 def _load_label(self, idx): 103 """Parse xml file and return labels.""" 104 img_id = self._items[idx] 105 anno_path = self._anno_path.format(*img_id) 106 root = ET.parse(anno_path).getroot() 107 size = root.find('size') 108 width = float(size.find('width').text) 109 height = float(size.find('height').text) 110 if idx not in self._im_shapes: 111 # store the shapes for later usage 112 self._im_shapes[idx] = (width, height) 113 label = [] 114 for obj in root.iter('object'): 115 try: 116 difficult = int(obj.find('difficult').text) 117 except ValueError: 118 difficult = 0 119 cls_name = obj.find('name').text.strip().lower() 120 if cls_name not in self.classes: 121 continue 122 cls_id = self.index_map[cls_name] 123 xml_box = obj.find('bndbox') 124 xmin = (float(xml_box.find('xmin').text) - 1) 125 ymin = (float(xml_box.find('ymin').text) - 1) 126 xmax = (float(xml_box.find('xmax').text) - 1) 127 ymax = (float(xml_box.find('ymax').text) - 1) 128 try: 129 self._validate_label(xmin, ymin, xmax, ymax, width, height) 130 label.append([xmin, ymin, xmax, ymax, cls_id, difficult]) 131 except AssertionError as e: 132 logging.warning("Invalid label at %s, %s", anno_path, e) 133 return np.array(label) 134 135 def _validate_label(self, xmin, ymin, xmax, ymax, width, height): 136 """Validate labels.""" 137 assert 0 <= xmin < width, "xmin must in [0, {}), given {}".format(width, xmin) 138 assert 0 <= ymin < height, "ymin must in [0, {}), given {}".format(height, ymin) 139 assert xmin < xmax <= width, "xmax must in (xmin, {}], given {}".format(width, xmax) 140 assert ymin < ymax <= height, "ymax must in (ymin, {}], given {}".format(height, ymax) 141 142 def _validate_class_names(self, class_list): 143 """Validate class names.""" 144 assert all(c.islower() for c in class_list), "uppercase characters" 145 stripped = [c for c in class_list if c.strip() != c] 146 if stripped: 147 warnings.warn('white space removed for {}'.format(stripped)) 148 149 def _preload_labels(self): 150 """Preload all labels into memory.""" 151 logging.debug("Preloading %s labels into memory...", str(self)) 152 return [self._load_label(idx) for idx in range(len(self))] 153 154 155class CustomVOCDetection(VOCDetection): 156 """Custom Pascal VOC detection Dataset. 157 Classes are generated from dataset 158 generate_classes : bool, default False 159 If True, generate class labels base on the annotations instead of the default classe labels. 160 """ 161 162 def __init__(self, generate_classes=False, **kwargs): 163 super(CustomVOCDetection, self).__init__(**kwargs) 164 if generate_classes: 165 self.CLASSES = self._generate_classes() 166 167 def _generate_classes(self): 168 classes = set() 169 all_xml = glob.glob(os.path.join(self._root, 'Annotations', '*.xml')) 170 for each_xml_file in all_xml: 171 tree = ET.parse(each_xml_file) 172 root = tree.getroot() 173 for child in root: 174 if child.tag == 'object': 175 for item in child: 176 if item.tag == 'name': 177 classes.add(item.text) 178 classes = sorted(list(classes)) 179 return classes 180 181 182class CustomVOCDetectionBase(VOCDetection): 183 """Base class for custom Dataset which follows protocol/formatting of the well-known VOC object detection dataset. 184 185 Parameters 186 ---------- 187 class: tuple of classes, default = None 188 We reuse the neural network weights if the corresponding class appears in the pretrained model. 189 Otherwise, we randomly initialize the neural network weights for new classes. 190 root : str, default '~/mxnet/datasets/voc' 191 Path to folder storing the dataset. 192 splits : list of tuples, default ((2007, 'trainval'), (2012, 'trainval')) 193 List of combinations of (year, name) 194 For years, candidates can be: 2007, 2012. 195 For names, candidates can be: 'train', 'val', 'trainval', 'test'. 196 transform : callable, default = None 197 A function that takes data and label and transforms them. Refer to 198 :doc:`./transforms` for examples. 199 A transform function for object detection should take label into consideration, 200 because any geometric modification will require label to be modified. 201 index_map : dict, default = None 202 By default, the 20 classes are mapped into indices from 0 to 19. We can 203 customize it by providing a str to int dict specifying how to map class 204 names to indices. This is only for advanced users, when you want to swap the orders 205 of class labels. 206 preload_label : bool, default = True 207 If True, then parse and load all labels into memory during 208 initialization. It often accelerate speed but require more memory 209 usage. Typical preloaded labels took tens of MB. You only need to disable it 210 when your dataset is extremely large. 211 """ 212 213 def __init__(self, classes=None, root=os.path.join('~', '.mxnet', 'datasets', 'voc'), 214 splits=((2007, 'trainval'), (2012, 'trainval')), 215 transform=None, index_map=None, preload_label=True): 216 217 # update classes 218 if classes: 219 self._set_class(classes) 220 super(CustomVOCDetectionBase, self).__init__(root=root, 221 splits=splits, 222 transform=transform, 223 index_map=index_map, 224 preload_label=False) 225 self._items_new = [self._items[each_id] for each_id in range(len(self._items)) if self._check_valid(each_id)] 226 self._items = self._items_new 227 self._label_cache = self._preload_labels() if preload_label else None 228 229 @classmethod 230 def _set_class(cls, classes): 231 cls.CLASSES = classes 232 233 def _load_items(self, splits): 234 """Load individual image indices from splits.""" 235 ids = [] 236 for subfolder, name in splits: 237 root = os.path.join(self._root, subfolder) if subfolder else self._root 238 lf = os.path.join(root, 'ImageSets', 'Main', name + '.txt') 239 with open(lf, 'r') as f: 240 ids += [(root, line.strip()) for line in f.readlines()] 241 return ids 242 243 def _check_valid(self, idx, allow_difficult=True): 244 """Parse xml file and return labels.""" 245 img_id = self._items[idx] 246 anno_path = self._anno_path.format(*img_id) 247 root = ET.parse(anno_path).getroot() 248 size = root.find('size') 249 width = float(size.find('width').text) 250 height = float(size.find('height').text) 251 if idx not in self._im_shapes: 252 # store the shapes for later usage 253 self._im_shapes[idx] = (width, height) 254 for obj in root.iter('object'): 255 try: 256 difficult = int(obj.find('difficult').text) 257 except ValueError: 258 difficult = 0 259 cls_name = obj.find('name').text.strip().lower() 260 if cls_name not in self.classes: 261 continue 262 if difficult and not allow_difficult: 263 continue 264 # cls_id = self.index_map[cls_name] 265 xml_box = obj.find('bndbox') 266 xmin = (float(xml_box.find('xmin').text) - 1) 267 ymin = (float(xml_box.find('ymin').text) - 1) 268 xmax = (float(xml_box.find('xmax').text) - 1) 269 ymax = (float(xml_box.find('ymax').text) - 1) 270 271 if not ((0 <= xmin < width) and (0 <= ymin < height) \ 272 and (xmin < xmax <= width) and (ymin < ymax <= height)): 273 return False 274 275 return True 276