1# Tencent is pleased to support the open source community by making ncnn available. 2# 3# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4# 5# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6# in compliance with the License. You may obtain a copy of the License at 7# 8# https://opensource.org/licenses/BSD-3-Clause 9# 10# Unless required by applicable law or agreed to in writing, software distributed 11# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12# CONDITIONS OF ANY KIND, either express or implied. See the License for the 13# specific language governing permissions and limitations under the License. 14 15import torch 16import torch.nn as nn 17import torch.nn.functional as F 18 19class Model(nn.Module): 20 def __init__(self): 21 super(Model, self).__init__() 22 23 self.bn_0 = nn.BatchNorm2d(num_features=32) 24 self.bn_1 = nn.BatchNorm2d(num_features=32, eps=1e-1, affine=False) 25 self.bn_2 = nn.BatchNorm2d(num_features=11, affine=True) 26 27 def forward(self, x, y): 28 x = self.bn_0(x) 29 x = self.bn_1(x) 30 31 y = self.bn_2(y) 32 33 return x, y 34 35def test(): 36 net = Model() 37 net.eval() 38 39 torch.manual_seed(0) 40 x = torch.rand(1, 32, 12, 64) 41 y = torch.rand(1, 11, 1, 1) 42 43 a0, a1 = net(x, y) 44 45 # export torchscript 46 mod = torch.jit.trace(net, (x, y)) 47 mod.save("test_nn_BatchNorm2d.pt") 48 49 # torchscript to pnnx 50 import os 51 os.system("../src/pnnx test_nn_BatchNorm2d.pt inputshape=[1,32,12,64],[1,11,1,1]") 52 53 # pnnx inference 54 import test_nn_BatchNorm2d_pnnx 55 b0, b1 = test_nn_BatchNorm2d_pnnx.test_inference() 56 57 return torch.equal(a0, b0) and torch.equal(a1, b1) 58 59if __name__ == "__main__": 60 if test(): 61 exit(0) 62 else: 63 exit(1) 64