1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18""" example train fit utility """ 19import logging 20import os 21import time 22import re 23import math 24import mxnet as mx 25 26def get_epoch_size(args, kv): 27 return math.ceil(int(args.num_examples / kv.num_workers) / args.batch_size) 28 29def _get_lr_scheduler(args, kv): 30 if 'lr_factor' not in args or args.lr_factor >= 1: 31 return (args.lr, None) 32 epoch_size = get_epoch_size(args, kv) 33 begin_epoch = args.load_epoch if args.load_epoch else 0 34 if 'pow' in args.lr_step_epochs: 35 lr = args.lr 36 max_up = args.num_epochs * epoch_size 37 pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs)) 38 poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr) 39 return (lr, poly_sched) 40 step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] 41 lr = args.lr 42 for s in step_epochs: 43 if begin_epoch >= s: 44 lr *= args.lr_factor 45 if lr != args.lr: 46 logging.info('Adjust learning rate to %e for epoch %d', 47 lr, begin_epoch) 48 49 steps = [epoch_size * (x - begin_epoch) 50 for x in step_epochs if x - begin_epoch > 0] 51 if steps: 52 return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor, 53 base_lr=args.lr)) 54 else: 55 return (lr, None) 56 57def _load_model(args, rank=0): 58 if 'load_epoch' not in args or args.load_epoch is None: 59 return (None, None, None) 60 assert args.model_prefix is not None 61 model_prefix = args.model_prefix 62 if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): 63 model_prefix += "-%d" % (rank) 64 sym, arg_params, aux_params = mx.model.load_checkpoint( 65 model_prefix, args.load_epoch) 66 logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) 67 return (sym, arg_params, aux_params) 68 69 70def _save_model(args, rank=0): 71 if args.model_prefix is None: 72 return None 73 return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( 74 args.model_prefix, rank), period=args.save_period) 75 76 77def add_fit_args(parser): 78 """ 79 parser : argparse.ArgumentParser 80 return a parser added with args required by fit 81 """ 82 train = parser.add_argument_group('Training', 'model training') 83 train.add_argument('--network', type=str, 84 help='the neural network to use') 85 train.add_argument('--num-layers', type=int, 86 help='number of layers in the neural network, \ 87 required by some networks such as resnet') 88 train.add_argument('--gpus', type=str, 89 help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') 90 train.add_argument('--kv-store', type=str, default='device', 91 help='key-value store type') 92 train.add_argument('--num-epochs', type=int, default=100, 93 help='max num of epochs') 94 train.add_argument('--lr', type=float, default=0.1, 95 help='initial learning rate') 96 train.add_argument('--lr-factor', type=float, default=0.1, 97 help='the ratio to reduce lr on each step') 98 train.add_argument('--lr-step-epochs', type=str, 99 help='the epochs to reduce the lr, e.g. 30,60') 100 train.add_argument('--initializer', type=str, default='default', 101 help='the initializer type') 102 train.add_argument('--optimizer', type=str, default='sgd', 103 help='the optimizer type') 104 train.add_argument('--mom', type=float, default=0.9, 105 help='momentum for sgd') 106 train.add_argument('--wd', type=float, default=0.0001, 107 help='weight decay for sgd') 108 train.add_argument('--batch-size', type=int, default=128, 109 help='the batch size') 110 train.add_argument('--disp-batches', type=int, default=20, 111 help='show progress for every n batches') 112 train.add_argument('--model-prefix', type=str, 113 help='model prefix') 114 train.add_argument('--save-period', type=int, default=1, help='params saving period') 115 parser.add_argument('--monitor', dest='monitor', type=int, default=0, 116 help='log network parameters every N iters if larger than 0') 117 train.add_argument('--load-epoch', type=int, 118 help='load the model on an epoch using the model-load-prefix') 119 train.add_argument('--top-k', type=int, default=0, 120 help='report the top-k accuracy. 0 means no report.') 121 train.add_argument('--loss', type=str, default='', 122 help='show the cross-entropy or nll loss. ce strands for cross-entropy, nll-loss stands for likelihood loss') 123 train.add_argument('--test-io', type=int, default=0, 124 help='1 means test reading speed without training') 125 train.add_argument('--dtype', type=str, default='float32', 126 help='precision: float32 or float16') 127 train.add_argument('--gc-type', type=str, default='none', 128 help='type of gradient compression to use, \ 129 takes `2bit` or `none` for now') 130 train.add_argument('--gc-threshold', type=float, default=0.5, 131 help='threshold for 2bit gradient compression') 132 # additional parameters for large batch sgd 133 train.add_argument('--macrobatch-size', type=int, default=0, 134 help='distributed effective batch size') 135 train.add_argument('--warmup-epochs', type=int, default=5, 136 help='the epochs to ramp-up lr to scaled large-batch value') 137 train.add_argument('--warmup-strategy', type=str, default='linear', 138 help='the ramping-up strategy for large batch sgd') 139 train.add_argument('--profile-worker-suffix', type=str, default='', 140 help='profile workers actions into this file. During distributed training\ 141 filename saved will be rank1_ followed by this suffix') 142 train.add_argument('--profile-server-suffix', type=str, default='', 143 help='profile server actions into a file with name like rank1_ followed by this suffix \ 144 during distributed training') 145 train.add_argument('--use-imagenet-data-augmentation', type=int, default=0, 146 help='enable data augmentation of ImageNet data, default disabled') 147 return train 148 149 150def fit(args, network, data_loader, **kwargs): 151 """ 152 train a model 153 args : argparse returns 154 network : the symbol definition of the nerual network 155 data_loader : function that returns the train and val data iterators 156 """ 157 # kvstore 158 kv = mx.kvstore.create(args.kv_store) 159 if args.gc_type != 'none': 160 kv.set_gradient_compression({'type': args.gc_type, 161 'threshold': args.gc_threshold}) 162 if args.profile_server_suffix: 163 mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server') 164 mx.profiler.set_state(state='run', profile_process='server') 165 166 if args.profile_worker_suffix: 167 if kv.num_workers > 1: 168 filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix 169 else: 170 filename = args.profile_worker_suffix 171 mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker') 172 mx.profiler.set_state(state='run', profile_process='worker') 173 174 # logging 175 head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' 176 logging.basicConfig(level=logging.DEBUG, format=head) 177 logging.info('start with arguments %s', args) 178 179 epoch_size = get_epoch_size(args, kv) 180 181 # data iterators 182 (train, val) = data_loader(args, kv) 183 if 'dist' in args.kv_store and not 'async' in args.kv_store: 184 logging.info('Resizing training data to %d batches per machine', epoch_size) 185 # resize train iter to ensure each machine has same number of batches per epoch 186 # if not, dist_sync can hang at the end with one machine waiting for other machines 187 train = mx.io.ResizeIter(train, epoch_size) 188 189 if args.test_io: 190 tic = time.time() 191 for i, batch in enumerate(train): 192 if isinstance(batch, list): 193 for b in batch: 194 for j in b.data: 195 j.wait_to_read() 196 else: 197 for j in batch.data: 198 j.wait_to_read() 199 if (i + 1) % args.disp_batches == 0: 200 logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i, 201 args.disp_batches * args.batch_size / (time.time() - tic)) 202 tic = time.time() 203 return 204 205 # load model 206 if 'arg_params' in kwargs and 'aux_params' in kwargs: 207 arg_params = kwargs['arg_params'] 208 aux_params = kwargs['aux_params'] 209 else: 210 sym, arg_params, aux_params = _load_model(args, kv.rank) 211 if sym is not None: 212 assert sym.tojson() == network.tojson() 213 214 # save model 215 checkpoint = _save_model(args, kv.rank) 216 217 # devices for training 218 devs = mx.cpu() if args.gpus is None or args.gpus == "" else [ 219 mx.gpu(int(i)) for i in args.gpus.split(',')] 220 221 # learning rate 222 lr, lr_scheduler = _get_lr_scheduler(args, kv) 223 224 # create model 225 model = mx.mod.Module( 226 context=devs, 227 symbol=network 228 ) 229 230 lr_scheduler = lr_scheduler 231 optimizer_params = { 232 'learning_rate': lr, 233 'wd': args.wd, 234 'lr_scheduler': lr_scheduler, 235 'multi_precision': True} 236 237 # Only a limited number of optimizers have 'momentum' property 238 has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'} 239 if args.optimizer in has_momentum: 240 optimizer_params['momentum'] = args.mom 241 242 monitor = mx.mon.Monitor( 243 args.monitor, pattern=".*") if args.monitor > 0 else None 244 245 # A limited number of optimizers have a warmup period 246 has_warmup = {'lbsgd', 'lbnag'} 247 if args.optimizer in has_warmup: 248 nworkers = kv.num_workers 249 if epoch_size < 1: 250 epoch_size = 1 251 macrobatch_size = args.macrobatch_size 252 if macrobatch_size < args.batch_size * nworkers: 253 macrobatch_size = args.batch_size * nworkers 254 #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) 255 batch_scale = math.ceil( 256 float(macrobatch_size) / args.batch_size / nworkers) 257 optimizer_params['updates_per_epoch'] = epoch_size 258 optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0 259 optimizer_params['batch_scale'] = batch_scale 260 optimizer_params['warmup_strategy'] = args.warmup_strategy 261 optimizer_params['warmup_epochs'] = args.warmup_epochs 262 optimizer_params['num_epochs'] = args.num_epochs 263 264 if args.initializer == 'default': 265 if args.network == 'alexnet': 266 # AlexNet will not converge using Xavier 267 initializer = mx.init.Normal() 268 # VGG will not trend to converge using Xavier-Gaussian 269 elif args.network and 'vgg' in args.network: 270 initializer = mx.init.Xavier() 271 else: 272 initializer = mx.init.Xavier( 273 rnd_type='gaussian', factor_type="in", magnitude=2) 274 # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), 275 elif args.initializer == 'xavier': 276 initializer = mx.init.Xavier() 277 elif args.initializer == 'msra': 278 initializer = mx.init.MSRAPrelu() 279 elif args.initializer == 'orthogonal': 280 initializer = mx.init.Orthogonal() 281 elif args.initializer == 'normal': 282 initializer = mx.init.Normal() 283 elif args.initializer == 'uniform': 284 initializer = mx.init.Uniform() 285 elif args.initializer == 'one': 286 initializer = mx.init.One() 287 elif args.initializer == 'zero': 288 initializer = mx.init.Zero() 289 290 # evaluation metrices 291 eval_metrics = ['accuracy'] 292 if args.top_k > 0: 293 eval_metrics.append(mx.metric.create( 294 'top_k_accuracy', top_k=args.top_k)) 295 296 supported_loss = ['ce', 'nll_loss'] 297 if len(args.loss) > 0: 298 # ce or nll loss is only applicable to softmax output 299 loss_type_list = args.loss.split(',') 300 if 'softmax_output' in network.list_outputs(): 301 for loss_type in loss_type_list: 302 loss_type = loss_type.strip() 303 if loss_type == 'nll': 304 loss_type = 'nll_loss' 305 if loss_type not in supported_loss: 306 logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \ 307 'negative likelihood loss is supported!') 308 else: 309 eval_metrics.append(mx.metric.create(loss_type)) 310 else: 311 logging.warning("The output is not softmax_output, loss argument will be skipped!") 312 313 # callbacks that run after each batch 314 batch_end_callbacks = [mx.callback.Speedometer( 315 args.batch_size, args.disp_batches)] 316 if 'batch_end_callback' in kwargs: 317 cbs = kwargs['batch_end_callback'] 318 batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] 319 320 # run 321 model.fit(train, 322 begin_epoch=args.load_epoch if args.load_epoch else 0, 323 num_epoch=args.num_epochs, 324 eval_data=val, 325 eval_metric=eval_metrics, 326 kvstore=kv, 327 optimizer=args.optimizer, 328 optimizer_params=optimizer_params, 329 initializer=initializer, 330 arg_params=arg_params, 331 aux_params=aux_params, 332 batch_end_callback=batch_end_callbacks, 333 epoch_end_callback=checkpoint, 334 allow_missing=True, 335 monitor=monitor) 336 337 if args.profile_server_suffix: 338 mx.profiler.set_state(state='run', profile_process='server') 339 if args.profile_worker_suffix: 340 mx.profiler.set_state(state='run', profile_process='worker') 341