1"""Default configs for center net""" 2# pylint: disable=bad-whitespace,missing-class-docstring,bad-indentation 3import os 4from typing import Union, Tuple 5from autocfg import dataclass, field 6 7@dataclass 8class CenterNetHead: 9 bias : float = -2.19 # use bias = -log((1 - 0.1) / 0.1) 10 wh_outputs : int = 2 # wh head channel 11 reg_outputs : int = 2 # regression head channel 12 head_conv_channel : int = 64 # additional conv channel 13 14@dataclass 15class CenterNet: 16 base_network : str = 'dla34_deconv' # base feature network 17 heads : CenterNetHead = field(default_factory=CenterNetHead) 18 scale : float = 4.0 # output vs input scaling ratio, e.g., input_h // feature_h 19 topk : int = 100 # topk detection results will be kept after inference 20 root : str = os.path.expanduser(os.path.join('~', '.mxnet', 'models')) # model zoo root dir 21 wh_weight : float = 0.1 # Loss weight for width/height 22 center_reg_weight : float = 1.0 # Center regression loss weight 23 data_shape : Tuple[int, int] = (512, 512) 24 # use the pre-trained detector for transfer learning(use preset, ignore other network settings) 25 transfer : str = 'center_net_resnet50_v1b_coco' 26 27@dataclass 28class TrainCfg: 29 pretrained_base : bool = True # whether load the imagenet pre-trained base 30 batch_size : int = 16 31 epochs : int = 15 32 lr : float = 1.25e-4 # learning rate 33 lr_decay : float = 0.1 # decay rate of learning rate. 34 lr_decay_epoch : Tuple[int, int] = (90, 120) # epochs at which learning rate decays 35 lr_mode : str = 'step' # learning rate scheduler mode. options are step, poly and cosine 36 warmup_lr : float = 0.0 # starting warmup learning rate. 37 warmup_epochs : int = 0 # number of warmup epochs 38 num_workers : int = 16 # cpu workers, the larger the more processes used 39 start_epoch : int = 0 40 momentum : float = 0.9 # SGD momentum 41 wd : float = 1e-4 # weight decay 42 log_interval : int = 100 # logging interval 43 44@dataclass 45class ValidCfg: 46 flip_test : bool = True # use flip in validation test 47 nms_thresh : Union[float, int] = 0 # 0 means disable 48 nms_topk : int = 400 # pre nms topk 49 post_nms : int = 100 # post nms topk 50 num_workers : int = 16 # cpu workers, the larger the more processes used 51 batch_size : int = 8 # validation batch size 52 interval : int = 1 # validation epoch interval, for slow validations 53 metric : str = 'voc07' # metric, 'voc', 'voc07' 54 iou_thresh : float = 0.5 # iou_thresh for VOC type metrics 55 56@dataclass 57class CenterNetCfg: 58 center_net : CenterNet = field(default_factory=CenterNet) 59 train : TrainCfg = field(default_factory=TrainCfg) 60 valid : ValidCfg = field(default_factory=ValidCfg) 61 gpus : Union[Tuple, list] = (0, 1, 2, 3, 4, 5, 6, 7) # gpu individual ids, not necessarily consecutive 62