1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17"""Train module using Caffe operator in MXNet"""
18import os
19import logging
20import mxnet as mx
21
22
23def fit(args, network, data_loader, eval_metrics=None, batch_end_callback=None):
24    """Train the model using Caffe operator in MXNet"""
25    # kvstore
26    kv = mx.kvstore.create(args.kv_store)
27
28    # logging
29    head = '%(asctime)-15s Node[' + str(kv.rank) + '] %(message)s'
30    if 'log_file' in args and args.log_file is not None:
31        log_file = args.log_file
32        log_dir = args.log_dir
33        log_file_full_name = os.path.join(log_dir, log_file)
34        if not os.path.exists(log_dir):
35            os.mkdir(log_dir)
36        logger = logging.getLogger()
37        handler = logging.FileHandler(log_file_full_name)
38        formatter = logging.Formatter(head)
39        handler.setFormatter(formatter)
40        logger.addHandler(handler)
41        logger.setLevel(logging.DEBUG)
42        logger.info('start with arguments %s', args)
43    else:
44        logging.basicConfig(level=logging.DEBUG, format=head)
45        logging.info('start with arguments %s', args)
46
47    # load model
48    model_prefix = args.model_prefix
49    if model_prefix is not None:
50        model_prefix += "-%d" % (kv.rank)
51    model_args = {}
52    if args.load_epoch is not None:
53        assert model_prefix is not None
54        tmp = mx.model.FeedForward.load(model_prefix, args.load_epoch)
55        model_args = {'arg_params' : tmp.arg_params,
56                      'aux_params' : tmp.aux_params,
57                      'begin_epoch' : args.load_epoch}
58    # save model
59    save_model_prefix = args.save_model_prefix
60    if save_model_prefix is None:
61        save_model_prefix = model_prefix
62    checkpoint = None if save_model_prefix is None else mx.callback.do_checkpoint(save_model_prefix)
63
64    # data
65    (train, val) = data_loader(args, kv)
66
67    # train
68    devs = mx.cpu() if args.gpus is None else [
69        mx.gpu(int(i)) for i in args.gpus.split(',')]
70
71    epoch_size = args.num_examples / args.batch_size
72
73    if args.kv_store == 'dist_sync':
74        epoch_size /= kv.num_workers
75        model_args['epoch_size'] = epoch_size
76
77    if 'lr_factor' in args and args.lr_factor < 1:
78        model_args['lr_scheduler'] = mx.lr_scheduler.FactorScheduler(
79            step=max(int(epoch_size * args.lr_factor_epoch), 1),
80            factor=args.lr_factor)
81
82    if 'clip_gradient' in args and args.clip_gradient is not None:
83        model_args['clip_gradient'] = args.clip_gradient
84
85    # disable kvstore for single device
86    if 'local' in kv.type and (
87            args.gpus is None or len(args.gpus.split(',')) is 1):
88        kv = None
89
90    mod = mx.mod.Module(network, context=devs)
91
92    if eval_metrics is None:
93        eval_metrics = ['accuracy']
94        # TopKAccuracy only allows top_k > 1
95        for top_k in [5, 10, 20]:
96            eval_metrics.append(mx.metric.create('top_k_accuracy', top_k=top_k))
97
98    if batch_end_callback is not None:
99        if not isinstance(batch_end_callback, list):
100            batch_end_callback = [batch_end_callback]
101    else:
102        batch_end_callback = []
103    batch_end_callback.append(mx.callback.Speedometer(args.batch_size, 50))
104
105    mod.fit(train_data=train, eval_metric=eval_metrics, eval_data=val, optimizer='sgd',
106            optimizer_params={'learning_rate':args.lr, 'momentum': 0.9, 'wd': 0.00001},
107            num_epoch=args.num_epochs, batch_end_callback=batch_end_callback,
108            initializer=mx.init.Xavier(factor_type="in", magnitude=2.34),
109            kvstore=kv, epoch_end_callback=checkpoint, **model_args)
110