1import os 2import argparse 3 4import torch 5import torch.nn as nn 6import torch.distributed as dist 7import torch.optim 8from tensorboardX import SummaryWriter 9 10from gluoncv.torch.model_zoo import get_model 11from gluoncv.torch.data import build_dataloader 12from gluoncv.torch.utils.model_utils import deploy_model, load_model, save_model 13from gluoncv.torch.utils.task_utils import train_classification, validation_classification 14from gluoncv.torch.engine.config import get_cfg_defaults 15from gluoncv.torch.engine.launch import spawn_workers 16from gluoncv.torch.utils.utils import build_log_dir 17from gluoncv.torch.utils.lr_policy import GradualWarmupScheduler 18 19 20def main_worker(cfg): 21 # create tensorboard and logs 22 if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0: 23 tb_logdir = build_log_dir(cfg) 24 writer = SummaryWriter(log_dir=tb_logdir) 25 else: 26 writer = None 27 cfg.freeze() 28 29 # create model 30 model = get_model(cfg) 31 model = deploy_model(model, cfg) 32 33 # create dataset and dataloader 34 train_loader, val_loader, train_sampler, val_sampler, mg_sampler = build_dataloader(cfg) 35 36 optimizer = torch.optim.SGD(model.parameters(), 37 lr=cfg.CONFIG.TRAIN.LR, 38 momentum=cfg.CONFIG.TRAIN.MOMENTUM, 39 weight_decay=cfg.CONFIG.TRAIN.W_DECAY) 40 41 if cfg.CONFIG.MODEL.LOAD: 42 model, _ = load_model(model, optimizer, cfg, load_fc=True) 43 44 if cfg.CONFIG.TRAIN.LR_POLICY == 'Step': 45 scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, 46 milestones=cfg.CONFIG.TRAIN.LR_MILESTONE, 47 gamma=cfg.CONFIG.TRAIN.STEP) 48 elif cfg.CONFIG.TRAIN.LR_POLICY == 'Cosine': 49 scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 50 T_max=cfg.CONFIG.TRAIN.EPOCH_NUM - cfg.CONFIG.TRAIN.WARMUP_EPOCHS, 51 eta_min=0, 52 last_epoch=cfg.CONFIG.TRAIN.RESUME_EPOCH) 53 else: 54 print('Learning rate schedule %s is not supported yet. Please use Step or Cosine.') 55 56 if cfg.CONFIG.TRAIN.USE_WARMUP: 57 scheduler_warmup = GradualWarmupScheduler(optimizer, 58 multiplier=(cfg.CONFIG.TRAIN.WARMUP_END_LR / cfg.CONFIG.TRAIN.LR), 59 total_epoch=cfg.CONFIG.TRAIN.WARMUP_EPOCHS, 60 after_scheduler=scheduler) 61 criterion = nn.CrossEntropyLoss().cuda() 62 63 base_iter = 0 64 for epoch in range(cfg.CONFIG.TRAIN.EPOCH_NUM): 65 if cfg.DDP_CONFIG.DISTRIBUTED: 66 train_sampler.set_epoch(epoch) 67 68 base_iter = train_classification(base_iter, model, train_loader, epoch, criterion, optimizer, cfg, writer=writer) 69 if cfg.CONFIG.TRAIN.USE_WARMUP: 70 scheduler_warmup.step() 71 else: 72 scheduler.step() 73 74 if epoch % cfg.CONFIG.VAL.FREQ == 0 or epoch == cfg.CONFIG.TRAIN.EPOCH_NUM - 1: 75 validation_classification(model, val_loader, epoch, criterion, cfg, writer) 76 77 if epoch % cfg.CONFIG.LOG.SAVE_FREQ == 0: 78 if cfg.DDP_CONFIG.GPU_WORLD_RANK == 0 or cfg.DDP_CONFIG.DISTRIBUTED == False: 79 save_model(model, optimizer, epoch, cfg) 80 if writer is not None: 81 writer.close() 82 83 84if __name__ == '__main__': 85 parser = argparse.ArgumentParser(description='Train video action recognition models.') 86 parser.add_argument('--config-file', type=str, help='path to config file.') 87 args = parser.parse_args() 88 89 cfg = get_cfg_defaults() 90 cfg.merge_from_file(args.config_file) 91 spawn_workers(main_worker, cfg) 92