1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
2
3import torch.nn as nn
4
5
6def c2_xavier_fill(module: nn.Module) -> None:
7    """
8    Initialize `module.weight` using the "XavierFill" implemented in Caffe2.
9    Also initializes `module.bias` to 0.
10
11    Args:
12        module (torch.nn.Module): module to initialize.
13    """
14    # Caffe2 implementation of XavierFill in fact
15    # corresponds to kaiming_uniform_ in PyTorch
16    nn.init.kaiming_uniform_(module.weight, a=1)
17    if module.bias is not None:
18        # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
19        #  torch.Tensor]`.
20        nn.init.constant_(module.bias, 0)
21
22
23def c2_msra_fill(module: nn.Module) -> None:
24    """
25    Initialize `module.weight` using the "MSRAFill" implemented in Caffe2.
26    Also initializes `module.bias` to 0.
27
28    Args:
29        module (torch.nn.Module): module to initialize.
30    """
31    nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu")
32    if module.bias is not None:
33        # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module,
34        #  torch.Tensor]`.
35        nn.init.constant_(module.bias, 0)
36