1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3import torch.nn as nn 4from fvcore.nn.weight_init import c2_msra_fill 5from pytorchvideo.layers import SpatioTemporalClsPositionalEncoding 6 7 8def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None: 9 """ 10 Performs ResNet style weight initialization. That is, recursively initialize the 11 given model in the following way for each type: 12 Conv - Follow the initialization of kaiming_normal: 13 https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_ 14 BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck 15 to 0. 16 Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias 17 to 0. 18 Args: 19 model (nn.Module): Model to be initialized. 20 fc_init_std (float): the expected standard deviation for fully-connected layer. 21 """ 22 for m in model.modules(): 23 if isinstance(m, (nn.Conv2d, nn.Conv3d)): 24 """ 25 Follow the initialization method proposed in: 26 {He, Kaiming, et al. 27 "Delving deep into rectifiers: Surpassing human-level 28 performance on imagenet classification." 29 arXiv preprint arXiv:1502.01852 (2015)} 30 """ 31 c2_msra_fill(m) 32 elif isinstance(m, nn.modules.batchnorm._NormBase): 33 if m.weight is not None: 34 if hasattr(m, "block_final_bn") and m.block_final_bn: 35 m.weight.data.fill_(0.0) 36 else: 37 m.weight.data.fill_(1.0) 38 if m.bias is not None: 39 m.bias.data.zero_() 40 if isinstance(m, nn.Linear): 41 m.weight.data.normal_(mean=0.0, std=fc_init_std) 42 if m.bias is not None: 43 m.bias.data.zero_() 44 return model 45 46 47def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None: 48 """ 49 Weight initialization for vision transformers. 50 51 Args: 52 model (nn.Module): Model to be initialized. 53 trunc_normal_std (float): the expected standard deviation for fully-connected 54 layer and ClsPositionalEncoding. 55 """ 56 for m in model.modules(): 57 if isinstance(m, nn.Linear): 58 nn.init.trunc_normal_(m.weight, std=trunc_normal_std) 59 if isinstance(m, nn.Linear) and m.bias is not None: 60 nn.init.constant_(m.bias, 0) 61 elif isinstance(m, nn.LayerNorm): 62 nn.init.constant_(m.bias, 0) 63 nn.init.constant_(m.weight, 1.0) 64 elif isinstance(m, SpatioTemporalClsPositionalEncoding): 65 for weights in m.parameters(): 66 nn.init.trunc_normal_(weights, std=trunc_normal_std) 67 68 69def init_net_weights( 70 model: nn.Module, 71 init_std: float = 0.01, 72 style: str = "resnet", 73) -> None: 74 """ 75 Performs weight initialization. Options include ResNet style weight initialization 76 and transformer style weight initialization. 77 78 Args: 79 model (nn.Module): Model to be initialized. 80 init_std (float): The expected standard deviation for initialization. 81 style (str): Options include "resnet" and "vit". 82 """ 83 assert style in ["resnet", "vit"] 84 if style == "resnet": 85 return _init_resnet_weights(model, init_std) 86 elif style == "vit": 87 return _init_vit_weights(model, init_std) 88 else: 89 raise NotImplementedError 90