1"""Default configs for center net"""
2# pylint: disable=bad-whitespace,missing-class-docstring,bad-indentation
3import os
4from typing import Union, Tuple
5from autocfg import dataclass, field
6
7@dataclass
8class CenterNetHead:
9    bias : float = -2.19          # use bias = -log((1 - 0.1) / 0.1)
10    wh_outputs : int = 2          # wh head channel
11    reg_outputs : int = 2         # regression head channel
12    head_conv_channel : int = 64  # additional conv channel
13
14@dataclass
15class CenterNet:
16  base_network : str = 'dla34_deconv'  # base feature network
17  heads : CenterNetHead = field(default_factory=CenterNetHead)
18  scale : float = 4.0  # output vs input scaling ratio, e.g., input_h // feature_h
19  topk : int = 100  # topk detection results will be kept after inference
20  root : str = os.path.expanduser(os.path.join('~', '.mxnet', 'models'))  # model zoo root dir
21  wh_weight : float = 0.1  # Loss weight for width/height
22  center_reg_weight : float = 1.0  # Center regression loss weight
23  data_shape : Tuple[int, int] = (512, 512)
24  # use the pre-trained detector for transfer learning(use preset, ignore other network settings)
25  transfer : str = 'center_net_resnet50_v1b_coco'
26
27@dataclass
28class TrainCfg:
29    pretrained_base : bool = True  # whether load the imagenet pre-trained base
30    batch_size : int = 16
31    epochs : int = 15
32    lr : float = 1.25e-4  # learning rate
33    lr_decay : float = 0.1  # decay rate of learning rate.
34    lr_decay_epoch : Tuple[int, int] = (90, 120)  # epochs at which learning rate decays
35    lr_mode : str = 'step'  # learning rate scheduler mode. options are step, poly and cosine
36    warmup_lr : float = 0.0  # starting warmup learning rate.
37    warmup_epochs : int = 0  # number of warmup epochs
38    num_workers : int = 16  # cpu workers, the larger the more processes used
39    start_epoch : int = 0
40    momentum : float = 0.9  # SGD momentum
41    wd : float = 1e-4  # weight decay
42    log_interval : int = 100  # logging interval
43
44@dataclass
45class ValidCfg:
46    flip_test : bool = True  # use flip in validation test
47    nms_thresh : Union[float, int] = 0  # 0 means disable
48    nms_topk : int = 400  # pre nms topk
49    post_nms : int = 100  # post nms topk
50    num_workers : int = 16  # cpu workers, the larger the more processes used
51    batch_size : int = 8  # validation batch size
52    interval : int = 1  # validation epoch interval, for slow validations
53    metric : str = 'voc07' # metric, 'voc', 'voc07'
54    iou_thresh : float = 0.5 # iou_thresh for VOC type metrics
55
56@dataclass
57class CenterNetCfg:
58    center_net : CenterNet = field(default_factory=CenterNet)
59    train : TrainCfg = field(default_factory=TrainCfg)
60    valid : ValidCfg = field(default_factory=ValidCfg)
61    gpus : Union[Tuple, list] = (0, 1, 2, 3, 4, 5, 6, 7)  # gpu individual ids, not necessarily consecutive
62