1"""RCNN framewark Training Metrics."""
2
3import mxnet as mx
4try:
5    from mxnet.metric import EvalMetric
6except ImportError:
7    from mxnet.gluon.metric import EvalMetric
8
9
10class RPNAccMetric(EvalMetric):
11    """ RPN accuracy. """
12
13    def __init__(self):
14        super(RPNAccMetric, self).__init__('RPNAcc')
15
16    def update(self, labels, preds):
17        """ Updates the internal evaluation result. """
18        # label: [rpn_label, rpn_weight]
19        # preds: [rpn_cls_logits]
20        rpn_label, rpn_weight = labels
21        rpn_cls_logits = preds[0]
22
23        # calculate num_inst (average on those fg anchors)
24        num_inst = mx.nd.sum(rpn_weight)
25
26        # cls_logits (b, c, h, w) red_label (b, 1, h, w)
27        pred_label = mx.nd.sigmoid(rpn_cls_logits) >= 0.5
28        # label (b, 1, h, w)
29        num_acc = mx.nd.sum((pred_label == rpn_label) * rpn_weight)
30
31        self.sum_metric += num_acc.asscalar()
32        self.num_inst += num_inst.asscalar()
33
34
35class RPNL1LossMetric(mx.metric.EvalMetric):
36    """ RPN L1 loss. """
37
38    def __init__(self):
39        super(RPNL1LossMetric, self).__init__('RPNL1Loss')
40
41    def update(self, labels, preds):
42        """ Updates the internal evaluation result. """
43        # label = [rpn_bbox_target, rpn_bbox_weight]
44        # pred = [rpn_bbox_reg]
45        rpn_bbox_target, rpn_bbox_weight = labels
46        rpn_bbox_reg = preds[0]
47
48        # calculate num_inst (average on those fg anchors)
49        num_inst = mx.nd.sum(rpn_bbox_weight) / 4
50
51        # calculate smooth_l1
52        loss = mx.nd.sum(
53            rpn_bbox_weight * mx.nd.smooth_l1(rpn_bbox_reg - rpn_bbox_target, scalar=3))
54
55        self.sum_metric += loss.asscalar()
56        self.num_inst += num_inst.asscalar()
57
58
59class RCNNAccMetric(mx.metric.EvalMetric):
60    """ RCNN accuracy. """
61
62    def __init__(self):
63        super(RCNNAccMetric, self).__init__('RCNNAcc')
64
65    def update(self, labels, preds):
66        """ Updates the internal evaluation result. """
67        # label = [rcnn_label]
68        # pred = [rcnn_cls]
69        rcnn_label = labels[0]
70        rcnn_cls = preds[0]
71
72        # calculate num_acc
73        pred_label = mx.nd.argmax(rcnn_cls, axis=-1)
74        num_acc = mx.nd.sum(pred_label == rcnn_label)
75
76        self.sum_metric += num_acc.asscalar()
77        self.num_inst += rcnn_label.size
78
79
80class RCNNL1LossMetric(mx.metric.EvalMetric):
81    """ RCNN L1 loss. """
82
83    def __init__(self):
84        super(RCNNL1LossMetric, self).__init__('RCNNL1Loss')
85
86    def update(self, labels, preds):
87        """ Updates the internal evaluation result. """
88        # label = [rcnn_bbox_target, rcnn_bbox_weight]
89        # pred = [rcnn_reg]
90        rcnn_bbox_target, rcnn_bbox_weight = labels
91        rcnn_bbox_reg = preds[0]
92
93        # calculate num_inst
94        num_inst = mx.nd.sum(rcnn_bbox_weight) / 4
95
96        # calculate smooth_l1
97        loss = mx.nd.sum(
98            rcnn_bbox_weight * mx.nd.smooth_l1(rcnn_bbox_reg - rcnn_bbox_target, scalar=1))
99
100        self.sum_metric += loss.asscalar()
101        self.num_inst += num_inst.asscalar()
102
103
104class MaskAccMetric(mx.metric.EvalMetric):
105    """ RCNN mask branch accuracy. """
106
107    def __init__(self):
108        super(MaskAccMetric, self).__init__('MaskAcc')
109
110    def update(self, labels, preds):
111        """ Updates the internal evaluation result. """
112        # label = [rcnn_mask_target, rcnn_mask_weight]
113        # pred = [rcnn_mask]
114        rcnn_mask_target, rcnn_mask_weight = labels
115        rcnn_mask = preds[0]
116
117        # calculate num_inst
118        num_inst = mx.nd.sum(rcnn_mask_weight)
119
120        # rcnn_mask (b, n, c, h, w)
121        pred_label = mx.nd.sigmoid(rcnn_mask) >= 0.5
122        label = rcnn_mask_target >= 0.5
123        # label (b, n, c, h, w)
124        num_acc = mx.nd.sum((pred_label == label) * rcnn_mask_weight)
125
126        self.sum_metric += num_acc.asscalar()
127        self.num_inst += num_inst.asscalar()
128
129
130class MaskFGAccMetric(mx.metric.EvalMetric):
131    """ RCNN mask branch foreground accuracy. """
132
133    def __init__(self):
134        super(MaskFGAccMetric, self).__init__('MaskFGAcc')
135
136    def update(self, labels, preds):
137        """ Updates the internal evaluation result. """
138        # label = [rcnn_mask_target, rcnn_mask_weight]
139        # pred = [rcnn_mask]
140        rcnn_mask_target, _ = labels
141        rcnn_mask = preds[0]
142
143        # calculate num_inst
144        num_inst = mx.nd.sum(rcnn_mask_target)
145
146        # rcnn_mask (b, n, c, h, w)
147        pred_label = mx.nd.sigmoid(rcnn_mask) >= 0.5
148        label = rcnn_mask_target >= 0.5
149        # label (b, n, c, h, w)
150        num_acc = mx.nd.sum((pred_label == label) * label)
151
152        self.sum_metric += num_acc.asscalar()
153        self.num_inst += num_inst.asscalar()
154