1# Tencent is pleased to support the open source community by making ncnn available.
2#
3# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4#
5# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6# in compliance with the License. You may obtain a copy of the License at
7#
8# https://opensource.org/licenses/BSD-3-Clause
9#
10# Unless required by applicable law or agreed to in writing, software distributed
11# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12# CONDITIONS OF ANY KIND, either express or implied. See the License for the
13# specific language governing permissions and limitations under the License.
14
15import torch
16import torch.nn as nn
17import torch.nn.functional as F
18
19class Model(nn.Module):
20    def __init__(self):
21        super(Model, self).__init__()
22
23        self.bn_0 = nn.BatchNorm2d(num_features=32)
24        self.bn_1 = nn.BatchNorm2d(num_features=32, eps=1e-1, affine=False)
25        self.bn_2 = nn.BatchNorm2d(num_features=11, affine=True)
26
27    def forward(self, x, y):
28        x = self.bn_0(x)
29        x = self.bn_1(x)
30
31        y = self.bn_2(y)
32
33        return x, y
34
35def test():
36    net = Model()
37    net.eval()
38
39    torch.manual_seed(0)
40    x = torch.rand(1, 32, 12, 64)
41    y = torch.rand(1, 11, 1, 1)
42
43    a0, a1 = net(x, y)
44
45    # export torchscript
46    mod = torch.jit.trace(net, (x, y))
47    mod.save("test_nn_BatchNorm2d.pt")
48
49    # torchscript to pnnx
50    import os
51    os.system("../src/pnnx test_nn_BatchNorm2d.pt inputshape=[1,32,12,64],[1,11,1,1]")
52
53    # pnnx inference
54    import test_nn_BatchNorm2d_pnnx
55    b0, b1 = test_nn_BatchNorm2d_pnnx.test_inference()
56
57    return torch.equal(a0, b0) and torch.equal(a1, b1)
58
59if __name__ == "__main__":
60    if test():
61        exit(0)
62    else:
63        exit(1)
64