1import mxnet as mx
2import numpy as np
3import os, time, logging, argparse, shutil
4
5from mxnet import gluon, image, init, nd
6from mxnet import autograd as ag
7from mxnet.gluon import nn
8from mxnet.gluon.data.vision import transforms
9import gluoncv as gcv
10gcv.utils.check_version('0.6.0')
11from gluoncv.utils import makedirs
12from gluoncv.model_zoo import get_model
13
14def parse_opts():
15    parser = argparse.ArgumentParser(description='Transfer learning on MINC-2500 dataset',
16                                     formatter_class=argparse.ArgumentDefaultsHelpFormatter)
17    parser.add_argument('--data', type=str, default='',
18                        help='directory for the prepared data folder')
19    parser.add_argument('--model', required=True, type=str,
20                        help='name of the pretrained model from model zoo.')
21    parser.add_argument('-j', '--workers', dest='num_workers', default=4, type=int,
22                        help='number of preprocessing workers')
23    parser.add_argument('--num-gpus', default=0, type=int,
24                        help='number of gpus to use, 0 indicates cpu only')
25    parser.add_argument('--epochs', default=40, type=int,
26                        help='number of training epochs')
27    parser.add_argument('-b', '--batch-size', default=64, type=int,
28                        help='mini-batch size')
29    parser.add_argument('--lr', '--learning-rate', default=0.001, type=float,
30                        help='initial learning rate')
31    parser.add_argument('--momentum', default=0.9, type=float,
32                        help='momentum')
33    parser.add_argument('--weight-decay', '--wd', dest='wd', default=1e-4, type=float,
34                        help='weight decay (default: 1e-4)')
35    parser.add_argument('--lr-factor', default=0.75, type=float,
36                        help='learning rate decay ratio')
37    parser.add_argument('--lr-steps', default='10,20,30', type=str,
38                        help='list of learning rate decay epochs as in str')
39    opts = parser.parse_args()
40    return opts
41
42# Preparation
43opts = parse_opts()
44classes = 23
45
46model_name = opts.model
47
48epochs = opts.epochs
49lr = opts.lr
50batch_size = opts.batch_size
51momentum = opts.momentum
52wd = opts.wd
53
54lr_factor = opts.lr_factor
55lr_steps = [int(s) for s in opts.lr_steps.split(',')] + [np.inf]
56
57num_gpus = opts.num_gpus
58num_workers = opts.num_workers
59ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
60batch_size = batch_size * max(num_gpus, 1)
61
62logging.basicConfig(level=logging.INFO,
63                    handlers = [logging.StreamHandler()])
64
65train_path = os.path.join(opts.data, 'train')
66val_path = os.path.join(opts.data, 'val')
67test_path = os.path.join(opts.data, 'test')
68
69jitter_param = 0.4
70lighting_param = 0.1
71normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
72
73transform_train = transforms.Compose([
74    transforms.Resize(480),
75    transforms.RandomResizedCrop(224),
76    transforms.RandomFlipLeftRight(),
77    transforms.RandomColorJitter(brightness=jitter_param, contrast=jitter_param,
78                                 saturation=jitter_param),
79    transforms.RandomLighting(lighting_param),
80    transforms.ToTensor(),
81    normalize
82])
83
84transform_test = transforms.Compose([
85    transforms.Resize(256),
86    transforms.CenterCrop(224),
87    transforms.ToTensor(),
88    normalize
89])
90
91def test(net, val_data, ctx):
92    metric = mx.metric.Accuracy()
93    for i, batch in enumerate(val_data):
94        data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
95        label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
96        outputs = [net(X) for X in data]
97        metric.update(label, outputs)
98
99    return metric.get()
100
101def train(train_path, val_path, test_path):
102    # Initialize the net with pretrained model
103    finetune_net = get_model(model_name, pretrained=True)
104    with finetune_net.name_scope():
105        finetune_net.output = nn.Dense(classes)
106    finetune_net.output.initialize(init.Xavier(), ctx = ctx)
107    finetune_net.collect_params().reset_ctx(ctx)
108    finetune_net.hybridize()
109
110    # Define DataLoader
111    train_data = gluon.data.DataLoader(
112        gluon.data.vision.ImageFolderDataset(train_path).transform_first(transform_train),
113        batch_size=batch_size, shuffle=True, num_workers=num_workers)
114
115    val_data = gluon.data.DataLoader(
116        gluon.data.vision.ImageFolderDataset(val_path).transform_first(transform_test),
117        batch_size=batch_size, shuffle=False, num_workers = num_workers)
118
119    test_data = gluon.data.DataLoader(
120        gluon.data.vision.ImageFolderDataset(test_path).transform_first(transform_test),
121        batch_size=batch_size, shuffle=False, num_workers = num_workers)
122
123    # Define Trainer
124    trainer = gluon.Trainer(finetune_net.collect_params(), 'sgd', {
125        'learning_rate': lr, 'momentum': momentum, 'wd': wd})
126    metric = mx.metric.Accuracy()
127    L = gluon.loss.SoftmaxCrossEntropyLoss()
128    lr_counter = 0
129    num_batch = len(train_data)
130
131    # Start Training
132    for epoch in range(epochs):
133        if epoch == lr_steps[lr_counter]:
134            trainer.set_learning_rate(trainer.learning_rate*lr_factor)
135            lr_counter += 1
136
137        tic = time.time()
138        train_loss = 0
139        metric.reset()
140
141        for i, batch in enumerate(train_data):
142            data = gluon.utils.split_and_load(batch[0], ctx_list=ctx, batch_axis=0, even_split=False)
143            label = gluon.utils.split_and_load(batch[1], ctx_list=ctx, batch_axis=0, even_split=False)
144            with ag.record():
145                outputs = [finetune_net(X) for X in data]
146                loss = [L(yhat, y) for yhat, y in zip(outputs, label)]
147            for l in loss:
148                l.backward()
149
150            trainer.step(batch_size)
151            train_loss += sum([l.mean().asscalar() for l in loss]) / len(loss)
152
153            metric.update(label, outputs)
154
155        _, train_acc = metric.get()
156        train_loss /= num_batch
157
158        _, val_acc = test(finetune_net, val_data, ctx)
159
160        logging.info('[Epoch %d] Train-acc: %.3f, loss: %.3f | Val-acc: %.3f | time: %.1f' %
161                 (epoch, train_acc, train_loss, val_acc, time.time() - tic))
162
163    _, test_acc = test(finetune_net, test_data, ctx)
164    logging.info('[Finished] Test-acc: %.3f' % (test_acc))
165
166if __name__ == "__main__":
167    train(train_path, val_path, test_path)
168