1# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 3import torch.nn as nn 4 5 6def c2_xavier_fill(module: nn.Module) -> None: 7 """ 8 Initialize `module.weight` using the "XavierFill" implemented in Caffe2. 9 Also initializes `module.bias` to 0. 10 11 Args: 12 module (torch.nn.Module): module to initialize. 13 """ 14 # Caffe2 implementation of XavierFill in fact 15 # corresponds to kaiming_uniform_ in PyTorch 16 nn.init.kaiming_uniform_(module.weight, a=1) 17 if module.bias is not None: 18 # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, 19 # torch.Tensor]`. 20 nn.init.constant_(module.bias, 0) 21 22 23def c2_msra_fill(module: nn.Module) -> None: 24 """ 25 Initialize `module.weight` using the "MSRAFill" implemented in Caffe2. 26 Also initializes `module.bias` to 0. 27 28 Args: 29 module (torch.nn.Module): module to initialize. 30 """ 31 nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="relu") 32 if module.bias is not None: 33 # pyre-fixme[6]: Expected `Tensor` for 1st param but got `Union[nn.Module, 34 # torch.Tensor]`. 35 nn.init.constant_(module.bias, 0) 36