1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import tools.find_mxnet 19import mxnet as mx 20import logging 21import sys 22import os 23import importlib 24import re 25from dataset.iterator import DetRecordIter 26from train.metric import MultiBoxMetric 27from evaluate.eval_metric import MApMetric, VOC07MApMetric 28from config.config import cfg 29from symbol.symbol_factory import get_symbol_train 30 31def convert_pretrained(name, args): 32 """ 33 Special operations need to be made due to name inconsistance, etc 34 35 Parameters: 36 --------- 37 name : str 38 pretrained model name 39 args : dict 40 loaded arguments 41 42 Returns: 43 --------- 44 processed arguments as dict 45 """ 46 return args 47 48def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio, 49 num_example, batch_size, begin_epoch): 50 """ 51 Compute learning rate and refactor scheduler 52 53 Parameters: 54 --------- 55 learning_rate : float 56 original learning rate 57 lr_refactor_step : comma separated str 58 epochs to change learning rate 59 lr_refactor_ratio : float 60 lr *= ratio at certain steps 61 num_example : int 62 number of training images, used to estimate the iterations given epochs 63 batch_size : int 64 training batch size 65 begin_epoch : int 66 starting epoch 67 68 Returns: 69 --------- 70 (learning_rate, mx.lr_scheduler) as tuple 71 """ 72 assert lr_refactor_ratio > 0 73 iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()] 74 if lr_refactor_ratio >= 1: 75 return (learning_rate, None) 76 else: 77 lr = learning_rate 78 epoch_size = num_example // batch_size 79 for s in iter_refactor: 80 if begin_epoch >= s: 81 lr *= lr_refactor_ratio 82 if lr != learning_rate: 83 logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch)) 84 steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch] 85 if not steps: 86 return (lr, None) 87 lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio) 88 return (lr, lr_scheduler) 89 90def train_net(net, train_path, num_classes, batch_size, 91 data_shape, mean_pixels, resume, finetune, pretrained, epoch, 92 prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate, 93 momentum, weight_decay, lr_refactor_step, lr_refactor_ratio, 94 freeze_layer_pattern='', 95 num_example=10000, label_pad_width=350, 96 nms_thresh=0.45, force_nms=False, ovp_thresh=0.5, 97 use_difficult=False, class_names=None, 98 voc07_metric=False, nms_topk=400, force_suppress=False, 99 train_list="", val_path="", val_list="", iter_monitor=0, 100 monitor_pattern=".*", log_file=None, kv_store=None): 101 """ 102 Wrapper for training phase. 103 104 Parameters: 105 ---------- 106 net : str 107 symbol name for the network structure 108 train_path : str 109 record file path for training 110 num_classes : int 111 number of object classes, not including background 112 batch_size : int 113 training batch-size 114 data_shape : int or tuple 115 width/height as integer or (3, height, width) tuple 116 mean_pixels : tuple of floats 117 mean pixel values for red, green and blue 118 resume : int 119 resume from previous checkpoint if > 0 120 finetune : int 121 fine-tune from previous checkpoint if > 0 122 pretrained : str 123 prefix of pretrained model, including path 124 epoch : int 125 load epoch of either resume/finetune/pretrained model 126 prefix : str 127 prefix for saving checkpoints 128 ctx : [mx.cpu()] or [mx.gpu(x)] 129 list of mxnet contexts 130 begin_epoch : int 131 starting epoch for training, should be 0 if not otherwise specified 132 end_epoch : int 133 end epoch of training 134 frequent : int 135 frequency to print out training status 136 learning_rate : float 137 training learning rate 138 momentum : float 139 trainig momentum 140 weight_decay : float 141 training weight decay param 142 lr_refactor_ratio : float 143 multiplier for reducing learning rate 144 lr_refactor_step : comma separated integers 145 at which epoch to rescale learning rate, e.g. '30, 60, 90' 146 freeze_layer_pattern : str 147 regex pattern for layers need to be fixed 148 num_example : int 149 number of training images 150 label_pad_width : int 151 force padding training and validation labels to sync their label widths 152 nms_thresh : float 153 non-maximum suppression threshold for validation 154 force_nms : boolean 155 suppress overlaped objects from different classes 156 train_list : str 157 list file path for training, this will replace the embeded labels in record 158 val_path : str 159 record file path for validation 160 val_list : str 161 list file path for validation, this will replace the embeded labels in record 162 iter_monitor : int 163 monitor internal stats in networks if > 0, specified by monitor_pattern 164 monitor_pattern : str 165 regex pattern for monitoring network stats 166 log_file : str 167 log to file if enabled 168 """ 169 # set up logger 170 logging.basicConfig() 171 logger = logging.getLogger() 172 logger.setLevel(logging.INFO) 173 if log_file: 174 fh = logging.FileHandler(log_file) 175 logger.addHandler(fh) 176 177 # check args 178 if isinstance(data_shape, int): 179 data_shape = (3, data_shape, data_shape) 180 assert len(data_shape) == 3 and data_shape[0] == 3 181 prefix += '_' + net + '_' + str(data_shape[1]) 182 183 if isinstance(mean_pixels, (int, float)): 184 mean_pixels = [mean_pixels, mean_pixels, mean_pixels] 185 assert len(mean_pixels) == 3, "must provide all RGB mean values" 186 187 train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels, 188 label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train) 189 190 if val_path: 191 val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels, 192 label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid) 193 else: 194 val_iter = None 195 196 # load symbol 197 net = get_symbol_train(net, data_shape[1], num_classes=num_classes, 198 nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk) 199 200 # define layers with fixed weight/bias 201 if freeze_layer_pattern.strip(): 202 re_prog = re.compile(freeze_layer_pattern) 203 fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)] 204 else: 205 fixed_param_names = None 206 207 # load pretrained or resume from previous state 208 ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')' 209 if resume > 0: 210 logger.info("Resume training with {} from epoch {}" 211 .format(ctx_str, resume)) 212 _, args, auxs = mx.model.load_checkpoint(prefix, resume) 213 begin_epoch = resume 214 elif finetune > 0: 215 logger.info("Start finetuning with {} from epoch {}" 216 .format(ctx_str, finetune)) 217 _, args, auxs = mx.model.load_checkpoint(prefix, finetune) 218 begin_epoch = finetune 219 # the prediction convolution layers name starts with relu, so it's fine 220 fixed_param_names = [name for name in net.list_arguments() \ 221 if name.startswith('conv')] 222 elif pretrained: 223 logger.info("Start training with {} from pretrained model {}" 224 .format(ctx_str, pretrained)) 225 _, args, auxs = mx.model.load_checkpoint(pretrained, epoch) 226 args = convert_pretrained(pretrained, args) 227 else: 228 logger.info("Experimental: start training from scratch with {}" 229 .format(ctx_str)) 230 args = None 231 auxs = None 232 fixed_param_names = None 233 234 # helper information 235 if fixed_param_names: 236 logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']') 237 238 # init training module 239 mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx, 240 fixed_param_names=fixed_param_names) 241 242 # fit parameters 243 batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent) 244 epoch_end_callback = mx.callback.do_checkpoint(prefix) 245 learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step, 246 lr_refactor_ratio, num_example, batch_size, begin_epoch) 247 optimizer_params={'learning_rate':learning_rate, 248 'momentum':momentum, 249 'wd':weight_decay, 250 'lr_scheduler':lr_scheduler, 251 'clip_gradient':None, 252 'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 } 253 monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None 254 255 # run fit net, every n epochs we run evaluation network to get mAP 256 if voc07_metric: 257 valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) 258 else: 259 valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3) 260 261 # create kvstore when there are gpus 262 kv = mx.kvstore.create(kv_store) if kv_store else None 263 264 mod.fit(train_iter, 265 val_iter, 266 eval_metric=MultiBoxMetric(), 267 validation_metric=valid_metric, 268 batch_end_callback=batch_end_callback, 269 epoch_end_callback=epoch_end_callback, 270 optimizer='sgd', 271 optimizer_params=optimizer_params, 272 begin_epoch=begin_epoch, 273 num_epoch=end_epoch, 274 initializer=mx.init.Xavier(), 275 arg_params=args, 276 aux_params=auxs, 277 allow_missing=True, 278 monitor=monitor, 279 kvstore=kv) 280