1"""Base segmentation dataset""" 2import random 3import numpy as np 4from PIL import Image, ImageOps, ImageFilter 5import mxnet as mx 6from mxnet import cpu 7import mxnet.ndarray as F 8from .base import VisionDataset 9 10__all__ = ['ms_batchify_fn', 'SegmentationDataset'] 11 12class SegmentationDataset(VisionDataset): 13 """Segmentation Base Dataset""" 14 # pylint: disable=abstract-method 15 def __init__(self, root, split, mode, transform, base_size=520, crop_size=480): 16 super(SegmentationDataset, self).__init__(root) 17 self.root = root 18 self.transform = transform 19 self.split = split 20 self.mode = mode if mode is not None else split 21 self.base_size = base_size 22 self.crop_size = crop_size 23 24 def _val_sync_transform(self, img, mask): 25 outsize = self.crop_size 26 short_size = outsize 27 w, h = img.size 28 if w > h: 29 oh = short_size 30 ow = int(1.0 * w * oh / h) 31 else: 32 ow = short_size 33 oh = int(1.0 * h * ow / w) 34 img = img.resize((ow, oh), Image.BILINEAR) 35 mask = mask.resize((ow, oh), Image.NEAREST) 36 # center crop 37 w, h = img.size 38 x1 = int(round((w - outsize) / 2.)) 39 y1 = int(round((h - outsize) / 2.)) 40 img = img.crop((x1, y1, x1+outsize, y1+outsize)) 41 mask = mask.crop((x1, y1, x1+outsize, y1+outsize)) 42 # final transform 43 img, mask = self._img_transform(img), self._mask_transform(mask) 44 return img, mask 45 46 def _sync_transform(self, img, mask): 47 # random mirror 48 if random.random() < 0.5: 49 img = img.transpose(Image.FLIP_LEFT_RIGHT) 50 mask = mask.transpose(Image.FLIP_LEFT_RIGHT) 51 crop_size = self.crop_size 52 # random scale (short edge) 53 long_size = random.randint(int(self.base_size*0.5), int(self.base_size*2.0)) 54 w, h = img.size 55 if h > w: 56 oh = long_size 57 ow = int(1.0 * w * long_size / h + 0.5) 58 short_size = ow 59 else: 60 ow = long_size 61 oh = int(1.0 * h * long_size / w + 0.5) 62 short_size = oh 63 img = img.resize((ow, oh), Image.BILINEAR) 64 mask = mask.resize((ow, oh), Image.NEAREST) 65 # pad crop 66 if short_size < crop_size: 67 padh = crop_size - oh if oh < crop_size else 0 68 padw = crop_size - ow if ow < crop_size else 0 69 img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0) 70 mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=0) 71 # random crop crop_size 72 w, h = img.size 73 x1 = random.randint(0, w - crop_size) 74 y1 = random.randint(0, h - crop_size) 75 img = img.crop((x1, y1, x1+crop_size, y1+crop_size)) 76 mask = mask.crop((x1, y1, x1+crop_size, y1+crop_size)) 77 # gaussian blur as in PSP 78 if random.random() < 0.5: 79 img = img.filter(ImageFilter.GaussianBlur( 80 radius=random.random())) 81 # final transform 82 img, mask = self._img_transform(img), self._mask_transform(mask) 83 return img, mask 84 85 def _img_transform(self, img): 86 return F.array(np.array(img), cpu(0)) 87 88 def _mask_transform(self, mask): 89 return F.array(np.array(mask), cpu(0)).astype('int32') 90 91 @property 92 def num_class(self): 93 """Number of categories.""" 94 return self.NUM_CLASS 95 96 @property 97 def pred_offset(self): 98 return 0 99 100def ms_batchify_fn(data): 101 """Multi-size batchify function""" 102 if isinstance(data[0], (str, mx.nd.NDArray)): 103 return list(data) 104 elif isinstance(data[0], tuple): 105 data = zip(*data) 106 return [ms_batchify_fn(i) for i in data] 107 raise RuntimeError('unknown datatype') 108