1# Tencent is pleased to support the open source community by making ncnn available.
2#
3# Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4#
5# Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6# in compliance with the License. You may obtain a copy of the License at
7#
8# https://opensource.org/licenses/BSD-3-Clause
9#
10# Unless required by applicable law or agreed to in writing, software distributed
11# under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12# CONDITIONS OF ANY KIND, either express or implied. See the License for the
13# specific language governing permissions and limitations under the License.
14
15import torch
16import torch.nn as nn
17import torch.nn.functional as F
18
19class Model(nn.Module):
20    def __init__(self):
21        super(Model, self).__init__()
22
23        self.attention_0_0 = nn.MultiheadAttention(embed_dim=64, num_heads=4)
24
25        if torch.__version__ >= '1.9':
26            self.attention_1_0 = nn.MultiheadAttention(embed_dim=40, num_heads=4, batch_first=True)
27
28    def forward(self, x, y):
29        x0, _ = self.attention_0_0(x, x, x)
30
31        if torch.__version__ < '1.9':
32            return x0
33
34        y0, _ = self.attention_1_0(y, y, y)
35
36        return x0, y0
37
38def test():
39    net = Model()
40    net.eval()
41
42    torch.manual_seed(0)
43    x = torch.rand(1, 1, 64)
44    y = torch.rand(1, 15, 40)
45
46    a = net(x, y)
47
48    # export torchscript
49    mod = torch.jit.trace(net, (x, y))
50    mod.save("test_nn_MultiheadAttention.pt")
51
52    # torchscript to pnnx
53    import os
54    os.system("../../src/pnnx test_nn_MultiheadAttention.pt inputshape=[1,1,64],[1,15,40]")
55
56    # ncnn inference
57    import test_nn_MultiheadAttention_ncnn
58    b = test_nn_MultiheadAttention_ncnn.test_inference()
59
60    for a0, b0 in zip(a, b):
61        if not torch.allclose(a0, b0, 1e-4, 1e-4):
62            return False
63    return True
64
65if __name__ == "__main__":
66    if test():
67        exit(0)
68    else:
69        exit(1)
70