1import os
2import argparse
3
4import torch
5import torch.nn as nn
6import torch.distributed as dist
7import torch.optim
8from tensorboardX import SummaryWriter
9
10from gluoncv.torch.model_zoo import get_model
11from gluoncv.torch.data import build_dataloader
12from gluoncv.torch.utils.model_utils import deploy_model, load_model, save_model
13from gluoncv.torch.utils.task_utils import train_classification, validation_classification
14from gluoncv.torch.engine.config import get_cfg_defaults
15from gluoncv.torch.engine.launch import spawn_workers
16from gluoncv.torch.utils.utils import build_log_dir
17from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler
18
19
20def main_worker(cfg):
21    # create tensorboard and logs
22    if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0:
23        tb_logdir = build_log_dir(cfg)
24        writer = SummaryWriter(log_dir=tb_logdir)
25    else:
26        writer = None
27    cfg.freeze()
28
29    # create model
30    model = get_model(cfg)
31    model = deploy_model(model, cfg)
32
33    # create dataset and dataloader
34    train_loader, val_loader, train_sampler, val_sampler, mg_sampler = build_dataloader(cfg)
35
36    optimizer = torch.optim.SGD(model.parameters(),
37                                lr=cfg.CONFIG.TRAIN.LR,
38                                momentum=cfg.CONFIG.TRAIN.MOMENTUM,
39                                weight_decay=cfg.CONFIG.TRAIN.W_DECAY)
40
41    if cfg.CONFIG.MODEL.LOAD:
42        model, _ = load_model(model, optimizer, cfg, load_fc=True)
43
44    if cfg.CONFIG.TRAIN.LR_POLICY == 'Step':
45        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
46                                                         milestones=cfg.CONFIG.TRAIN.LR_MILESTONE,
47                                                         gamma=cfg.CONFIG.TRAIN.STEP)
48    elif cfg.CONFIG.TRAIN.LR_POLICY == 'Cosine':
49        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,
50                                                               T_max=cfg.CONFIG.TRAIN.EPOCH_NUM - cfg.CONFIG.TRAIN.WARMUP_EPOCHS,
51                                                               eta_min=0,
52                                                               last_epoch=cfg.CONFIG.TRAIN.RESUME_EPOCH)
53    else:
54        print('Learning rate schedule %s is not supported yet. Please use Step or Cosine.')
55
56    if cfg.CONFIG.TRAIN.USE_WARMUP:
57        scheduler_warmup = GradualWarmupScheduler(optimizer,
58                                                  multiplier=(cfg.CONFIG.TRAIN.WARMUP_END_LR / cfg.CONFIG.TRAIN.LR),
59                                                  total_epoch=cfg.CONFIG.TRAIN.WARMUP_EPOCHS,
60                                                  after_scheduler=scheduler)
61    criterion = nn.CrossEntropyLoss().cuda()
62
63    base_iter = 0
64    for epoch in range(cfg.CONFIG.TRAIN.EPOCH_NUM):
65        if cfg.DDP_CONFIG.DISTRIBUTED:
66            train_sampler.set_epoch(epoch)
67
68        base_iter = train_classification(base_iter, model, train_loader, epoch, criterion, optimizer, cfg, writer=writer)
69        if cfg.CONFIG.TRAIN.USE_WARMUP:
70            scheduler_warmup.step()
71        else:
72            scheduler.step()
73
74        if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1:
75            validation_classification(model, val_loader, epoch, criterion, cfg, writer)
76
77        if epoch % cfg.CONFIG.LOG.SAVE_FREQ == 0:
78            if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0 or cfg.DDP_CONFIG.DISTRIBUTED == False:
79                save_model(model, optimizer, epoch, cfg)
80    if writer is not None:
81        writer.close()
82
83
84if __name__ == '__main__':
85    parser = argparse.ArgumentParser(description='Train video action recognition models.')
86    parser.add_argument('--config-file', type=str, help='path to config file.')
87    args = parser.parse_args()
88
89    cfg = get_cfg_defaults()
90    cfg.merge_from_file(args.config_file)
91    spawn_workers(main_worker, cfg)
92