1"""Dual Attention Network
2https://arxiv.org/abs/1809.02983"""
3from mxnet.gluon import nn
4from mxnet.context import cpu
5from mxnet.gluon.nn import HybridBlock
6from .segbase import SegBaseModel
7from .fcn import _FCNHead
8from .attention import PAM_Module, CAM_Module
9# pylint: disable-all
10
11__all__ = ['DANet', 'get_danet', 'get_danet_resnet50_citys', 'get_danet_resnet101_citys']
12
13
14class DANet(SegBaseModel):
15    r"""Dual Attention Networks for Semantic Segmentation
16
17    Parameters
18    ----------
19    nclass : int
20        Number of categories for the training dataset.
21    backbone : string
22        Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
23        'resnet101').
24    norm_layer : object
25        Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`;
26
27    Reference:
28        Jun Fu, Jing Liu, Haijie Tian, Yong Li, Yongjun Bao, Zhiwei Fang, Hanqing Lu. "Dual Attention
29        Network for Scene Segmentation." *CVPR*, 2019
30    """
31
32    def __init__(self, nclass, backbone='resnet50', aux=False, ctx=cpu(), pretrained_base=True,
33                 height=None, width=None, base_size=520, crop_size=480, dilated=True, **kwargs):
34        super(DANet, self).__init__(nclass, aux, backbone, ctx=ctx, base_size=base_size,
35                                    crop_size=crop_size, pretrained_base=pretrained_base, **kwargs)
36        self.aux = aux
37        height = height if height is not None else crop_size
38        width = width if width is not None else crop_size
39
40        with self.name_scope():
41            self.head = DANetHead(2048, nclass, **kwargs)
42            self.head.initialize(ctx=ctx)
43
44        self._up_kwargs = {'height': height, 'width': width}
45
46    def hybrid_forward(self, F, x):
47        c3, c4 = self.base_forward(x)
48
49        x = self.head(c4)
50        x = list(x)
51        x[0] = F.contrib.BilinearResize2D(x[0], **self._up_kwargs)
52        x[1] = F.contrib.BilinearResize2D(x[1], **self._up_kwargs)
53        x[2] = F.contrib.BilinearResize2D(x[2], **self._up_kwargs)
54
55        outputs = [x[0]]
56        outputs.append(x[1])
57        outputs.append(x[2])
58
59        return tuple(outputs)
60
61
62class DANetHead(HybridBlock):
63    def __init__(self, in_channels, out_channels, norm_layer=nn.BatchNorm, norm_kwargs=None, **kwargs):
64        super(DANetHead, self).__init__()
65        inter_channels = in_channels // 4
66        self.conv5a = nn.HybridSequential()
67        self.conv5a.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels, kernel_size=3,
68                                  padding=1, use_bias=False))
69        self.conv5a.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs)))
70        self.conv5a.add(nn.Activation('relu'))
71
72        self.conv5c = nn.HybridSequential()
73        self.conv5c.add(nn.Conv2D(in_channels=in_channels, channels=inter_channels, kernel_size=3,
74                                  padding=1, use_bias=False))
75        self.conv5c.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs)))
76        self.conv5c.add(nn.Activation('relu'))
77
78        self.sa = PAM_Module(inter_channels)
79        self.sc = CAM_Module(inter_channels)
80        self.conv51 = nn.HybridSequential()
81        self.conv51.add(nn.Conv2D(in_channels=inter_channels, channels=inter_channels, kernel_size=3,
82                                  padding=1, use_bias=False))
83        self.conv51.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs)))
84        self.conv51.add(nn.Activation('relu'))
85
86        self.conv52 = nn.HybridSequential()
87        self.conv52.add(nn.Conv2D(in_channels=inter_channels, channels=inter_channels, kernel_size=3,
88                                  padding=1, use_bias=False))
89        self.conv52.add(norm_layer(in_channels=inter_channels, **({} if norm_kwargs is None else norm_kwargs)))
90        self.conv52.add(nn.Activation('relu'))
91
92        self.conv6 = nn.HybridSequential()
93        self.conv6.add(nn.Dropout(0.1))
94        self.conv6.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1))
95
96        self.conv7 = nn.HybridSequential()
97        self.conv7.add(nn.Dropout(0.1))
98        self.conv7.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1))
99
100        self.conv8 = nn.HybridSequential()
101        self.conv8.add(nn.Dropout(0.1))
102        self.conv8.add(nn.Conv2D(in_channels=512, channels=out_channels, kernel_size=1))
103
104    def hybrid_forward(self, F, x):
105        feat1 = self.conv5a(x)
106        sa_feat = self.sa(feat1)
107        sa_conv = self.conv51(sa_feat)
108        sa_output = self.conv6(sa_conv)
109
110        feat2 = self.conv5c(x)
111        sc_feat = self.sc(feat2)
112        sc_conv = self.conv52(sc_feat)
113        sc_output = self.conv7(sc_conv)
114
115        feat_sum = sa_conv + sc_conv
116        sasc_output = self.conv8(feat_sum)
117
118        output = [sasc_output]
119        output.append(sa_output)
120        output.append(sc_output)
121
122        return tuple(output)
123
124    def predict(self, x):
125        h, w = x.shape[2:]
126        self._up_kwargs['height'] = h
127        self._up_kwargs['width'] = w
128        _, c4 = self.base_forward(x)
129        x = self.head.demo(c4)
130        import mxnet.ndarray as F
131        pred = F.contrib.BilinearResize2D(x, **self._up_kwargs)
132        return pred
133
134
135def get_danet(dataset='pascal_voc', backbone='resnet50', pretrained=False,
136              root='~/.mxnet/models', ctx=cpu(0), **kwargs):
137    r"""DANet model from the paper `"Dual Attention Network for Scene Segmentation"
138    <https://arxiv.org/abs/1809.02983>`
139    """
140    acronyms = {
141        'pascal_voc': 'voc',
142        'pascal_aug': 'voc',
143        'ade20k': 'ade',
144        'coco': 'coco',
145        'citys': 'citys',
146    }
147    from ..data import datasets
148    # infer number of classes
149    model = DANet(nclass=datasets[dataset].NUM_CLASS, backbone=backbone, ctx=ctx, **kwargs)
150    model.classes = datasets[dataset].classes
151    if pretrained:
152        from .model_store import get_model_file
153        model.load_parameters(get_model_file('danet_%s_%s' % (backbone, acronyms[dataset]),
154                                             tag=pretrained, root=root), ctx=ctx)
155    return model
156
157
158def get_danet_resnet50_citys(**kwargs):
159    r"""DANet
160    Parameters
161    ----------
162    pretrained : bool or str
163        Boolean value controls whether to load the default pretrained weights for model.
164        String value represents the hashtag for a certain version of pretrained weights.
165    ctx : Context, default CPU
166        The context in which to load the pretrained weights.
167    root : str, default '~/.mxnet/models'
168        Location for keeping the model parameters.
169    Examples
170    --------
171    >>> model = get_danet_resnet50_citys(pretrained=True)
172    >>> print(model)
173    """
174    return get_danet('citys', 'resnet50', **kwargs)
175
176
177def get_danet_resnet101_citys(**kwargs):
178    r"""DANet
179    Parameters
180    ----------
181    pretrained : bool or str
182        Boolean value controls whether to load the default pretrained weights for model.
183        String value represents the hashtag for a certain version of pretrained weights.
184    ctx : Context, default CPU
185        The context in which to load the pretrained weights.
186    root : str, default '~/.mxnet/models'
187        Location for keeping the model parameters.
188
189    Examples
190    --------
191    >>> model = get_danet_resnet101_citys(pretrained=True)
192    >>> print(model)
193    """
194    return get_danet('citys', 'resnet101', **kwargs)
195