1# pylint: disable=missing-function-docstring, line-too-long
2"""
3Utility functions, misc
4"""
5import os
6import time
7import numpy as np
8
9
10def read_labelmap(labelmap_file):
11    """Read label map and class ids."""
12
13    labelmap = []
14    class_ids = set()
15    name = ""
16    class_id = ""
17    with open(labelmap_file, 'r') as f:
18        for line in f:
19            if line.startswith("  name:"):
20                name = line.split('"')[1]
21            elif line.startswith("  id:") or line.startswith("  label_id:"):
22                class_id = int(line.strip().split(" ")[-1])
23                labelmap.append({"id": class_id, "name": name})
24                class_ids.add(class_id)
25    return labelmap, class_ids
26
27
28def build_log_dir(cfg):
29    # create base log directory
30    if cfg.CONFIG.LOG.EXP_NAME == 'use_time':
31        cfg.CONFIG.LOG.EXP_NAME = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime(time.time()))
32    log_path = os.path.join(cfg.CONFIG.LOG.BASE_PATH, cfg.CONFIG.LOG.EXP_NAME)
33    if not os.path.exists(log_path):
34        os.makedirs(log_path)
35
36    # dump config file
37    with open(os.path.join(log_path, 'config.yaml'), 'w') as f:
38        f.write(cfg.dump())
39
40    # create tensorboard saving directory
41    tb_logdir = os.path.join(log_path, cfg.CONFIG.LOG.LOG_DIR)
42    if not os.path.exists(tb_logdir):
43        os.makedirs(tb_logdir)
44
45    # create checkpoint saving directory
46    ckpt_logdir = os.path.join(log_path, cfg.CONFIG.LOG.SAVE_DIR)
47    if not os.path.exists(ckpt_logdir):
48        os.makedirs(ckpt_logdir)
49
50    return tb_logdir
51
52
53class AverageMeter(object):
54    """Computes and stores the average and current value"""
55
56    def __init__(self):
57        self.reset()
58
59    def reset(self):
60        self.val = 0
61        self.avg = 0
62        self.sum = 0
63        self.count = 0
64
65    def update(self, val, n=1):
66        self.val = val
67        self.sum += val * n
68        self.count += n
69        self.avg = self.sum / self.count
70
71
72def calculate_mAP(output, target):
73    import torchnet.meter as meter
74    mtr = meter.mAPMeter()
75    mtr.add(output, target)
76    ap = mtr.value()
77    return ap
78
79
80def accuracy(output, target, topk=(1,)):
81    """Computes the precision@k for the specified values of k"""
82    maxk = max(topk)
83    batch_size = target.size(0)
84
85    _, pred = output.topk(maxk, 1, True, True)
86    pred = pred.t()
87    correct = pred.eq(target.view(1, -1).expand_as(pred))
88
89    res = []
90    for k in topk:
91        correct_k = correct[:k].view(-1).float().sum(0)
92        res.append(correct_k.mul_(100.0 / batch_size))
93    return res
94
95
96def per_class_error(output, target, num_classes):
97    per_class_acc = []
98    for i in range(num_classes):
99        index = np.where(target == i)
100        diff = output[index] - target[index]
101        corr = np.where(diff == 0)[0].shape[0] / 1.0
102        total = index[0].shape[0] / 1.0
103        print(corr, total)
104        per_class_acc.append(corr / total)
105
106    return per_class_acc
107
108
109class ProgressMeter(object):
110    """Default PyTorch pogress meter"""
111    def __init__(self, num_batches, meters, prefix=""):
112        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
113        self.meters = meters
114        self.prefix = prefix
115
116    def display(self, batch):
117        entries = [self.prefix + self.batch_fmtstr.format(batch)]
118        entries += [str(meter) for meter in self.meters]
119        print('\t'.join(entries))
120
121    def _get_batch_fmtstr(self, num_batches):
122        num_digits = len(str(num_batches // 1))
123        fmt = '{:' + str(num_digits) + 'd}'
124        return '[' + fmt + '/' + fmt.format(num_batches) + ']'
125
126
127def adjust_learning_rate(optimizer, epoch, args):
128    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
129    lr = args.lr * (0.1 ** (epoch // 30))
130    for param_group in optimizer.param_groups:
131        param_group['lr'] = lr
132
133
134def get_iou(bb1, bb2):
135    """
136    Calculate the Intersection over Union (IoU) of two bounding boxes.
137
138    Parameters
139    ----------
140    bb1 : dict
141        Keys: {'x1', 'x2', 'y1', 'y2'}
142        The (x1, y1) position is at the top left corner,
143        the (x2, y2) position is at the bottom right corner
144    bb2 : dict
145        Keys: {'x1', 'x2', 'y1', 'y2'}
146        The (x, y) position is at the top left corner,
147        the (x2, y2) position is at the bottom right corner
148
149    Returns
150    -------
151    float
152        in [0, 1]
153    """
154    assert bb1['x1'] < bb1['x2']
155    assert bb1['y1'] < bb1['y2']
156    assert bb2['x1'] < bb2['x2']
157    assert bb2['y1'] < bb2['y2']
158
159    # determine the coordinates of the intersection rectangle
160    x_left = max(bb1['x1'], bb2['x1'])
161    y_top = max(bb1['y1'], bb2['y1'])
162    x_right = min(bb1['x2'], bb2['x2'])
163    y_bottom = min(bb1['y2'], bb2['y2'])
164
165    if x_right < x_left or y_bottom < y_top:
166        return 0.0
167
168    # The intersection of two axis-aligned bounding boxes is always an
169    # axis-aligned bounding box
170    intersection_area = (x_right - x_left) * (y_bottom - y_top)
171
172    # compute the area of both AABBs
173    bb1_area = (bb1['x2'] - bb1['x1']) * (bb1['y2'] - bb1['y1'])
174    bb2_area = (bb2['x2'] - bb2['x1']) * (bb2['y2'] - bb2['y1'])
175
176    # compute the intersection over union by taking the intersection
177    # area and dividing it by the sum of prediction + ground-truth
178    # areas - the interesection area
179    iou = intersection_area / float(bb1_area + bb2_area - intersection_area)
180    assert iou >= 0.0
181    assert iou <= 1.0
182    return iou
183