1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17"""Train module using Caffe operator in MXNet""" 18import os 19import logging 20import mxnet as mx 21 22 23def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None): 24 """Train the model using Caffe operator in MXNet""" 25 # kvstore 26 kv = mx.kvstore.create(args.kv_store) 27 28 # logging 29 head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s' 30 if 'log_file' in args and args.log_file is not None: 31 log_file = args.log_file 32 log_dir = args.log_dir 33 log_file_full_name = os.path.join(log_dir, log_file) 34 if not os.path.exists(log_dir): 35 os.mkdir(log_dir) 36 logger = logging.getLogger() 37 handler = logging.FileHandler(log_file_full_name) 38 formatter = logging.Formatter(head) 39 handler.setFormatter(formatter) 40 logger.addHandler(handler) 41 logger.setLevel(logging.DEBUG) 42 logger.info('start with arguments %s', args) 43 else: 44 logging.basicConfig(level=logging.DEBUG, format=head) 45 logging.info('start with arguments %s', args) 46 47 # load model 48 model_prefix = args.model_prefix 49 if model_prefix is not None: 50 model_prefix += "-%d" % (kv.rank) 51 model_args = {} 52 if args.load_epoch is not None: 53 assert model_prefix is not None 54 tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch) 55 model_args = {'arg_params' : tmp.arg_params, 56 'aux_params' : tmp.aux_params, 57 'begin_epoch' : args.load_epoch} 58 # save model 59 save_model_prefix = args.save_model_prefix 60 if save_model_prefix is None: 61 save_model_prefix = model_prefix 62 checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix) 63 64 # data 65 (train, val) = data_loader(args, kv) 66 67 # train 68 devs = mx.cpu() if args.gpus is None else [ 69 mx.gpu(int(i)) for i in args.gpus.split(',')] 70 71 epoch_size = args.num_examples / args.batch_size 72 73 if args.kv_store == 'dist_sync': 74 epoch_size /= kv.num_workers 75 model_args['epoch_size'] = epoch_size 76 77 if 'lr_factor' in args and args.lr_factor < 1: 78 model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler( 79 step=max(int(epoch_size * args.lr_factor_epoch), 1), 80 factor=args.lr_factor) 81 82 if 'clip_gradient' in args and args.clip_gradient is not None: 83 model_args['clip_gradient'] = args.clip_gradient 84 85 # disable kvstore for single device 86 if 'local' in kv.type and ( 87 args.gpus is None or len(args.gpus.split(',')) is 1): 88 kv = None 89 90 mod = mx.mod.Module(network, context=devs) 91 92 if eval_metrics is None: 93 eval_metrics = ['accuracy'] 94 # TopKAccuracy only allows top_k > 1 95 for top_k in [5, 10, 20]: 96 eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=top_k)) 97 98 if batch_end_callback is not None: 99 if not isinstance(batch_end_callback, list): 100 batch_end_callback = [batch_end_callback] 101 else: 102 batch_end_callback = [] 103 batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50)) 104 105 mod.fit(train_data=train, eval_metric=eval_metrics, eval_data=val, optimizer='sgd', 106 optimizer_params={'learning_rate':args.lr, 'momentum': 0.9, 'wd': 0.00001}, 107 num_epoch=args.num_epochs, batch_end_callback=batch_end_callback, 108 initializer=mx.init.Xavier(factor_type="in", magnitude=2.34), 109 kvstore=kv, epoch_end_callback=checkpoint, **model_args) 110