1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import tools.find_mxnet
19import mxnet as mx
20import logging
21import sys
22import os
23import importlib
24import re
25from dataset.iterator import DetRecordIter
26from train.metric import MultiBoxMetric
27from evaluate.eval_metric import MApMetric, VOC07MApMetric
28from config.config import cfg
29from symbol.symbol_factory import get_symbol_train
30
31def convert_pretrained(name, args):
32    """
33    Special operations need to be made due to name inconsistance, etc
34
35    Parameters:
36    ---------
37    name : str
38        pretrained model name
39    args : dict
40        loaded arguments
41
42    Returns:
43    ---------
44    processed arguments as dict
45    """
46    return args
47
48def get_lr_scheduler(learning_rate, lr_refactor_step, lr_refactor_ratio,
49                     num_example, batch_size, begin_epoch):
50    """
51    Compute learning rate and refactor scheduler
52
53    Parameters:
54    ---------
55    learning_rate : float
56        original learning rate
57    lr_refactor_step : comma separated str
58        epochs to change learning rate
59    lr_refactor_ratio : float
60        lr *= ratio at certain steps
61    num_example : int
62        number of training images, used to estimate the iterations given epochs
63    batch_size : int
64        training batch size
65    begin_epoch : int
66        starting epoch
67
68    Returns:
69    ---------
70    (learning_rate, mx.lr_scheduler) as tuple
71    """
72    assert lr_refactor_ratio > 0
73    iter_refactor = [int(r) for r in lr_refactor_step.split(',') if r.strip()]
74    if lr_refactor_ratio >= 1:
75        return (learning_rate, None)
76    else:
77        lr = learning_rate
78        epoch_size = num_example // batch_size
79        for s in iter_refactor:
80            if begin_epoch >= s:
81                lr *= lr_refactor_ratio
82        if lr != learning_rate:
83            logging.getLogger().info("Adjusted learning rate to {} for epoch {}".format(lr, begin_epoch))
84        steps = [epoch_size * (x - begin_epoch) for x in iter_refactor if x > begin_epoch]
85        if not steps:
86            return (lr, None)
87        lr_scheduler = mx.lr_scheduler.MultiFactorScheduler(step=steps, factor=lr_refactor_ratio)
88        return (lr, lr_scheduler)
89
90def train_net(net, train_path, num_classes, batch_size,
91              data_shape, mean_pixels, resume, finetune, pretrained, epoch,
92              prefix, ctx, begin_epoch, end_epoch, frequent, learning_rate,
93              momentum, weight_decay, lr_refactor_step, lr_refactor_ratio,
94              freeze_layer_pattern='',
95              num_example=10000, label_pad_width=350,
96              nms_thresh=0.45, force_nms=False, ovp_thresh=0.5,
97              use_difficult=False, class_names=None,
98              voc07_metric=False, nms_topk=400, force_suppress=False,
99              train_list="", val_path="", val_list="", iter_monitor=0,
100              monitor_pattern=".*", log_file=None, kv_store=None):
101    """
102    Wrapper for training phase.
103
104    Parameters:
105    ----------
106    net : str
107        symbol name for the network structure
108    train_path : str
109        record file path for training
110    num_classes : int
111        number of object classes, not including background
112    batch_size : int
113        training batch-size
114    data_shape : int or tuple
115        width/height as integer or (3, height, width) tuple
116    mean_pixels : tuple of floats
117        mean pixel values for red, green and blue
118    resume : int
119        resume from previous checkpoint if > 0
120    finetune : int
121        fine-tune from previous checkpoint if > 0
122    pretrained : str
123        prefix of pretrained model, including path
124    epoch : int
125        load epoch of either resume/finetune/pretrained model
126    prefix : str
127        prefix for saving checkpoints
128    ctx : [mx.cpu()] or [mx.gpu(x)]
129        list of mxnet contexts
130    begin_epoch : int
131        starting epoch for training, should be 0 if not otherwise specified
132    end_epoch : int
133        end epoch of training
134    frequent : int
135        frequency to print out training status
136    learning_rate : float
137        training learning rate
138    momentum : float
139        trainig momentum
140    weight_decay : float
141        training weight decay param
142    lr_refactor_ratio : float
143        multiplier for reducing learning rate
144    lr_refactor_step : comma separated integers
145        at which epoch to rescale learning rate, e.g. '30, 60, 90'
146    freeze_layer_pattern : str
147        regex pattern for layers need to be fixed
148    num_example : int
149        number of training images
150    label_pad_width : int
151        force padding training and validation labels to sync their label widths
152    nms_thresh : float
153        non-maximum suppression threshold for validation
154    force_nms : boolean
155        suppress overlaped objects from different classes
156    train_list : str
157        list file path for training, this will replace the embeded labels in record
158    val_path : str
159        record file path for validation
160    val_list : str
161        list file path for validation, this will replace the embeded labels in record
162    iter_monitor : int
163        monitor internal stats in networks if > 0, specified by monitor_pattern
164    monitor_pattern : str
165        regex pattern for monitoring network stats
166    log_file : str
167        log to file if enabled
168    """
169    # set up logger
170    logging.basicConfig()
171    logger = logging.getLogger()
172    logger.setLevel(logging.INFO)
173    if log_file:
174        fh = logging.FileHandler(log_file)
175        logger.addHandler(fh)
176
177    # check args
178    if isinstance(data_shape, int):
179        data_shape = (3, data_shape, data_shape)
180    assert len(data_shape) == 3 and data_shape[0] == 3
181    prefix += '_' + net + '_' + str(data_shape[1])
182
183    if isinstance(mean_pixels, (int, float)):
184        mean_pixels = [mean_pixels, mean_pixels, mean_pixels]
185    assert len(mean_pixels) == 3, "must provide all RGB mean values"
186
187    train_iter = DetRecordIter(train_path, batch_size, data_shape, mean_pixels=mean_pixels,
188        label_pad_width=label_pad_width, path_imglist=train_list, **cfg.train)
189
190    if val_path:
191        val_iter = DetRecordIter(val_path, batch_size, data_shape, mean_pixels=mean_pixels,
192            label_pad_width=label_pad_width, path_imglist=val_list, **cfg.valid)
193    else:
194        val_iter = None
195
196    # load symbol
197    net = get_symbol_train(net, data_shape[1], num_classes=num_classes,
198        nms_thresh=nms_thresh, force_suppress=force_suppress, nms_topk=nms_topk)
199
200    # define layers with fixed weight/bias
201    if freeze_layer_pattern.strip():
202        re_prog = re.compile(freeze_layer_pattern)
203        fixed_param_names = [name for name in net.list_arguments() if re_prog.match(name)]
204    else:
205        fixed_param_names = None
206
207    # load pretrained or resume from previous state
208    ctx_str = '('+ ','.join([str(c) for c in ctx]) + ')'
209    if resume > 0:
210        logger.info("Resume training with {} from epoch {}"
211            .format(ctx_str, resume))
212        _, args, auxs = mx.model.load_checkpoint(prefix, resume)
213        begin_epoch = resume
214    elif finetune > 0:
215        logger.info("Start finetuning with {} from epoch {}"
216            .format(ctx_str, finetune))
217        _, args, auxs = mx.model.load_checkpoint(prefix, finetune)
218        begin_epoch = finetune
219        # the prediction convolution layers name starts with relu, so it's fine
220        fixed_param_names = [name for name in net.list_arguments() \
221            if name.startswith('conv')]
222    elif pretrained:
223        logger.info("Start training with {} from pretrained model {}"
224            .format(ctx_str, pretrained))
225        _, args, auxs = mx.model.load_checkpoint(pretrained, epoch)
226        args = convert_pretrained(pretrained, args)
227    else:
228        logger.info("Experimental: start training from scratch with {}"
229            .format(ctx_str))
230        args = None
231        auxs = None
232        fixed_param_names = None
233
234    # helper information
235    if fixed_param_names:
236        logger.info("Freezed parameters: [" + ','.join(fixed_param_names) + ']')
237
238    # init training module
239    mod = mx.mod.Module(net, label_names=('label',), logger=logger, context=ctx,
240                        fixed_param_names=fixed_param_names)
241
242    # fit parameters
243    batch_end_callback = mx.callback.Speedometer(train_iter.batch_size, frequent=frequent)
244    epoch_end_callback = mx.callback.do_checkpoint(prefix)
245    learning_rate, lr_scheduler = get_lr_scheduler(learning_rate, lr_refactor_step,
246        lr_refactor_ratio, num_example, batch_size, begin_epoch)
247    optimizer_params={'learning_rate':learning_rate,
248                      'momentum':momentum,
249                      'wd':weight_decay,
250                      'lr_scheduler':lr_scheduler,
251                      'clip_gradient':None,
252                      'rescale_grad': 1.0 / len(ctx) if len(ctx) > 0 else 1.0 }
253    monitor = mx.mon.Monitor(iter_monitor, pattern=monitor_pattern) if iter_monitor > 0 else None
254
255    # run fit net, every n epochs we run evaluation network to get mAP
256    if voc07_metric:
257        valid_metric = VOC07MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
258    else:
259        valid_metric = MApMetric(ovp_thresh, use_difficult, class_names, pred_idx=3)
260
261    # create kvstore when there are gpus
262    kv = mx.kvstore.create(kv_store) if kv_store else None
263
264    mod.fit(train_iter,
265            val_iter,
266            eval_metric=MultiBoxMetric(),
267            validation_metric=valid_metric,
268            batch_end_callback=batch_end_callback,
269            epoch_end_callback=epoch_end_callback,
270            optimizer='sgd',
271            optimizer_params=optimizer_params,
272            begin_epoch=begin_epoch,
273            num_epoch=end_epoch,
274            initializer=mx.init.Xavier(),
275            arg_params=args,
276            aux_params=auxs,
277            allow_missing=True,
278            monitor=monitor,
279            kvstore=kv)
280