1"""Pascal ADE20K Semantic Segmentation Dataset."""
2import os
3from PIL import Image
4import numpy as np
5import mxnet as mx
6from ..segbase import SegmentationDataset
7
8class ADE20KSegmentation(SegmentationDataset):
9    """ADE20K Semantic Segmentation Dataset.
10
11    Parameters
12    ----------
13    root : string
14        Path to VOCdevkit folder. Default is '$(HOME)/mxnet/datasplits/ade'
15    split: string
16        'train', 'val' or 'test'
17    transform : callable, optional
18        A function that transforms the image
19
20    Examples
21    --------
22    >>> from mxnet.gluon.data.vision import transforms
23    >>> # Transforms for Normalization
24    >>> input_transform = transforms.Compose([
25    >>>     transforms.ToTensor(),
26    >>>     transforms.Normalize([.485, .456, .406], [.229, .224, .225]),
27    >>> ])
28    >>> # Create Dataset
29    >>> trainset = gluoncv.data.ADE20KSegmentation(split='train', transform=input_transform)
30    >>> # Create Training Loader
31    >>> train_data = gluon.data.DataLoader(
32    >>>     trainset, 4, shuffle=True, last_batch='rollover',
33    >>>     num_workers=4)
34    """
35    # pylint: disable=abstract-method
36    BASE_DIR = 'ADEChallengeData2016'
37    NUM_CLASS = 150
38    CLASSES = ("wall", "building, edifice", "sky", "floor, flooring", "tree",
39               "ceiling", "road, route", "bed", "windowpane, window", "grass",
40               "cabinet", "sidewalk, pavement",
41               "person, individual, someone, somebody, mortal, soul",
42               "earth, ground", "door, double door", "table", "mountain, mount",
43               "plant, flora, plant life", "curtain, drape, drapery, mantle, pall",
44               "chair", "car, auto, automobile, machine, motorcar",
45               "water", "painting, picture", "sofa, couch, lounge", "shelf",
46               "house", "sea", "mirror", "rug, carpet, carpeting", "field", "armchair",
47               "seat", "fence, fencing", "desk", "rock, stone", "wardrobe, closet, press",
48               "lamp", "bathtub, bathing tub, bath, tub", "railing, rail", "cushion",
49               "base, pedestal, stand", "box", "column, pillar", "signboard, sign",
50               "chest of drawers, chest, bureau, dresser", "counter", "sand", "sink",
51               "skyscraper", "fireplace, hearth, open fireplace", "refrigerator, icebox",
52               "grandstand, covered stand", "path", "stairs, steps", "runway",
53               "case, display case, showcase, vitrine",
54               "pool table, billiard table, snooker table", "pillow",
55               "screen door, screen", "stairway, staircase", "river", "bridge, span",
56               "bookcase", "blind, screen", "coffee table, cocktail table",
57               "toilet, can, commode, crapper, pot, potty, stool, throne",
58               "flower", "book", "hill", "bench", "countertop",
59               "stove, kitchen stove, range, kitchen range, cooking stove",
60               "palm, palm tree", "kitchen island",
61               "computer, computing machine, computing device, data processor, "
62               "electronic computer, information processing system",
63               "swivel chair", "boat", "bar", "arcade machine",
64               "hovel, hut, hutch, shack, shanty",
65               "bus, autobus, coach, charabanc, double-decker, jitney, motorbus, "
66               "motorcoach, omnibus, passenger vehicle",
67               "towel", "light, light source", "truck, motortruck", "tower",
68               "chandelier, pendant, pendent", "awning, sunshade, sunblind",
69               "streetlight, street lamp", "booth, cubicle, stall, kiosk",
70               "television receiver, television, television set, tv, tv set, idiot "
71               "box, boob tube, telly, goggle box",
72               "airplane, aeroplane, plane", "dirt track",
73               "apparel, wearing apparel, dress, clothes",
74               "pole", "land, ground, soil",
75               "bannister, banister, balustrade, balusters, handrail",
76               "escalator, moving staircase, moving stairway",
77               "ottoman, pouf, pouffe, puff, hassock",
78               "bottle", "buffet, counter, sideboard",
79               "poster, posting, placard, notice, bill, card",
80               "stage", "van", "ship", "fountain",
81               "conveyer belt, conveyor belt, conveyer, conveyor, transporter",
82               "canopy", "washer, automatic washer, washing machine",
83               "plaything, toy", "swimming pool, swimming bath, natatorium",
84               "stool", "barrel, cask", "basket, handbasket", "waterfall, falls",
85               "tent, collapsible shelter", "bag", "minibike, motorbike", "cradle",
86               "oven", "ball", "food, solid food", "step, stair", "tank, storage tank",
87               "trade name, brand name, brand, marque", "microwave, microwave oven",
88               "pot, flowerpot", "animal, animate being, beast, brute, creature, fauna",
89               "bicycle, bike, wheel, cycle", "lake",
90               "dishwasher, dish washer, dishwashing machine",
91               "screen, silver screen, projection screen",
92               "blanket, cover", "sculpture", "hood, exhaust hood", "sconce", "vase",
93               "traffic light, traffic signal, stoplight", "tray",
94               "ashcan, trash can, garbage can, wastebin, ash bin, ash-bin, ashbin, "
95               "dustbin, trash barrel, trash bin",
96               "fan", "pier, wharf, wharfage, dock", "crt screen",
97               "plate", "monitor, monitoring device", "bulletin board, notice board",
98               "shower", "radiator", "glass, drinking glass", "clock", "flag")
99    def __init__(self, root=os.path.expanduser('~/.mxnet/datasets/ade'),
100                 split='train', mode=None, transform=None, **kwargs):
101        super(ADE20KSegmentation, self).__init__(root, split, mode, transform, **kwargs)
102        root = os.path.join(root, self.BASE_DIR)
103        assert os.path.exists(root), "Please setup the dataset using" + \
104            "scripts/datasets/ade20k.py"
105        self.images, self.masks = _get_ade20k_pairs(root, split)
106        assert (len(self.images) == len(self.masks))
107        if len(self.images) == 0:
108            raise(RuntimeError("Found 0 images in subfolders of: \
109                " + root + "\n"))
110
111    def __getitem__(self, index):
112        img = Image.open(self.images[index]).convert('RGB')
113        if self.mode == 'test':
114            img = self._img_transform(img)
115            if self.transform is not None:
116                img = self.transform(img)
117            return img, os.path.basename(self.images[index])
118        mask = Image.open(self.masks[index])
119        # synchrosized transform
120        if self.mode == 'train':
121            img, mask = self._sync_transform(img, mask)
122        elif self.mode == 'val':
123            img, mask = self._val_sync_transform(img, mask)
124        else:
125            assert self.mode == 'testval'
126            img, mask = self._img_transform(img), self._mask_transform(mask)
127        # general resize, normalize and toTensor
128        if self.transform is not None:
129            img = self.transform(img)
130        return img, mask
131
132    def _mask_transform(self, mask):
133        return mx.nd.array(np.array(mask), mx.cpu(0)).astype('int32') - 1
134
135    def __len__(self):
136        return len(self.images)
137
138    @property
139    def classes(self):
140        """Category names."""
141        return type(self).CLASSES
142
143    @property
144    def pred_offset(self):
145        return 1
146
147def _get_ade20k_pairs(folder, mode='train'):
148    img_paths = []
149    mask_paths = []
150    if mode == 'train':
151        img_folder = os.path.join(folder, 'images/training')
152        mask_folder = os.path.join(folder, 'annotations/training')
153    else:
154        img_folder = os.path.join(folder, 'images/validation')
155        mask_folder = os.path.join(folder, 'annotations/validation')
156    for filename in os.listdir(img_folder):
157        basename, _ = os.path.splitext(filename)
158        if filename.endswith(".jpg"):
159            imgpath = os.path.join(img_folder, filename)
160            maskname = basename + '.png'
161            maskpath = os.path.join(mask_folder, maskname)
162            if os.path.isfile(maskpath):
163                img_paths.append(imgpath)
164                mask_paths.append(maskpath)
165            else:
166                print('cannot find the mask:', maskpath)
167
168    return img_paths, mask_paths
169