1"""Train CenterNet"""
2import argparse
3import os
4import logging
5import warnings
6import time
7import numpy as np
8import mxnet as mx
9from mxnet import nd
10from mxnet import gluon
11from mxnet import autograd
12import gluoncv as gcv
13gcv.utils.check_version('0.6.0')
14from gluoncv import data as gdata
15from gluoncv import utils as gutils
16from gluoncv.model_zoo import get_model
17from gluoncv.data.batchify import Tuple, Stack, Pad
18from gluoncv.data.transforms.presets.center_net import CenterNetDefaultTrainTransform
19from gluoncv.data.transforms.presets.center_net import CenterNetDefaultValTransform, get_post_transform
20
21from gluoncv.utils.metrics.voc_detection import VOC07MApMetric
22from gluoncv.utils.metrics.coco_detection import COCODetectionMetric
23from gluoncv.utils.metrics.accuracy import Accuracy
24from gluoncv.utils import LRScheduler, LRSequential
25
26
27def parse_args():
28    parser = argparse.ArgumentParser(description='Train CenterNet networks.')
29    parser.add_argument('--network', type=str, default='resnet18_v1b',
30                        help="Base network name which serves as feature extraction base.")
31    parser.add_argument('--data-shape', type=int, default=512,
32                        help="Input data shape, use 300, 512.")
33    parser.add_argument('--batch-size', type=int, default=32,
34                        help='Training mini-batch size')
35    parser.add_argument('--dataset', type=str, default='voc',
36                        help='Training dataset. Now support voc.')
37    parser.add_argument('--dataset-root', type=str, default='~/.mxnet/datasets/',
38                        help='Path of the directory where the dataset is located.')
39    parser.add_argument('--num-workers', '-j', dest='num_workers', type=int,
40                        default=4, help='Number of data workers, you can use larger '
41                        'number to accelerate data loading, if you CPU and GPUs are powerful.')
42    parser.add_argument('--gpus', type=str, default='0',
43                        help='Training with GPUs, you can specify 1,3 for example.')
44    parser.add_argument('--epochs', type=int, default=140,
45                        help='Training epochs.')
46    parser.add_argument('--resume', type=str, default='',
47                        help='Resume from previously saved parameters if not None. '
48                        'For example, you can resume from ./ssd_xxx_0123.params')
49    parser.add_argument('--start-epoch', type=int, default=0,
50                        help='Starting epoch for resuming, default is 0 for new training.'
51                        'You can specify it to 100 for example to start from 100 epoch.')
52    parser.add_argument('--lr', type=float, default=1.25e-4,
53                        help='Learning rate, default is 0.000125')
54    parser.add_argument('--lr-decay', type=float, default=0.1,
55                        help='decay rate of learning rate. default is 0.1.')
56    parser.add_argument('--lr-decay-epoch', type=str, default='90,120',
57                        help='epochs at which learning rate decays. default is 90,120.')
58    parser.add_argument('--lr-mode', type=str, default='step',
59                        help='learning rate scheduler mode. options are step, poly and cosine.')
60    parser.add_argument('--warmup-lr', type=float, default=0.0,
61                        help='starting warmup learning rate. default is 0.0.')
62    parser.add_argument('--warmup-epochs', type=int, default=0,
63                        help='number of warmup epochs.')
64    parser.add_argument('--momentum', type=float, default=0.9,
65                        help='SGD momentum, default is 0.9')
66    parser.add_argument('--wd', type=float, default=0.0001,
67                        help='Weight decay, default is 1e-4')
68    parser.add_argument('--log-interval', type=int, default=100,
69                        help='Logging mini-batch interval. Default is 100.')
70    parser.add_argument('--num-samples', type=int, default=-1,
71                        help='Training images. Use -1 to automatically get the number.')
72    parser.add_argument('--save-prefix', type=str, default='',
73                        help='Saving parameter prefix')
74    parser.add_argument('--save-interval', type=int, default=10,
75                        help='Saving parameters epoch interval, best model will always be saved.')
76    parser.add_argument('--val-interval', type=int, default=1,
77                        help='Epoch interval for validation, increase the number will reduce the '
78                             'training time if validation is slow.')
79    parser.add_argument('--seed', type=int, default=233,
80                        help='Random seed to be fixed.')
81    parser.add_argument('--wh-weight', type=float, default=0.1,
82                        help='Loss weight for width/height')
83    parser.add_argument('--center-reg-weight', type=float, default=1.0,
84                        help='Center regression loss weight')
85    parser.add_argument('--flip-validation', action='store_true',
86                        help='flip data augmentation in validation.')
87
88    args = parser.parse_args()
89    return args
90
91def get_dataset(dataset, args):
92    if dataset.lower() == 'voc':
93        train_dataset = gdata.VOCDetection(
94            splits=[(2007, 'trainval'), (2012, 'trainval')])
95        val_dataset = gdata.VOCDetection(
96            splits=[(2007, 'test')])
97        val_metric = VOC07MApMetric(iou_thresh=0.5, class_names=val_dataset.classes)
98    elif dataset.lower() == 'coco':
99        train_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_train2017')
100        val_dataset = gdata.COCODetection(root=args.dataset_root + "/coco", splits='instances_val2017', skip_empty=False)
101        val_metric = COCODetectionMetric(
102            val_dataset, args.save_prefix + '_eval', cleanup=True,
103            data_shape=(args.data_shape, args.data_shape), post_affine=get_post_transform)
104        # coco validation is slow, consider increase the validation interval
105        if args.val_interval == 1:
106            args.val_interval = 10
107    else:
108        raise NotImplementedError('Dataset: {} not implemented.'.format(dataset))
109    if args.num_samples < 0:
110        args.num_samples = len(train_dataset)
111    return train_dataset, val_dataset, val_metric
112
113def get_dataloader(net, train_dataset, val_dataset, data_shape, batch_size, num_workers, ctx):
114    """Get dataloader."""
115    width, height = data_shape, data_shape
116    num_class = len(train_dataset.classes)
117    batchify_fn = Tuple([Stack() for _ in range(6)])  # stack image, cls_targets, box_targets
118    train_loader = gluon.data.DataLoader(
119        train_dataset.transform(CenterNetDefaultTrainTransform(
120            width, height, num_class=num_class, scale_factor=net.scale)),
121        batch_size, True, batchify_fn=batchify_fn, last_batch='rollover', num_workers=num_workers)
122    val_batchify_fn = Tuple(Stack(), Pad(pad_val=-1))
123    val_loader = gluon.data.DataLoader(
124        val_dataset.transform(CenterNetDefaultValTransform(width, height)),
125        batch_size, False, batchify_fn=val_batchify_fn, last_batch='keep', num_workers=num_workers)
126    return train_loader, val_loader
127
128def save_params(net, best_map, current_map, epoch, save_interval, prefix):
129    current_map = float(current_map)
130    if current_map > best_map[0]:
131        best_map[0] = current_map
132        net.save_parameters('{:s}_best.params'.format(prefix, epoch, current_map))
133        with open(prefix+'_best_map.log', 'a') as f:
134            f.write('{:04d}:\t{:.4f}\n'.format(epoch, current_map))
135    if save_interval and epoch % save_interval == 0:
136        net.save_parameters('{:s}_{:04d}_{:.4f}.params'.format(prefix, epoch, current_map))
137
138def validate(net, val_data, ctx, eval_metric, flip_test=False):
139    """Test on validation dataset."""
140    eval_metric.reset()
141    net.flip_test = flip_test
142    mx.nd.waitall()
143    net.hybridize()
144    for batch in val_data:
145        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
146        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
147        det_bboxes = []
148        det_ids = []
149        det_scores = []
150        gt_bboxes = []
151        gt_ids = []
152        gt_difficults = []
153        for x, y in zip(data, label):
154            # get prediction results
155            ids, scores, bboxes = net(x)
156            det_ids.append(ids)
157            det_scores.append(scores)
158            # clip to image size
159            det_bboxes.append(bboxes.clip(0, batch[0].shape[2]))
160            # split ground truths
161            gt_ids.append(y.slice_axis(axis=-1, begin=4, end=5))
162            gt_bboxes.append(y.slice_axis(axis=-1, begin=0, end=4))
163            gt_difficults.append(y.slice_axis(axis=-1, begin=5, end=6) if y.shape[-1] > 5 else None)
164
165        # update metric
166        eval_metric.update(det_bboxes, det_ids, det_scores, gt_bboxes, gt_ids, gt_difficults)
167    return eval_metric.get()
168
169def train(net, train_data, val_data, eval_metric, ctx, args):
170    """Training pipeline"""
171    net.collect_params().reset_ctx(ctx)
172    # lr decay policy
173    lr_decay = float(args.lr_decay)
174    lr_steps = sorted([int(ls) for ls in args.lr_decay_epoch.split(',') if ls.strip()])
175    lr_decay_epoch = [e - args.warmup_epochs for e in lr_steps]
176    num_batches = args.num_samples // args.batch_size
177    lr_scheduler = LRSequential([
178        LRScheduler('linear', base_lr=0, target_lr=args.lr,
179                    nepochs=args.warmup_epochs, iters_per_epoch=num_batches),
180        LRScheduler(args.lr_mode, base_lr=args.lr,
181                    nepochs=args.epochs - args.warmup_epochs,
182                    iters_per_epoch=num_batches,
183                    step_epoch=lr_decay_epoch,
184                    step_factor=args.lr_decay, power=2),
185    ])
186
187    for k, v in net.collect_params('.*bias').items():
188        v.wd_mult = 0.0
189    trainer = gluon.Trainer(
190                net.collect_params(), 'adam',
191                {'learning_rate': args.lr, 'wd': args.wd,
192                 'lr_scheduler': lr_scheduler})
193
194    heatmap_loss = gcv.loss.HeatmapFocalLoss(from_logits=True)
195    wh_loss = gcv.loss.MaskedL1Loss(weight=args.wh_weight)
196    center_reg_loss = gcv.loss.MaskedL1Loss(weight=args.center_reg_weight)
197    heatmap_loss_metric = mx.metric.Loss('HeatmapFocal')
198    wh_metric = mx.metric.Loss('WHL1')
199    center_reg_metric = mx.metric.Loss('CenterRegL1')
200
201    # set up logger
202    logging.basicConfig()
203    logger = logging.getLogger()
204    logger.setLevel(logging.INFO)
205    log_file_path = args.save_prefix + '_train.log'
206    log_dir = os.path.dirname(log_file_path)
207    if log_dir and not os.path.exists(log_dir):
208        os.makedirs(log_dir)
209    fh = logging.FileHandler(log_file_path)
210    logger.addHandler(fh)
211    logger.info(args)
212    logger.info('Start training from [Epoch {}]'.format(args.start_epoch))
213    best_map = [0]
214
215    for epoch in range(args.start_epoch, args.epochs):
216        wh_metric.reset()
217        center_reg_metric.reset()
218        tic = time.time()
219        btic = time.time()
220        net.hybridize()
221
222        for i, batch in enumerate(train_data):
223            split_data = [gluon.utils.split_and_load(batch[ind], ctx_list=ctx, batch_axis=0) for ind in range(6)]
224            data, heatmap_targets, wh_targets, wh_masks, center_reg_targets, center_reg_masks = split_data
225            batch_size = args.batch_size
226            with autograd.record():
227                sum_losses = []
228                heatmap_losses = []
229                wh_losses = []
230                center_reg_losses = []
231                wh_preds = []
232                center_reg_preds = []
233                for x, heatmap_target, wh_target, wh_mask, center_reg_target, center_reg_mask in zip(*split_data):
234                    heatmap_pred, wh_pred, center_reg_pred = net(x)
235                    wh_preds.append(wh_pred)
236                    center_reg_preds.append(center_reg_pred)
237                    wh_losses.append(wh_loss(wh_pred, wh_target, wh_mask))
238                    center_reg_losses.append(center_reg_loss(center_reg_pred, center_reg_target, center_reg_mask))
239                    heatmap_losses.append(heatmap_loss(heatmap_pred, heatmap_target))
240                    curr_loss = heatmap_losses[-1]+ wh_losses[-1] + center_reg_losses[-1]
241                    sum_losses.append(curr_loss)
242                autograd.backward(sum_losses)
243            trainer.step(len(sum_losses))  # step with # gpus
244
245            heatmap_loss_metric.update(0, heatmap_losses)
246            wh_metric.update(0, wh_losses)
247            center_reg_metric.update(0, center_reg_losses)
248            if args.log_interval and not (i + 1) % args.log_interval:
249                name2, loss2 = wh_metric.get()
250                name3, loss3 = center_reg_metric.get()
251                name4, loss4 = heatmap_loss_metric.get()
252                logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, LR={}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
253                    epoch, i, batch_size/(time.time()-btic), trainer.learning_rate, name2, loss2, name3, loss3, name4, loss4))
254            btic = time.time()
255
256        name2, loss2 = wh_metric.get()
257        name3, loss3 = center_reg_metric.get()
258        name4, loss4 = heatmap_loss_metric.get()
259        logger.info('[Epoch {}] Training cost: {:.3f}, {}={:.3f}, {}={:.3f}, {}={:.3f}'.format(
260            epoch, (time.time()-tic), name2, loss2, name3, loss3, name4, loss4))
261        if (epoch % args.val_interval == 0) or (args.save_interval and epoch % args.save_interval == 0) or (epoch == args.epochs - 1):
262            # consider reduce the frequency of validation to save time
263            map_name, mean_ap = validate(net, val_data, ctx, eval_metric, flip_test=args.flip_validation)
264            val_msg = '\n'.join(['{}={}'.format(k, v) for k, v in zip(map_name, mean_ap)])
265            logger.info('[Epoch {}] Validation: \n{}'.format(epoch, val_msg))
266            current_map = float(mean_ap[-1])
267        else:
268            current_map = 0.
269        save_params(net, best_map, current_map, epoch, args.save_interval, args.save_prefix)
270
271if __name__ == '__main__':
272    args = parse_args()
273
274    # fix seed for mxnet, numpy and python builtin random generator.
275    gutils.random.seed(args.seed)
276
277    # training contexts
278    ctx = [mx.gpu(int(i)) for i in args.gpus.split(',') if i.strip()]
279    ctx = ctx if ctx else [mx.cpu()]
280
281    # network
282    net_name = '_'.join(('center_net', args.network, args.dataset))
283    args.save_prefix += net_name
284    net = get_model(net_name, pretrained_base=True, norm_layer=gluon.nn.BatchNorm)
285    if args.resume.strip():
286        net.load_parameters(args.resume.strip())
287    else:
288        with warnings.catch_warnings(record=True) as w:
289            warnings.simplefilter("always")
290            net.initialize()
291            # needed for net to be first gpu when using AMP
292            net.collect_params().reset_ctx(ctx[0])
293
294    # training data
295    train_dataset, val_dataset, eval_metric = get_dataset(args.dataset, args)
296    batch_size = args.batch_size
297    train_data, val_data = get_dataloader(
298        net, train_dataset, val_dataset, args.data_shape, batch_size, args.num_workers, ctx[0])
299
300
301    # training
302    train(net, train_data, val_data, eval_metric, ctx, args)
303