1"""Single-shot Multi-box Detector."""
2from __future__ import absolute_import
3
4import os
5import warnings
6import mxnet as mx
7from mxnet import autograd
8from mxnet.gluon import nn
9from mxnet.gluon import HybridBlock
10from ...nn.feature import FeatureExpander
11from .anchor import SSDAnchorGenerator
12from ...nn.predictor import ConvPredictor
13from ...nn.coder import MultiPerClassDecoder, NormalizedBoxCenterDecoder
14from .vgg_atrous import vgg16_atrous_300, vgg16_atrous_512
15
16__all__ = ['SSD', 'get_ssd', 'custom_ssd']
17
18
19class SSD(HybridBlock):
20    """Single-shot Object Detection Network: https://arxiv.org/abs/1512.02325.
21
22    Parameters
23    ----------
24    network : string or None
25        Name of the base network, if `None` is used, will instantiate the
26        base network from `features` directly instead of composing.
27    base_size : int
28        Base input size, it is speficied so SSD can support dynamic input shapes.
29    features : list of str or mxnet.gluon.HybridBlock
30        Intermediate features to be extracted or a network with multi-output.
31        If `network` is `None`, `features` is expected to be a multi-output network.
32    num_filters : list of int
33        Number of channels for the appended layers, ignored if `network`is `None`.
34    sizes : iterable fo float
35        Sizes of anchor boxes, this should be a list of floats, in incremental order.
36        The length of `sizes` must be len(layers) + 1. For example, a two stage SSD
37        model can have ``sizes = [30, 60, 90]``, and it converts to `[30, 60]` and
38        `[60, 90]` for the two stages, respectively. For more details, please refer
39        to original paper.
40    ratios : iterable of list
41        Aspect ratios of anchors in each output layer. Its length must be equals
42        to the number of SSD output layers.
43    steps : list of int
44        Step size of anchor boxes in each output layer.
45    classes : iterable of str
46        Names of all categories.
47    use_1x1_transition : bool
48        Whether to use 1x1 convolution as transition layer between attached layers,
49        it is effective reducing model capacity.
50    use_bn : bool
51        Whether to use BatchNorm layer after each attached convolutional layer.
52    reduce_ratio : float
53        Channel reduce ratio (0, 1) of the transition layer.
54    min_depth : int
55        Minimum channels for the transition layers.
56    global_pool : bool
57        Whether to attach a global average pooling layer as the last output layer.
58    pretrained : bool or str
59        Boolean value controls whether to load the default pretrained weights for model.
60        String value represents the hashtag for a certain version of pretrained weights.
61    stds : tuple of float, default is (0.1, 0.1, 0.2, 0.2)
62        Std values to be divided/multiplied to box encoded values.
63    nms_thresh : float, default is 0.45.
64        Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS.
65    nms_topk : int, default is 400
66        Apply NMS to top k detection results, use -1 to disable so that every Detection
67         result is used in NMS.
68    post_nms : int, default is 100
69        Only return top `post_nms` detection results, the rest is discarded. The number is
70        based on COCO dataset which has maximum 100 objects per image. You can adjust this
71        number if expecting more objects. You can use -1 to return all detections.
72    anchor_alloc_size : tuple of int, default is (128, 128)
73        For advanced users. Define `anchor_alloc_size` to generate large enough anchor
74        maps, which will later saved in parameters. During inference, we support arbitrary
75        input image by cropping corresponding area of the anchor map. This allow us
76        to export to symbol so we can run it in c++, scalar, etc.
77    ctx : mx.Context
78        Network context.
79    norm_layer : object
80        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
81        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
82        This will only apply to base networks that has `norm_layer` specified, will ignore if the
83        base network (e.g. VGG) don't accept this argument.
84    norm_kwargs : dict
85        Additional `norm_layer` arguments, for example `num_devices=4`
86        for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
87    root : str
88        The root path for model storage, default is '~/.mxnet/models'
89    minimal_opset : bool
90        We sometimes add special operators to accelerate training/inference, however, for exporting
91        to third party compilers we want to utilize most widely used operators.
92        If `minimal_opset` is `True`, the network will use a minimal set of operators good
93        for e.g., `TVM`.
94    predictor_kernel: tuple of int. default is (3,3)
95        Dimension of predictor kernel
96    predictor_pad: tuple of int. default is (1,1)
97        Padding of the predictor kenrel conv.
98    anchor_generator: default is SSDAnchorGenerator
99        Anchor Generator to be used. The default it SSDAnchorGenerator corresponding
100        to SSD published article. This argument can be used for other custom
101        anchor generators. Like LiteAnchorGenerator.
102
103    """
104    def __init__(self, network, base_size, features, num_filters, sizes, ratios,
105                 steps, classes, use_1x1_transition=True, use_bn=True,
106                 reduce_ratio=1.0, min_depth=128, global_pool=False, pretrained=False,
107                 stds=(0.1, 0.1, 0.2, 0.2), nms_thresh=0.45, nms_topk=400, post_nms=100,
108                 anchor_alloc_size=128, ctx=mx.cpu(),
109                 norm_layer=nn.BatchNorm, norm_kwargs=None,
110                 root=os.path.join('~', '.mxnet', 'models'), minimal_opset=False,
111                 predictors_kernel=(3, 3), predictors_pad=(1, 1),
112                 anchor_generator=SSDAnchorGenerator, **kwargs):
113        super(SSD, self).__init__(**kwargs)
114        if norm_kwargs is None:
115            norm_kwargs = {}
116        if network is None:
117            num_layers = len(ratios)
118        else:
119            num_layers = len(features) + len(num_filters) + int(global_pool)
120        assert len(sizes) == num_layers + 1
121        sizes = list(zip(sizes[:-1], sizes[1:]))
122        assert isinstance(ratios, list), "Must provide ratios as list or list of list"
123        if not isinstance(ratios[0], (tuple, list)):
124            ratios = ratios * num_layers  # propagate to all layers if use same ratio
125        assert num_layers == len(sizes) == len(ratios), \
126            "Mismatched (number of layers) vs (sizes) vs (ratios): {}, {}, {}".format(
127                num_layers, len(sizes), len(ratios))
128        assert num_layers > 0, "SSD require at least one layer, suggest multiple."
129        self._num_layers = num_layers
130        self.classes = classes
131        self.nms_thresh = nms_thresh
132        self.nms_topk = nms_topk
133        self.post_nms = post_nms
134
135        with self.name_scope():
136            if network is None:
137                # use fine-grained manually designed block as features
138                try:
139                    self.features = features(pretrained=pretrained, ctx=ctx, root=root,
140                                             norm_layer=norm_layer, norm_kwargs=norm_kwargs)
141                except TypeError:
142                    self.features = features(pretrained=pretrained, ctx=ctx, root=root)
143            else:
144                try:
145                    self.features = FeatureExpander(
146                        network=network, outputs=features, num_filters=num_filters,
147                        use_1x1_transition=use_1x1_transition,
148                        use_bn=use_bn, reduce_ratio=reduce_ratio, min_depth=min_depth,
149                        global_pool=global_pool, pretrained=pretrained, ctx=ctx,
150                        norm_layer=norm_layer, norm_kwargs=norm_kwargs, root=root)
151                except TypeError:
152                    self.features = FeatureExpander(
153                        network=network, outputs=features, num_filters=num_filters,
154                        use_1x1_transition=use_1x1_transition,
155                        use_bn=use_bn, reduce_ratio=reduce_ratio, min_depth=min_depth,
156                        global_pool=global_pool, pretrained=pretrained, ctx=ctx, root=root)
157            self.class_predictors = nn.HybridSequential()
158            self.box_predictors = nn.HybridSequential()
159            self.anchor_generators = nn.HybridSequential()
160            asz = anchor_alloc_size
161            im_size = (base_size, base_size)
162            for i, s, r, st in zip(range(num_layers), sizes, ratios, steps):
163                branch_anchor_generator = anchor_generator(i, im_size, s, r, st, (asz, asz))
164                self.anchor_generators.add(branch_anchor_generator)
165                asz = max(asz // 2, 16)  # pre-compute larger than 16x16 anchor map
166                num_anchors = branch_anchor_generator.num_depth
167                self.class_predictors.add(ConvPredictor(num_anchors * (len(self.classes) + 1),
168                                                        kernel=predictors_kernel,
169                                                        pad=predictors_pad))
170                self.box_predictors.add(ConvPredictor(num_anchors * 4,
171                                                      kernel=predictors_kernel,
172                                                      pad=predictors_pad))
173            self.bbox_decoder = NormalizedBoxCenterDecoder(stds, minimal_opset=minimal_opset)
174            self.cls_decoder = MultiPerClassDecoder(len(self.classes) + 1, thresh=0.01)
175
176    @property
177    def num_classes(self):
178        """Return number of foreground classes.
179
180        Returns
181        -------
182        int
183            Number of foreground classes
184
185        """
186        return len(self.classes)
187
188    def set_nms(self, nms_thresh=0.45, nms_topk=400, post_nms=100):
189        """Set non-maximum suppression parameters.
190
191        Parameters
192        ----------
193        nms_thresh : float, default is 0.45.
194            Non-maximum suppression threshold. You can specify < 0 or > 1 to disable NMS.
195        nms_topk : int, default is 400
196            Apply NMS to top k detection results, use -1 to disable so that every Detection
197             result is used in NMS.
198        post_nms : int, default is 100
199            Only return top `post_nms` detection results, the rest is discarded. The number is
200            based on COCO dataset which has maximum 100 objects per image. You can adjust this
201            number if expecting more objects. You can use -1 to return all detections.
202
203        Returns
204        -------
205        None
206
207        """
208        self._clear_cached_op()
209        self.nms_thresh = nms_thresh
210        self.nms_topk = nms_topk
211        self.post_nms = post_nms
212
213    # pylint: disable=arguments-differ
214    def hybrid_forward(self, F, x):
215        """Hybrid forward"""
216        features = self.features(x)
217        cls_preds = [F.flatten(F.transpose(cp(feat), (0, 2, 3, 1)))
218                     for feat, cp in zip(features, self.class_predictors)]
219        box_preds = [F.flatten(F.transpose(bp(feat), (0, 2, 3, 1)))
220                     for feat, bp in zip(features, self.box_predictors)]
221        anchors = [F.reshape(ag(feat), shape=(1, -1))
222                   for feat, ag in zip(features, self.anchor_generators)]
223        cls_preds = F.concat(*cls_preds, dim=1).reshape((0, -1, self.num_classes + 1))
224        box_preds = F.concat(*box_preds, dim=1).reshape((0, -1, 4))
225        anchors = F.concat(*anchors, dim=1).reshape((1, -1, 4))
226        if autograd.is_training():
227            return [cls_preds, box_preds, anchors]
228        bboxes = self.bbox_decoder(box_preds, anchors)
229        cls_ids, scores = self.cls_decoder(F.softmax(cls_preds, axis=-1))
230        results = []
231        for i in range(self.num_classes):
232            cls_id = cls_ids.slice_axis(axis=-1, begin=i, end=i+1)
233            score = scores.slice_axis(axis=-1, begin=i, end=i+1)
234            # per class results
235            per_result = F.concat(*[cls_id, score, bboxes], dim=-1)
236            results.append(per_result)
237        result = F.concat(*results, dim=1)
238        if self.nms_thresh > 0 and self.nms_thresh < 1:
239            result = F.contrib.box_nms(
240                result, overlap_thresh=self.nms_thresh, topk=self.nms_topk, valid_thresh=0.01,
241                id_index=0, score_index=1, coord_start=2, force_suppress=False)
242            if self.post_nms > 0:
243                result = result.slice_axis(axis=1, begin=0, end=self.post_nms)
244        ids = F.slice_axis(result, axis=2, begin=0, end=1)
245        scores = F.slice_axis(result, axis=2, begin=1, end=2)
246        bboxes = F.slice_axis(result, axis=2, begin=2, end=6)
247        return ids, scores, bboxes
248
249    def reset_class(self, classes, reuse_weights=None):
250        """Reset class categories and class predictors.
251
252        Parameters
253        ----------
254        classes : iterable of str
255            The new categories. ['apple', 'orange'] for example.
256        reuse_weights : dict
257            A {new_integer : old_integer} or mapping dict or {new_name : old_name} mapping dict,
258            or a list of [name0, name1,...] if class names don't change.
259            This allows the new predictor to reuse the
260            previously trained weights specified.
261
262        Example
263        -------
264        >>> net = gluoncv.model_zoo.get_model('ssd_512_resnet50_v1_voc', pretrained=True)
265        >>> # use direct name to name mapping to reuse weights
266        >>> net.reset_class(classes=['person'], reuse_weights={'person':'person'})
267        >>> # or use interger mapping, person is the 14th category in VOC
268        >>> net.reset_class(classes=['person'], reuse_weights={0:14})
269        >>> # you can even mix them
270        >>> net.reset_class(classes=['person'], reuse_weights={'person':14})
271        >>> # or use a list of string if class name don't change
272        >>> net.reset_class(classes=['person'], reuse_weights=['person'])
273
274        """
275        self._clear_cached_op()
276        old_classes = self.classes
277        self.classes = classes
278        # trying to reuse weights by mapping old and new classes
279        if isinstance(reuse_weights, (dict, list)):
280            if isinstance(reuse_weights, dict):
281                # trying to replace str with indices
282                new_keys = []
283                new_vals = []
284                for k, v in reuse_weights.items():
285                    if isinstance(v, str):
286                        try:
287                            new_vals.append(old_classes.index(v))  # raise ValueError if not found
288                        except ValueError:
289                            raise ValueError(
290                                "{} not found in old class names {}".format(v, old_classes))
291                    else:
292                        if v < 0 or v >= len(old_classes):
293                            raise ValueError(
294                                "Index {} out of bounds for old class names".format(v))
295                        new_vals.append(v)
296                    if isinstance(k, str):
297                        try:
298                            new_keys.append(self.classes.index(k))  # raise ValueError if not found
299                        except ValueError:
300                            raise ValueError(
301                                "{} not found in new class names {}".format(k, self.classes))
302                    else:
303                        if k < 0 or k >= len(self.classes):
304                            raise ValueError(
305                                "Index {} out of bounds for new class names".format(k))
306                        new_keys.append(k)
307                reuse_weights = dict(zip(new_keys, new_vals))
308            else:
309                new_map = {}
310                for x in reuse_weights:
311                    try:
312                        new_idx = self.classes.index(x)
313                        old_idx = old_classes.index(x)
314                        new_map[new_idx] = old_idx
315                    except ValueError:
316                        warnings.warn("{} not found in old: {} or new class names: {}".format(
317                            x, old_classes, self.classes))
318                reuse_weights = new_map
319        # replace class predictors
320        with self.name_scope():
321            class_predictors = nn.HybridSequential(prefix=self.class_predictors.prefix)
322            for i, ag in zip(range(len(self.class_predictors)), self.anchor_generators):
323                # Re-use the same prefix and ctx_list as used by the current ConvPredictor
324                prefix = self.class_predictors[i].prefix
325                old_pred = self.class_predictors[i].predictor
326                ctx = list(old_pred.params.values())[0].list_ctx()
327                # to avoid deferred init, number of in_channels must be defined
328                in_channels = list(old_pred.params.values())[0].shape[1]
329                new_cp = ConvPredictor(ag.num_depth * (self.num_classes + 1),
330                                       in_channels=in_channels, prefix=prefix)
331                new_cp.collect_params().initialize(ctx=ctx)
332                if reuse_weights:
333                    assert isinstance(reuse_weights, dict)
334                    for old_params, new_params in zip(old_pred.params.values(),
335                                                      new_cp.predictor.params.values()):
336                        old_data = old_params.data()
337                        new_data = new_params.data()
338
339                        for k, v in reuse_weights.items():
340                            if k >= len(self.classes) or v >= len(old_classes):
341                                warnings.warn("reuse mapping {}/{} -> {}/{} out of range".format(
342                                    k, self.classes, v, old_classes))
343                                continue
344                            # always increment k and v (background is always the 0th)
345                            new_data[k+1::len(self.classes)+1] = old_data[v+1::len(old_classes)+1]
346                        # reuse background weights as well
347                        new_data[0::len(self.classes)+1] = old_data[0::len(old_classes)+1]
348                        # set data to new conv layers
349                        new_params.set_data(new_data)
350                class_predictors.add(new_cp)
351            self.class_predictors = class_predictors
352            self.cls_decoder = MultiPerClassDecoder(len(self.classes) + 1, thresh=0.01)
353
354def get_ssd(name, base_size, features, filters, sizes, ratios, steps, classes,
355            dataset, pretrained=False, pretrained_base=True, ctx=mx.cpu(),
356            root=os.path.join('~', '.mxnet', 'models'),
357            anchor_generator=SSDAnchorGenerator, **kwargs):
358    """Get SSD models.
359
360    Parameters
361    ----------
362    name : str or None
363        Model name, if `None` is used, you must specify `features` to be a `HybridBlock`.
364    base_size : int
365        Base image size for training, this is fixed once training is assigned.
366        A fixed base size still allows you to have variable input size during test.
367    features : iterable of str or `HybridBlock`
368        List of network internal output names, in order to specify which layers are
369        used for predicting bbox values.
370        If `name` is `None`, `features` must be a `HybridBlock` which generate multiple
371        outputs for prediction.
372    filters : iterable of float or None
373        List of convolution layer channels which is going to be appended to the base
374        network feature extractor. If `name` is `None`, this is ignored.
375    sizes : iterable fo float
376        Sizes of anchor boxes, this should be a list of floats, in incremental order.
377        The length of `sizes` must be len(layers) + 1. For example, a two stage SSD
378        model can have ``sizes = [30, 60, 90]``, and it converts to `[30, 60]` and
379        `[60, 90]` for the two stages, respectively. For more details, please refer
380        to original paper.
381    ratios : iterable of list
382        Aspect ratios of anchors in each output layer. Its length must be equals
383        to the number of SSD output layers.
384    steps : list of int
385        Step size of anchor boxes in each output layer.
386    classes : iterable of str
387        Names of categories.
388    dataset : str
389        Name of dataset. This is used to identify model name because models trained on
390        different datasets are going to be very different.
391    pretrained : bool or str
392        Boolean value controls whether to load the default pretrained weights for model.
393        String value represents the hashtag for a certain version of pretrained weights.
394    pretrained_base : bool or str, optional, default is True
395        Load pretrained base network, the extra layers are randomized. Note that
396        if pretrained is `True`, this has no effect.
397    ctx : mxnet.Context
398        Context such as mx.cpu(), mx.gpu(0).
399    root : str
400        Model weights storing path.
401    norm_layer : object
402        Normalization layer used (default: :class:`mxnet.gluon.nn.BatchNorm`)
403        Can be :class:`mxnet.gluon.nn.BatchNorm` or :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
404    norm_kwargs : dict
405        Additional `norm_layer` arguments, for example `num_devices=4`
406        for :class:`mxnet.gluon.contrib.nn.SyncBatchNorm`.
407
408    Returns
409    -------
410    HybridBlock
411        A SSD detection network.
412    """
413    pretrained_base = False if pretrained else pretrained_base
414    base_name = None if callable(features) else name
415    net = SSD(base_name, base_size, features, filters, sizes, ratios, steps,
416              pretrained=pretrained_base, classes=classes, ctx=ctx, root=root,
417              minimal_opset=pretrained, anchor_generator=anchor_generator, **kwargs)
418    if pretrained:
419        from ..model_store import get_model_file
420        full_name = '_'.join(('ssd', str(base_size), name, dataset))
421        net.load_parameters(get_model_file(full_name, tag=pretrained, root=root), ctx=ctx)
422    return net
423
424def custom_ssd(base_network_name, base_size, filters, sizes, ratios, steps,
425               classes, dataset, pretrained_base, **kwargs):
426    """Custom SSD models.
427    """
428    if base_network_name == 'vgg16_atrous' and base_size == 300:
429        features = vgg16_atrous_300
430    elif base_network_name == 'vgg16_atrous' and base_size == 512:
431        features = vgg16_atrous_512
432    elif base_network_name == 'resnet18_v1' and base_size == 512:
433        features = ['stage3_activation1', 'stage4_activation1']
434    elif base_network_name == 'resnet50_v1' and base_size == 512:
435        features = ['stage3_activation5', 'stage4_activation2']
436    elif base_network_name == 'resnet101_v2' and base_size == 512:
437        features = ['stage3_activation22', 'stage4_activation2']
438    elif base_network_name == 'resnet152_v2' and base_size == 512:
439        features = ['stage2_activation7', 'stage3_activation35', 'stage4_activation2']
440    elif base_network_name == 'mobilenet1.0' and base_size == 512:
441        features = ['relu22_fwd', 'relu26_fwd']
442    elif base_network_name == 'mobilenet1.0' and base_size == 300:
443        features = ['relu22_fwd', 'relu26_fwd']
444    elif base_network_name == 'mobilenet0.25' and base_size == 300:
445        features = ['relu22_fwd', 'relu26_fwd']
446    else:
447        raise NotImplementedError('Unsupported network', base_network_name)
448
449    net = get_ssd(name=base_network_name,
450                  base_size=base_size,
451                  features=features,
452                  filters=filters,
453                  sizes=sizes,
454                  ratios=ratios,
455                  steps=steps,
456                  classes=classes,
457                  dataset=dataset,
458                  pretrained_base=pretrained_base,
459                  **kwargs)
460
461    return net
462