# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ example train fit utility """ import logging import os import time import re import math import mxnet as mx def get_epoch_size(args, kv): return math.ceil(int(args.num_examples / kv.num_workers) / args.batch_size) def _get_lr_scheduler(args, kv): if 'lr_factor' not in args or args.lr_factor >= 1: return (args.lr, None) epoch_size = get_epoch_size(args, kv) begin_epoch = args.load_epoch if args.load_epoch else 0 if 'pow' in args.lr_step_epochs: lr = args.lr max_up = args.num_epochs * epoch_size pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs)) poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr) return (lr, poly_sched) step_epochs = [int(l) for l in args.lr_step_epochs.split(',')] lr = args.lr for s in step_epochs: if begin_epoch >= s: lr *= args.lr_factor if lr != args.lr: logging.info('Adjust learning rate to %e for epoch %d', lr, begin_epoch) steps = [epoch_size * (x - begin_epoch) for x in step_epochs if x - begin_epoch > 0] if steps: return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor, base_lr=args.lr)) else: return (lr, None) def _load_model(args, rank=0): if 'load_epoch' not in args or args.load_epoch is None: return (None, None, None) assert args.model_prefix is not None model_prefix = args.model_prefix if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)): model_prefix += "-%d" % (rank) sym, arg_params, aux_params = mx.model.load_checkpoint( model_prefix, args.load_epoch) logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch) return (sym, arg_params, aux_params) def _save_model(args, rank=0): if args.model_prefix is None: return None return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % ( args.model_prefix, rank), period=args.save_period) def add_fit_args(parser): """ parser : argparse.ArgumentParser return a parser added with args required by fit """ train = parser.add_argument_group('Training', 'model training') train.add_argument('--network', type=str, help='the neural network to use') train.add_argument('--num-layers', type=int, help='number of layers in the neural network, \ required by some networks such as resnet') train.add_argument('--gpus', type=str, help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu') train.add_argument('--kv-store', type=str, default='device', help='key-value store type') train.add_argument('--num-epochs', type=int, default=100, help='max num of epochs') train.add_argument('--lr', type=float, default=0.1, help='initial learning rate') train.add_argument('--lr-factor', type=float, default=0.1, help='the ratio to reduce lr on each step') train.add_argument('--lr-step-epochs', type=str, help='the epochs to reduce the lr, e.g. 30,60') train.add_argument('--initializer', type=str, default='default', help='the initializer type') train.add_argument('--optimizer', type=str, default='sgd', help='the optimizer type') train.add_argument('--mom', type=float, default=0.9, help='momentum for sgd') train.add_argument('--wd', type=float, default=0.0001, help='weight decay for sgd') train.add_argument('--batch-size', type=int, default=128, help='the batch size') train.add_argument('--disp-batches', type=int, default=20, help='show progress for every n batches') train.add_argument('--model-prefix', type=str, help='model prefix') train.add_argument('--save-period', type=int, default=1, help='params saving period') parser.add_argument('--monitor', dest='monitor', type=int, default=0, help='log network parameters every N iters if larger than 0') train.add_argument('--load-epoch', type=int, help='load the model on an epoch using the model-load-prefix') train.add_argument('--top-k', type=int, default=0, help='report the top-k accuracy. 0 means no report.') train.add_argument('--loss', type=str, default='', help='show the cross-entropy or nll loss. ce strands for cross-entropy, nll-loss stands for likelihood loss') train.add_argument('--test-io', type=int, default=0, help='1 means test reading speed without training') train.add_argument('--dtype', type=str, default='float32', help='precision: float32 or float16') train.add_argument('--gc-type', type=str, default='none', help='type of gradient compression to use, \ takes `2bit` or `none` for now') train.add_argument('--gc-threshold', type=float, default=0.5, help='threshold for 2bit gradient compression') # additional parameters for large batch sgd train.add_argument('--macrobatch-size', type=int, default=0, help='distributed effective batch size') train.add_argument('--warmup-epochs', type=int, default=5, help='the epochs to ramp-up lr to scaled large-batch value') train.add_argument('--warmup-strategy', type=str, default='linear', help='the ramping-up strategy for large batch sgd') train.add_argument('--profile-worker-suffix', type=str, default='', help='profile workers actions into this file. During distributed training\ filename saved will be rank1_ followed by this suffix') train.add_argument('--profile-server-suffix', type=str, default='', help='profile server actions into a file with name like rank1_ followed by this suffix \ during distributed training') train.add_argument('--use-imagenet-data-augmentation', type=int, default=0, help='enable data augmentation of ImageNet data, default disabled') return train def fit(args, network, data_loader, **kwargs): """ train a model args : argparse returns network : the symbol definition of the nerual network data_loader : function that returns the train and val data iterators """ # kvstore kv = mx.kvstore.create(args.kv_store) if args.gc_type != 'none': kv.set_gradient_compression({'type': args.gc_type, 'threshold': args.gc_threshold}) if args.profile_server_suffix: mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server') mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: if kv.num_workers > 1: filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix else: filename = args.profile_worker_suffix mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker') mx.profiler.set_state(state='run', profile_process='worker') # logging head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' logging.basicConfig(level=logging.DEBUG, format=head) logging.info('start with arguments %s', args) epoch_size = get_epoch_size(args, kv) # data iterators (train, val) = data_loader(args, kv) if 'dist' in args.kv_store and not 'async' in args.kv_store: logging.info('Resizing training data to %d batches per machine', epoch_size) # resize train iter to ensure each machine has same number of batches per epoch # if not, dist_sync can hang at the end with one machine waiting for other machines train = mx.io.ResizeIter(train, epoch_size) if args.test_io: tic = time.time() for i, batch in enumerate(train): if isinstance(batch, list): for b in batch: for j in b.data: j.wait_to_read() else: for j in batch.data: j.wait_to_read() if (i + 1) % args.disp_batches == 0: logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i, args.disp_batches * args.batch_size / (time.time() - tic)) tic = time.time() return # load model if 'arg_params' in kwargs and 'aux_params' in kwargs: arg_params = kwargs['arg_params'] aux_params = kwargs['aux_params'] else: sym, arg_params, aux_params = _load_model(args, kv.rank) if sym is not None: assert sym.tojson() == network.tojson() # save model checkpoint = _save_model(args, kv.rank) # devices for training devs = mx.cpu() if args.gpus is None or args.gpus == "" else [ mx.gpu(int(i)) for i in args.gpus.split(',')] # learning rate lr, lr_scheduler = _get_lr_scheduler(args, kv) # create model model = mx.mod.Module( context=devs, symbol=network ) lr_scheduler = lr_scheduler optimizer_params = { 'learning_rate': lr, 'wd': args.wd, 'lr_scheduler': lr_scheduler, 'multi_precision': True} # Only a limited number of optimizers have 'momentum' property has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'} if args.optimizer in has_momentum: optimizer_params['momentum'] = args.mom monitor = mx.mon.Monitor( args.monitor, pattern=".*") if args.monitor > 0 else None # A limited number of optimizers have a warmup period has_warmup = {'lbsgd', 'lbnag'} if args.optimizer in has_warmup: nworkers = kv.num_workers if epoch_size < 1: epoch_size = 1 macrobatch_size = args.macrobatch_size if macrobatch_size < args.batch_size * nworkers: macrobatch_size = args.batch_size * nworkers #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999) batch_scale = math.ceil( float(macrobatch_size) / args.batch_size / nworkers) optimizer_params['updates_per_epoch'] = epoch_size optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0 optimizer_params['batch_scale'] = batch_scale optimizer_params['warmup_strategy'] = args.warmup_strategy optimizer_params['warmup_epochs'] = args.warmup_epochs optimizer_params['num_epochs'] = args.num_epochs if args.initializer == 'default': if args.network == 'alexnet': # AlexNet will not converge using Xavier initializer = mx.init.Normal() # VGG will not trend to converge using Xavier-Gaussian elif args.network and 'vgg' in args.network: initializer = mx.init.Xavier() else: initializer = mx.init.Xavier( rnd_type='gaussian', factor_type="in", magnitude=2) # initializer = mx.init.Xavier(factor_type="in", magnitude=2.34), elif args.initializer == 'xavier': initializer = mx.init.Xavier() elif args.initializer == 'msra': initializer = mx.init.MSRAPrelu() elif args.initializer == 'orthogonal': initializer = mx.init.Orthogonal() elif args.initializer == 'normal': initializer = mx.init.Normal() elif args.initializer == 'uniform': initializer = mx.init.Uniform() elif args.initializer == 'one': initializer = mx.init.One() elif args.initializer == 'zero': initializer = mx.init.Zero() # evaluation metrices eval_metrics = ['accuracy'] if args.top_k > 0: eval_metrics.append(mx.metric.create( 'top_k_accuracy', top_k=args.top_k)) supported_loss = ['ce', 'nll_loss'] if len(args.loss) > 0: # ce or nll loss is only applicable to softmax output loss_type_list = args.loss.split(',') if 'softmax_output' in network.list_outputs(): for loss_type in loss_type_list: loss_type = loss_type.strip() if loss_type == 'nll': loss_type = 'nll_loss' if loss_type not in supported_loss: logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \ 'negative likelihood loss is supported!') else: eval_metrics.append(mx.metric.create(loss_type)) else: logging.warning("The output is not softmax_output, loss argument will be skipped!") # callbacks that run after each batch batch_end_callbacks = [mx.callback.Speedometer( args.batch_size, args.disp_batches)] if 'batch_end_callback' in kwargs: cbs = kwargs['batch_end_callback'] batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs] # run model.fit(train, begin_epoch=args.load_epoch if args.load_epoch else 0, num_epoch=args.num_epochs, eval_data=val, eval_metric=eval_metrics, kvstore=kv, optimizer=args.optimizer, optimizer_params=optimizer_params, initializer=initializer, arg_params=arg_params, aux_params=aux_params, batch_end_callback=batch_end_callbacks, epoch_end_callback=checkpoint, allow_missing=True, monitor=monitor) if args.profile_server_suffix: mx.profiler.set_state(state='run', profile_process='server') if args.profile_worker_suffix: mx.profiler.set_state(state='run', profile_process='worker')