1# Tencent is pleased to support the open source community by making ncnn available. 2# 3# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved. 4# 5# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except 6# in compliance with the License. You may obtain a copy of the License at 7# 8# https://opensource.org/licenses/BSD-3-Clause 9# 10# Unless required by applicable law or agreed to in writing, software distributed 11# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR 12# CONDITIONS OF ANY KIND, either express or implied. See the License for the 13# specific language governing permissions and limitations under the License. 14 15import torch 16import torch.nn as nn 17import torch.nn.functional as F 18 19class Model(nn.Module): 20 def __init__(self): 21 super(Model, self).__init__() 22 23 self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4) 24 25 if torch.__version__ >= '1.9': 26 self.attention_1_0 = nn.MultiheadAttention(embed_dim=40, num_heads=4, batch_first=True) 27 28 def forward(self, x, y): 29 x0, _ = self.attention_0_0(x, x, x) 30 31 if torch.__version__ < '1.9': 32 return x0 33 34 y0, _ = self.attention_1_0(y, y, y) 35 36 return x0, y0 37 38def test(): 39 net = Model() 40 net.eval() 41 42 torch.manual_seed(0) 43 x = torch.rand(1, 1, 64) 44 y = torch.rand(1, 15, 40) 45 46 a = net(x, y) 47 48 # export torchscript 49 mod = torch.jit.trace(net, (x, y)) 50 mod.save("test_nn_MultiheadAttention.pt") 51 52 # torchscript to pnnx 53 import os 54 os.system("../../src/pnnx test_nn_MultiheadAttention.pt inputshape=[1,1,64],[1,15,40]") 55 56 # ncnn inference 57 import test_nn_MultiheadAttention_ncnn 58 b = test_nn_MultiheadAttention_ncnn.test_inference() 59 60 for a0, b0 in zip(a, b): 61 if not torch.allclose(a0, b0, 1e-4, 1e-4): 62 return False 63 return True 64 65if __name__ == "__main__": 66 if test(): 67 exit(0) 68 else: 69 exit(1) 70