1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2020 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "gemm.h"
16
17 namespace ncnn {
18
Gemm()19 Gemm::Gemm()
20 {
21 one_blob_only = false;
22 support_inplace = false;
23 }
24
load_param(const ParamDict & pd)25 int Gemm::load_param(const ParamDict& pd)
26 {
27 alpha = pd.get(0, 1.f);
28 beta = pd.get(1, 1.f);
29 transA = pd.get(2, 0);
30 transB = pd.get(3, 0);
31
32 return 0;
33 }
34
forward(const std::vector<Mat> & bottom_blobs,std::vector<Mat> & top_blobs,const Option & opt) const35 int Gemm::forward(const std::vector<Mat>& bottom_blobs, std::vector<Mat>& top_blobs, const Option& opt) const
36 {
37 const Mat& A0 = bottom_blobs[0];
38 const Mat& B0 = bottom_blobs[1];
39
40 size_t elemsize = A0.elemsize;
41
42 Mat A;
43 if (transA == 0)
44 {
45 A = A0;
46 }
47 else
48 {
49 // transpose A to row-major
50 A.create(A0.h, A0.w, elemsize, opt.workspace_allocator);
51
52 for (int i = 0; i < A.h; i++)
53 {
54 float* ptr = A.row(i);
55 for (int j = 0; j < A.w; j++)
56 {
57 ptr[j] = A0.row(j)[i];
58 }
59 }
60 }
61
62 Mat B;
63 if (transB == 0)
64 {
65 // transpose B to col-major
66 B.create(B0.h, B0.w, elemsize, opt.workspace_allocator);
67
68 for (int i = 0; i < B.h; i++)
69 {
70 float* ptr = B.row(i);
71 for (int j = 0; j < B.w; j++)
72 {
73 ptr[j] = B0.row(j)[i];
74 }
75 }
76 }
77 else
78 {
79 B = B0;
80 }
81
82 int M = A.h;
83 int K = A.w; // assert A.w == B.w
84 int N = B.h;
85
86 bool has_C = bottom_blobs.size() == 3;
87
88 const float* ptrC = 0;
89 int broadcast_type_C = 0;
90 if (has_C)
91 {
92 const Mat& C = bottom_blobs[2];
93
94 ptrC = C;
95
96 if (C.dims == 1 && C.w == 1)
97 {
98 // scalar
99 broadcast_type_C = 0;
100 }
101 if (C.dims == 1 && C.w == M)
102 {
103 // M
104 // auto broadcast from h to w is the ncnn-style convention
105 broadcast_type_C = 1;
106 }
107 if (C.dims == 2 && C.w == 1 && C.h == M)
108 {
109 // Mx1
110 broadcast_type_C = 2;
111 }
112 if (C.dims == 2 && C.w == N && C.h == M)
113 {
114 // MxN
115 broadcast_type_C = 3;
116 }
117 if (C.dims == 2 && C.w == N && C.h == 1)
118 {
119 // 1xN
120 broadcast_type_C = 4;
121 }
122 }
123
124 Mat& top_blob = top_blobs[0];
125 top_blob.create(N, M, elemsize, opt.blob_allocator);
126 if (top_blob.empty())
127 return -100;
128
129 float* outptr = top_blob;
130 for (int i = 0; i < M; i++)
131 {
132 const float* ptrA = A.row(i);
133
134 for (int j = 0; j < N; j++)
135 {
136 const float* ptrB = B.row(j);
137
138 float sum = 0.f;
139 if (has_C)
140 {
141 if (broadcast_type_C == 0)
142 {
143 sum = ptrC[0];
144 }
145 if (broadcast_type_C == 1)
146 {
147 sum = ptrC[i];
148 }
149 if (broadcast_type_C == 2)
150 {
151 sum = ptrC[i];
152 }
153 if (broadcast_type_C == 3)
154 {
155 sum = ptrC[i * N + j];
156 }
157 if (broadcast_type_C == 4)
158 {
159 sum = ptrC[j];
160 }
161
162 sum *= beta;
163 }
164
165 for (int k = 0; k < K; k++)
166 {
167 sum += ptrA[k] * ptrB[k];
168 }
169
170 *outptr++ = sum * alpha;
171 }
172 }
173
174 return 0;
175 }
176
177 } // namespace ncnn
178