1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "gemm.h"
16 
17 namespace ncnn {
18 
Gemm()19 Gemm::Gemm()
20 {
21     one_blob_only = false;
22     support_inplace = false;
23 }
24 
load_param(const ParamDict & pd)25 int Gemm::load_param(const ParamDict& pd)
26 {
27     alpha = pd.get(0, 1.f);
28     beta = pd.get(1, 1.f);
29     transA = pd.get(2, 0);
30     transB = pd.get(3, 0);
31 
32     return 0;
33 }
34 
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const35 int Gemm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
36 {
37     const Mat& A0 = bottom_blobs[0];
38     const Mat& B0 = bottom_blobs[1];
39 
40     size_t elemsize = A0.elemsize;
41 
42     Mat A;
43     if (transA == 0)
44     {
45         A = A0;
46     }
47     else
48     {
49         // transpose A to row-major
50         A.create(A0.h, A0.w, elemsize, opt.workspace_allocator);
51 
52         for (int i = 0; i < A.h; i++)
53         {
54             float* ptr = A.row(i);
55             for (int j = 0; j < A.w; j++)
56             {
57                 ptr[j] = A0.row(j)[i];
58             }
59         }
60     }
61 
62     Mat B;
63     if (transB == 0)
64     {
65         // transpose B to col-major
66         B.create(B0.h, B0.w, elemsize, opt.workspace_allocator);
67 
68         for (int i = 0; i < B.h; i++)
69         {
70             float* ptr = B.row(i);
71             for (int j = 0; j < B.w; j++)
72             {
73                 ptr[j] = B0.row(j)[i];
74             }
75         }
76     }
77     else
78     {
79         B = B0;
80     }
81 
82     int M = A.h;
83     int K = A.w; // assert A.w == B.w
84     int N = B.h;
85 
86     bool has_C = bottom_blobs.size() == 3;
87 
88     const float* ptrC = 0;
89     int broadcast_type_C = 0;
90     if (has_C)
91     {
92         const Mat& C = bottom_blobs[2];
93 
94         ptrC = C;
95 
96         if (C.dims == 1 && C.w == 1)
97         {
98             // scalar
99             broadcast_type_C = 0;
100         }
101         if (C.dims == 1 && C.w == M)
102         {
103             // M
104             // auto broadcast from h to w is the ncnn-style convention
105             broadcast_type_C = 1;
106         }
107         if (C.dims == 2 && C.w == 1 && C.h == M)
108         {
109             // Mx1
110             broadcast_type_C = 2;
111         }
112         if (C.dims == 2 && C.w == N && C.h == M)
113         {
114             // MxN
115             broadcast_type_C = 3;
116         }
117         if (C.dims == 2 && C.w == N && C.h == 1)
118         {
119             // 1xN
120             broadcast_type_C = 4;
121         }
122     }
123 
124     Mat& top_blob = top_blobs[0];
125     top_blob.create(N, M, elemsize, opt.blob_allocator);
126     if (top_blob.empty())
127         return -100;
128 
129     float* outptr = top_blob;
130     for (int i = 0; i < M; i++)
131     {
132         const float* ptrA = A.row(i);
133 
134         for (int j = 0; j < N; j++)
135         {
136             const float* ptrB = B.row(j);
137 
138             float sum = 0.f;
139             if (has_C)
140             {
141                 if (broadcast_type_C == 0)
142                 {
143                     sum = ptrC[0];
144                 }
145                 if (broadcast_type_C == 1)
146                 {
147                     sum = ptrC[i];
148                 }
149                 if (broadcast_type_C == 2)
150                 {
151                     sum = ptrC[i];
152                 }
153                 if (broadcast_type_C == 3)
154                 {
155                     sum = ptrC[i * N + j];
156                 }
157                 if (broadcast_type_C == 4)
158                 {
159                     sum = ptrC[j];
160                 }
161 
162                 sum *= beta;
163             }
164 
165             for (int k = 0; k < K; k++)
166             {
167                 sum += ptrA[k] * ptrB[k];
168             }
169 
170             *outptr++ = sum * alpha;
171         }
172     }
173 
174     return 0;
175 }
176 
177 } // namespace ncnn
178