1"""Dual Attention Network 2https://arxiv.org/abs/1809.02983""" 3from mxnet.gluon import nn 4from mxnet.context import cpu 5from mxnet.gluon.nn import HybridBlock 6from .segbase import SegBaseModel 7from .fcn import _FCNHead 8from .attention import PAM_Module, CAM_Module 9# pylint: disable-all 10 11__all__ = ['DANet', 'get_danet', 'get_danet_resnet50_citys', 'get_danet_resnet101_citys'] 12 13 14class DANet(SegBaseModel): 15 r"""Dual Attention Networks for Semantic Segmentation 16 17 Parameters 18 ---------- 19 nclass : int 20 Number of categories for the training dataset. 21 backbone : string 22 Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50', 23 'resnet101'). 24 norm_layer : object 25 Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 26 27 Reference: 28 Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang, Hanqing Lu. "Dual Attention 29 Network for Scene Segmentation." *CVPR*, 2019 30 """ 31 32 def __init__(self, nclass, backbone='resnet50', aux=False, ctx=cpu(), pretrained_base=True, 33 height=None, width=None, base_size=520, crop_size=480, dilated=True, **kwargs): 34 super(DANet, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size, 35 crop_size=crop_size, pretrained_base=pretrained_base, **kwargs) 36 self.aux = aux 37 height = height if height is not None else crop_size 38 width = width if width is not None else crop_size 39 40 with self.name_scope(): 41 self.head = DANetHead(2048, nclass, **kwargs) 42 self.head.initialize(ctx=ctx) 43 44 self._up_kwargs = {'height': height, 'width': width} 45 46 def hybrid_forward(self, F, x): 47 c3, c4 = self.base_forward(x) 48 49 x = self.head(c4) 50 x = list(x) 51 x[0] = F.contrib.BilinearResize2D(x[0], **self._up_kwargs) 52 x[1] = F.contrib.BilinearResize2D(x[1], **self._up_kwargs) 53 x[2] = F.contrib.BilinearResize2D(x[2], **self._up_kwargs) 54 55 outputs = [x[0]] 56 outputs.append(x[1]) 57 outputs.append(x[2]) 58 59 return tuple(outputs) 60 61 62class DANetHead(HybridBlock): 63 def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs): 64 super(DANetHead, self).__init__() 65 inter_channels = in_channels // 4 66 self.conv5a = nn.HybridSequential() 67 self.conv5a.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels, kernel_size=3, 68 padding=1, use_bias=False)) 69 self.conv5a.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs))) 70 self.conv5a.add(nn.Activation('relu')) 71 72 self.conv5c = nn.HybridSequential() 73 self.conv5c.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels, kernel_size=3, 74 padding=1, use_bias=False)) 75 self.conv5c.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs))) 76 self.conv5c.add(nn.Activation('relu')) 77 78 self.sa = PAM_Module(inter_channels) 79 self.sc = CAM_Module(inter_channels) 80 self.conv51 = nn.HybridSequential() 81 self.conv51.add(nn.Conv2D(in_channels=inter_channels, channels=inter_channels, kernel_size=3, 82 padding=1, use_bias=False)) 83 self.conv51.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs))) 84 self.conv51.add(nn.Activation('relu')) 85 86 self.conv52 = nn.HybridSequential() 87 self.conv52.add(nn.Conv2D(in_channels=inter_channels, channels=inter_channels, kernel_size=3, 88 padding=1, use_bias=False)) 89 self.conv52.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs))) 90 self.conv52.add(nn.Activation('relu')) 91 92 self.conv6 = nn.HybridSequential() 93 self.conv6.add(nn.Dropout(0.1)) 94 self.conv6.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1)) 95 96 self.conv7 = nn.HybridSequential() 97 self.conv7.add(nn.Dropout(0.1)) 98 self.conv7.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1)) 99 100 self.conv8 = nn.HybridSequential() 101 self.conv8.add(nn.Dropout(0.1)) 102 self.conv8.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1)) 103 104 def hybrid_forward(self, F, x): 105 feat1 = self.conv5a(x) 106 sa_feat = self.sa(feat1) 107 sa_conv = self.conv51(sa_feat) 108 sa_output = self.conv6(sa_conv) 109 110 feat2 = self.conv5c(x) 111 sc_feat = self.sc(feat2) 112 sc_conv = self.conv52(sc_feat) 113 sc_output = self.conv7(sc_conv) 114 115 feat_sum = sa_conv + sc_conv 116 sasc_output = self.conv8(feat_sum) 117 118 output = [sasc_output] 119 output.append(sa_output) 120 output.append(sc_output) 121 122 return tuple(output) 123 124 def predict(self, x): 125 h, w = x.shape[2:] 126 self._up_kwargs['height'] = h 127 self._up_kwargs['width'] = w 128 _, c4 = self.base_forward(x) 129 x = self.head.demo(c4) 130 import mxnet.ndarray as F 131 pred = F.contrib.BilinearResize2D(x, **self._up_kwargs) 132 return pred 133 134 135def get_danet(dataset='pascal_voc', backbone='resnet50', pretrained=False, 136 root='~/.mxnet/models', ctx=cpu(0), **kwargs): 137 r"""DANet model from the paper `"Dual Attention Network for Scene Segmentation" 138 <https://arxiv.org/abs/1809.02983>` 139 """ 140 acronyms = { 141 'pascal_voc': 'voc', 142 'pascal_aug': 'voc', 143 'ade20k': 'ade', 144 'coco': 'coco', 145 'citys': 'citys', 146 } 147 from ..data import datasets 148 # infer number of classes 149 model = DANet(nclass=datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs) 150 model.classes = datasets[dataset].classes 151 if pretrained: 152 from .model_store import get_model_file 153 model.load_parameters(get_model_file('danet_%s_%s' % (backbone, acronyms[dataset]), 154 tag=pretrained, root=root), ctx=ctx) 155 return model 156 157 158def get_danet_resnet50_citys(**kwargs): 159 r"""DANet 160 Parameters 161 ---------- 162 pretrained : bool or str 163 Boolean value controls whether to load the default pretrained weights for model. 164 String value represents the hashtag for a certain version of pretrained weights. 165 ctx : Context, default CPU 166 The context in which to load the pretrained weights. 167 root : str, default '~/.mxnet/models' 168 Location for keeping the model parameters. 169 Examples 170 -------- 171 >>> model = get_danet_resnet50_citys(pretrained=True) 172 >>> print(model) 173 """ 174 return get_danet('citys', 'resnet50', **kwargs) 175 176 177def get_danet_resnet101_citys(**kwargs): 178 r"""DANet 179 Parameters 180 ---------- 181 pretrained : bool or str 182 Boolean value controls whether to load the default pretrained weights for model. 183 String value represents the hashtag for a certain version of pretrained weights. 184 ctx : Context, default CPU 185 The context in which to load the pretrained weights. 186 root : str, default '~/.mxnet/models' 187 Location for keeping the model parameters. 188 189 Examples 190 -------- 191 >>> model = get_danet_resnet101_citys(pretrained=True) 192 >>> print(model) 193 """ 194 return get_danet('citys', 'resnet101', **kwargs) 195