1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18""" example train fit utility """
19import logging
20import os
21import time
22import re
23import math
24import mxnet as mx
25
26def get_epoch_size(args, kv):
27    return math.ceil(int(args.num_examples / kv.num_workers) / args.batch_size)
28
29def _get_lr_scheduler(args, kv):
30    if 'lr_factor' not in args or args.lr_factor >= 1:
31        return (args.lr, None)
32    epoch_size = get_epoch_size(args, kv)
33    begin_epoch = args.load_epoch if args.load_epoch else 0
34    if 'pow' in args.lr_step_epochs:
35        lr = args.lr
36        max_up = args.num_epochs * epoch_size
37        pwr = float(re.sub('pow[- ]*', '', args.lr_step_epochs))
38        poly_sched = mx.lr_scheduler.PolyScheduler(max_up, lr, pwr)
39        return (lr, poly_sched)
40    step_epochs = [int(l) for l in args.lr_step_epochs.split(',')]
41    lr = args.lr
42    for s in step_epochs:
43        if begin_epoch >= s:
44            lr *= args.lr_factor
45    if lr != args.lr:
46        logging.info('Adjust learning rate to %e for epoch %d',
47                     lr, begin_epoch)
48
49    steps = [epoch_size * (x - begin_epoch)
50             for x in step_epochs if x - begin_epoch > 0]
51    if steps:
52        return (lr, mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=args.lr_factor,
53                                                         base_lr=args.lr))
54    else:
55        return (lr, None)
56
57def _load_model(args, rank=0):
58    if 'load_epoch' not in args or args.load_epoch is None:
59        return (None, None, None)
60    assert args.model_prefix is not None
61    model_prefix = args.model_prefix
62    if rank > 0 and os.path.exists("%s-%d-symbol.json" % (model_prefix, rank)):
63        model_prefix += "-%d" % (rank)
64    sym, arg_params, aux_params = mx.model.load_checkpoint(
65        model_prefix, args.load_epoch)
66    logging.info('Loaded model %s_%04d.params', model_prefix, args.load_epoch)
67    return (sym, arg_params, aux_params)
68
69
70def _save_model(args, rank=0):
71    if args.model_prefix is None:
72        return None
73    return mx.callback.do_checkpoint(args.model_prefix if rank == 0 else "%s-%d" % (
74        args.model_prefix, rank), period=args.save_period)
75
76
77def add_fit_args(parser):
78    """
79    parser : argparse.ArgumentParser
80    return a parser added with args required by fit
81    """
82    train = parser.add_argument_group('Training', 'model training')
83    train.add_argument('--network', type=str,
84                       help='the neural network to use')
85    train.add_argument('--num-layers', type=int,
86                       help='number of layers in the neural network, \
87                             required by some networks such as resnet')
88    train.add_argument('--gpus', type=str,
89                       help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu')
90    train.add_argument('--kv-store', type=str, default='device',
91                       help='key-value store type')
92    train.add_argument('--num-epochs', type=int, default=100,
93                       help='max num of epochs')
94    train.add_argument('--lr', type=float, default=0.1,
95                       help='initial learning rate')
96    train.add_argument('--lr-factor', type=float, default=0.1,
97                       help='the ratio to reduce lr on each step')
98    train.add_argument('--lr-step-epochs', type=str,
99                       help='the epochs to reduce the lr, e.g. 30,60')
100    train.add_argument('--initializer', type=str, default='default',
101                       help='the initializer type')
102    train.add_argument('--optimizer', type=str, default='sgd',
103                       help='the optimizer type')
104    train.add_argument('--mom', type=float, default=0.9,
105                       help='momentum for sgd')
106    train.add_argument('--wd', type=float, default=0.0001,
107                       help='weight decay for sgd')
108    train.add_argument('--batch-size', type=int, default=128,
109                       help='the batch size')
110    train.add_argument('--disp-batches', type=int, default=20,
111                       help='show progress for every n batches')
112    train.add_argument('--model-prefix', type=str,
113                       help='model prefix')
114    train.add_argument('--save-period', type=int, default=1, help='params saving period')
115    parser.add_argument('--monitor', dest='monitor', type=int, default=0,
116                        help='log network parameters every N iters if larger than 0')
117    train.add_argument('--load-epoch', type=int,
118                       help='load the model on an epoch using the model-load-prefix')
119    train.add_argument('--top-k', type=int, default=0,
120                       help='report the top-k accuracy. 0 means no report.')
121    train.add_argument('--loss', type=str, default='',
122                       help='show the cross-entropy or nll loss. ce strands for cross-entropy, nll-loss stands for likelihood loss')
123    train.add_argument('--test-io', type=int, default=0,
124                       help='1 means test reading speed without training')
125    train.add_argument('--dtype', type=str, default='float32',
126                       help='precision: float32 or float16')
127    train.add_argument('--gc-type', type=str, default='none',
128                       help='type of gradient compression to use, \
129                             takes `2bit` or `none` for now')
130    train.add_argument('--gc-threshold', type=float, default=0.5,
131                       help='threshold for 2bit gradient compression')
132    # additional parameters for large batch sgd
133    train.add_argument('--macrobatch-size', type=int, default=0,
134                       help='distributed effective batch size')
135    train.add_argument('--warmup-epochs', type=int, default=5,
136                       help='the epochs to ramp-up lr to scaled large-batch value')
137    train.add_argument('--warmup-strategy', type=str, default='linear',
138                       help='the ramping-up strategy for large batch sgd')
139    train.add_argument('--profile-worker-suffix', type=str, default='',
140                       help='profile workers actions into this file. During distributed training\
141                             filename saved will be rank1_ followed by this suffix')
142    train.add_argument('--profile-server-suffix', type=str, default='',
143                       help='profile server actions into a file with name like rank1_ followed by this suffix \
144                             during distributed training')
145    train.add_argument('--use-imagenet-data-augmentation', type=int, default=0,
146                       help='enable data augmentation of ImageNet data, default disabled')
147    return train
148
149
150def fit(args, network, data_loader, **kwargs):
151    """
152    train a model
153    args : argparse returns
154    network : the symbol definition of the nerual network
155    data_loader : function that returns the train and val data iterators
156    """
157    # kvstore
158    kv = mx.kvstore.create(args.kv_store)
159    if args.gc_type != 'none':
160        kv.set_gradient_compression({'type': args.gc_type,
161                                     'threshold': args.gc_threshold})
162    if args.profile_server_suffix:
163        mx.profiler.set_config(filename=args.profile_server_suffix, profile_all=True, profile_process='server')
164        mx.profiler.set_state(state='run', profile_process='server')
165
166    if args.profile_worker_suffix:
167        if kv.num_workers > 1:
168            filename = 'rank' + str(kv.rank) + '_' + args.profile_worker_suffix
169        else:
170            filename = args.profile_worker_suffix
171        mx.profiler.set_config(filename=filename, profile_all=True, profile_process='worker')
172        mx.profiler.set_state(state='run', profile_process='worker')
173
174    # logging
175    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
176    logging.basicConfig(level=logging.DEBUG, format=head)
177    logging.info('start with arguments %s', args)
178
179    epoch_size = get_epoch_size(args, kv)
180
181    # data iterators
182    (train, val) = data_loader(args, kv)
183    if 'dist' in args.kv_store and not 'async' in args.kv_store:
184        logging.info('Resizing training data to %d batches per machine', epoch_size)
185        # resize train iter to ensure each machine has same number of batches per epoch
186        # if not, dist_sync can hang at the end with one machine waiting for other machines
187        train = mx.io.ResizeIter(train, epoch_size)
188
189    if args.test_io:
190        tic = time.time()
191        for i, batch in enumerate(train):
192            if isinstance(batch, list):
193                for b in batch:
194                    for j in b.data:
195                        j.wait_to_read()
196            else:
197                for j in batch.data:
198                    j.wait_to_read()
199            if (i + 1) % args.disp_batches == 0:
200                logging.info('Batch [%d]\tSpeed: %.2f samples/sec', i,
201                             args.disp_batches * args.batch_size / (time.time() - tic))
202                tic = time.time()
203        return
204
205    # load model
206    if 'arg_params' in kwargs and 'aux_params' in kwargs:
207        arg_params = kwargs['arg_params']
208        aux_params = kwargs['aux_params']
209    else:
210        sym, arg_params, aux_params = _load_model(args, kv.rank)
211        if sym is not None:
212            assert sym.tojson() == network.tojson()
213
214    # save model
215    checkpoint = _save_model(args, kv.rank)
216
217    # devices for training
218    devs = mx.cpu() if args.gpus is None or args.gpus == "" else [
219        mx.gpu(int(i)) for i in args.gpus.split(',')]
220
221    # learning rate
222    lr, lr_scheduler = _get_lr_scheduler(args, kv)
223
224    # create model
225    model = mx.mod.Module(
226        context=devs,
227        symbol=network
228    )
229
230    lr_scheduler = lr_scheduler
231    optimizer_params = {
232        'learning_rate': lr,
233        'wd': args.wd,
234        'lr_scheduler': lr_scheduler,
235        'multi_precision': True}
236
237    # Only a limited number of optimizers have 'momentum' property
238    has_momentum = {'sgd', 'dcasgd', 'nag', 'signum', 'lbsgd'}
239    if args.optimizer in has_momentum:
240        optimizer_params['momentum'] = args.mom
241
242    monitor = mx.mon.Monitor(
243        args.monitor, pattern=".*") if args.monitor > 0 else None
244
245    # A limited number of optimizers have a warmup period
246    has_warmup = {'lbsgd', 'lbnag'}
247    if args.optimizer in has_warmup:
248        nworkers = kv.num_workers
249        if epoch_size < 1:
250            epoch_size = 1
251        macrobatch_size = args.macrobatch_size
252        if macrobatch_size < args.batch_size * nworkers:
253            macrobatch_size = args.batch_size * nworkers
254        #batch_scale = round(float(macrobatch_size) / args.batch_size / nworkers +0.4999)
255        batch_scale = math.ceil(
256            float(macrobatch_size) / args.batch_size / nworkers)
257        optimizer_params['updates_per_epoch'] = epoch_size
258        optimizer_params['begin_epoch'] = args.load_epoch if args.load_epoch else 0
259        optimizer_params['batch_scale'] = batch_scale
260        optimizer_params['warmup_strategy'] = args.warmup_strategy
261        optimizer_params['warmup_epochs'] = args.warmup_epochs
262        optimizer_params['num_epochs'] = args.num_epochs
263
264    if args.initializer == 'default':
265        if args.network == 'alexnet':
266            # AlexNet will not converge using Xavier
267            initializer = mx.init.Normal()
268            # VGG will not trend to converge using Xavier-Gaussian
269        elif args.network and 'vgg' in args.network:
270            initializer = mx.init.Xavier()
271        else:
272            initializer = mx.init.Xavier(
273                rnd_type='gaussian', factor_type="in", magnitude=2)
274    # initializer   = mx.init.Xavier(factor_type="in", magnitude=2.34),
275    elif args.initializer == 'xavier':
276        initializer = mx.init.Xavier()
277    elif args.initializer == 'msra':
278        initializer = mx.init.MSRAPrelu()
279    elif args.initializer == 'orthogonal':
280        initializer = mx.init.Orthogonal()
281    elif args.initializer == 'normal':
282        initializer = mx.init.Normal()
283    elif args.initializer == 'uniform':
284        initializer = mx.init.Uniform()
285    elif args.initializer == 'one':
286        initializer = mx.init.One()
287    elif args.initializer == 'zero':
288        initializer = mx.init.Zero()
289
290    # evaluation metrices
291    eval_metrics = ['accuracy']
292    if args.top_k > 0:
293        eval_metrics.append(mx.metric.create(
294            'top_k_accuracy', top_k=args.top_k))
295
296    supported_loss = ['ce', 'nll_loss']
297    if len(args.loss) > 0:
298        # ce or nll loss is only applicable to softmax output
299        loss_type_list = args.loss.split(',')
300        if 'softmax_output' in network.list_outputs():
301            for loss_type in loss_type_list:
302                loss_type = loss_type.strip()
303                if loss_type == 'nll':
304                    loss_type = 'nll_loss'
305                if loss_type not in supported_loss:
306                    logging.warning(loss_type + ' is not an valid loss type, only cross-entropy or ' \
307                                    'negative likelihood loss is supported!')
308                else:
309                    eval_metrics.append(mx.metric.create(loss_type))
310        else:
311            logging.warning("The output is not softmax_output, loss argument will be skipped!")
312
313    # callbacks that run after each batch
314    batch_end_callbacks = [mx.callback.Speedometer(
315        args.batch_size, args.disp_batches)]
316    if 'batch_end_callback' in kwargs:
317        cbs = kwargs['batch_end_callback']
318        batch_end_callbacks += cbs if isinstance(cbs, list) else [cbs]
319
320    # run
321    model.fit(train,
322              begin_epoch=args.load_epoch if args.load_epoch else 0,
323              num_epoch=args.num_epochs,
324              eval_data=val,
325              eval_metric=eval_metrics,
326              kvstore=kv,
327              optimizer=args.optimizer,
328              optimizer_params=optimizer_params,
329              initializer=initializer,
330              arg_params=arg_params,
331              aux_params=aux_params,
332              batch_end_callback=batch_end_callbacks,
333              epoch_end_callback=checkpoint,
334              allow_missing=True,
335              monitor=monitor)
336
337    if args.profile_server_suffix:
338        mx.profiler.set_state(state='run', profile_process='server')
339    if args.profile_worker_suffix:
340        mx.profiler.set_state(state='run', profile_process='worker')
341