1"""Base segmentation dataset"""
2import random
3import numpy as np
4from PIL import Image, ImageOps, ImageFilter
5import mxnet as mx
6from mxnet import cpu
7import mxnet.ndarray as F
8from .base import VisionDataset
9
10__all__ = ['ms_batchify_fn', 'SegmentationDataset']
11
12class SegmentationDataset(VisionDataset):
13    """Segmentation Base Dataset"""
14    # pylint: disable=abstract-method
15    def __init__(self, root, split, mode, transform, base_size=520, crop_size=480):
16        super(SegmentationDataset, self).__init__(root)
17        self.root = root
18        self.transform = transform
19        self.split = split
20        self.mode = mode if mode is not None else split
21        self.base_size = base_size
22        self.crop_size = crop_size
23
24    def _val_sync_transform(self, img, mask):
25        outsize = self.crop_size
26        short_size = outsize
27        w, h = img.size
28        if w > h:
29            oh = short_size
30            ow = int(1.0 * w * oh / h)
31        else:
32            ow = short_size
33            oh = int(1.0 * h * ow / w)
34        img = img.resize((ow, oh), Image.BILINEAR)
35        mask = mask.resize((ow, oh), Image.NEAREST)
36        # center crop
37        w, h = img.size
38        x1 = int(round((w - outsize) / 2.))
39        y1 = int(round((h - outsize) / 2.))
40        img = img.crop((x1, y1, x1+outsize, y1+outsize))
41        mask = mask.crop((x1, y1, x1+outsize, y1+outsize))
42        # final transform
43        img, mask = self._img_transform(img), self._mask_transform(mask)
44        return img, mask
45
46    def _sync_transform(self, img, mask):
47        # random mirror
48        if random.random() < 0.5:
49            img = img.transpose(Image.FLIP_LEFT_RIGHT)
50            mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
51        crop_size = self.crop_size
52        # random scale (short edge)
53        long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0))
54        w, h = img.size
55        if h > w:
56            oh = long_size
57            ow = int(1.0 * w * long_size / h + 0.5)
58            short_size = ow
59        else:
60            ow = long_size
61            oh = int(1.0 * h * long_size / w + 0.5)
62            short_size = oh
63        img = img.resize((ow, oh), Image.BILINEAR)
64        mask = mask.resize((ow, oh), Image.NEAREST)
65        # pad crop
66        if short_size < crop_size:
67            padh = crop_size - oh if oh < crop_size else 0
68            padw = crop_size - ow if ow < crop_size else 0
69            img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
70            mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0)
71        # random crop crop_size
72        w, h = img.size
73        x1 = random.randint(0, w - crop_size)
74        y1 = random.randint(0, h - crop_size)
75        img = img.crop((x1, y1, x1+crop_size, y1+crop_size))
76        mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size))
77        # gaussian blur as in PSP
78        if random.random() < 0.5:
79            img = img.filter(ImageFilter.GaussianBlur(
80                radius=random.random()))
81        # final transform
82        img, mask = self._img_transform(img), self._mask_transform(mask)
83        return img, mask
84
85    def _img_transform(self, img):
86        return F.array(np.array(img), cpu(0))
87
88    def _mask_transform(self, mask):
89        return F.array(np.array(mask), cpu(0)).astype('int32')
90
91    @property
92    def num_class(self):
93        """Number of categories."""
94        return self.NUM_CLASS
95
96    @property
97    def pred_offset(self):
98        return 0
99
100def ms_batchify_fn(data):
101    """Multi-size batchify function"""
102    if isinstance(data[0], (str, mx.nd.NDArray)):
103        return list(data)
104    elif isinstance(data[0], tuple):
105        data = zip(*data)
106        return [ms_batchify_fn(i) for i in data]
107    raise RuntimeError('unknown datatype')
108