1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3import torch.nn as nn
4from fvcore.nn.weight_init import c2_msra_fill
5from pytorchvideo.layers import SpatioTemporalClsPositionalEncoding
6
7
8def _init_resnet_weights(model: nn.Module, fc_init_std: float = 0.01) -> None:
9    """
10    Performs ResNet style weight initialization. That is, recursively initialize the
11    given model in the following way for each type:
12        Conv - Follow the initialization of kaiming_normal:
13            https://pytorch.org/docs/stable/_modules/torch/nn/init.html#kaiming_normal_
14        BatchNorm - Set weight and bias of last BatchNorm at every residual bottleneck
15            to 0.
16        Linear - Set weight to 0 mean Gaussian with std deviation fc_init_std and bias
17            to 0.
18    Args:
19        model (nn.Module): Model to be initialized.
20        fc_init_std (float): the expected standard deviation for fully-connected layer.
21    """
22    for m in model.modules():
23        if isinstance(m, (nn.Conv2d, nn.Conv3d)):
24            """
25            Follow the initialization method proposed in:
26            {He, Kaiming, et al.
27            "Delving deep into rectifiers: Surpassing human-level
28            performance on imagenet classification."
29            arXiv preprint arXiv:1502.01852 (2015)}
30            """
31            c2_msra_fill(m)
32        elif isinstance(m, nn.modules.batchnorm._NormBase):
33            if m.weight is not None:
34                if hasattr(m, "block_final_bn") and m.block_final_bn:
35                    m.weight.data.fill_(0.0)
36                else:
37                    m.weight.data.fill_(1.0)
38            if m.bias is not None:
39                m.bias.data.zero_()
40        if isinstance(m, nn.Linear):
41            m.weight.data.normal_(mean=0.0, std=fc_init_std)
42            if m.bias is not None:
43                m.bias.data.zero_()
44    return model
45
46
47def _init_vit_weights(model: nn.Module, trunc_normal_std: float = 0.02) -> None:
48    """
49    Weight initialization for vision transformers.
50
51    Args:
52        model (nn.Module): Model to be initialized.
53        trunc_normal_std (float): the expected standard deviation for fully-connected
54            layer and ClsPositionalEncoding.
55    """
56    for m in model.modules():
57        if isinstance(m, nn.Linear):
58            nn.init.trunc_normal_(m.weight, std=trunc_normal_std)
59            if isinstance(m, nn.Linear) and m.bias is not None:
60                nn.init.constant_(m.bias, 0)
61        elif isinstance(m, nn.LayerNorm):
62            nn.init.constant_(m.bias, 0)
63            nn.init.constant_(m.weight, 1.0)
64        elif isinstance(m, SpatioTemporalClsPositionalEncoding):
65            for weights in m.parameters():
66                nn.init.trunc_normal_(weights, std=trunc_normal_std)
67
68
69def init_net_weights(
70    model: nn.Module,
71    init_std: float = 0.01,
72    style: str = "resnet",
73) -> None:
74    """
75    Performs weight initialization. Options include ResNet style weight initialization
76    and transformer style weight initialization.
77
78    Args:
79        model (nn.Module): Model to be initialized.
80        init_std (float): The expected standard deviation for initialization.
81        style (str): Options include "resnet" and "vit".
82    """
83    assert style in ["resnet", "vit"]
84    if style == "resnet":
85        return _init_resnet_weights(model, init_std)
86    elif style == "vit":
87        return _init_vit_weights(model, init_std)
88    else:
89        raise NotImplementedError
90