1// Code generated by "go generate gonum.org/v1/gonum/blas/gonum”; DO NOT EDIT.
2
3// Copyright ©2014 The Gonum Authors. All rights reserved.
4// Use of this source code is governed by a BSD-style
5// license that can be found in the LICENSE file.
6
7package gonum
8
9import (
10	"runtime"
11	"sync"
12
13	"gonum.org/v1/gonum/blas"
14	"gonum.org/v1/gonum/internal/asm/f32"
15)
16
17// Sgemm performs one of the matrix-matrix operations
18//  C = alpha * A * B + beta * C
19//  C = alpha * Aᵀ * B + beta * C
20//  C = alpha * A * Bᵀ + beta * C
21//  C = alpha * Aᵀ * Bᵀ + beta * C
22// where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
23// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
24// B are transposed.
25//
26// Float32 implementations are autogenerated and not directly tested.
27func (Implementation) Sgemm(tA, tB blas.Transpose, m, n, k int, alpha float32, a []float32, lda int, b []float32, ldb int, beta float32, c []float32, ldc int) {
28	switch tA {
29	default:
30		panic(badTranspose)
31	case blas.NoTrans, blas.Trans, blas.ConjTrans:
32	}
33	switch tB {
34	default:
35		panic(badTranspose)
36	case blas.NoTrans, blas.Trans, blas.ConjTrans:
37	}
38	if m < 0 {
39		panic(mLT0)
40	}
41	if n < 0 {
42		panic(nLT0)
43	}
44	if k < 0 {
45		panic(kLT0)
46	}
47	aTrans := tA == blas.Trans || tA == blas.ConjTrans
48	if aTrans {
49		if lda < max(1, m) {
50			panic(badLdA)
51		}
52	} else {
53		if lda < max(1, k) {
54			panic(badLdA)
55		}
56	}
57	bTrans := tB == blas.Trans || tB == blas.ConjTrans
58	if bTrans {
59		if ldb < max(1, k) {
60			panic(badLdB)
61		}
62	} else {
63		if ldb < max(1, n) {
64			panic(badLdB)
65		}
66	}
67	if ldc < max(1, n) {
68		panic(badLdC)
69	}
70
71	// Quick return if possible.
72	if m == 0 || n == 0 {
73		return
74	}
75
76	// For zero matrix size the following slice length checks are trivially satisfied.
77	if aTrans {
78		if len(a) < (k-1)*lda+m {
79			panic(shortA)
80		}
81	} else {
82		if len(a) < (m-1)*lda+k {
83			panic(shortA)
84		}
85	}
86	if bTrans {
87		if len(b) < (n-1)*ldb+k {
88			panic(shortB)
89		}
90	} else {
91		if len(b) < (k-1)*ldb+n {
92			panic(shortB)
93		}
94	}
95	if len(c) < (m-1)*ldc+n {
96		panic(shortC)
97	}
98
99	// Quick return if possible.
100	if (alpha == 0 || k == 0) && beta == 1 {
101		return
102	}
103
104	// scale c
105	if beta != 1 {
106		if beta == 0 {
107			for i := 0; i < m; i++ {
108				ctmp := c[i*ldc : i*ldc+n]
109				for j := range ctmp {
110					ctmp[j] = 0
111				}
112			}
113		} else {
114			for i := 0; i < m; i++ {
115				ctmp := c[i*ldc : i*ldc+n]
116				for j := range ctmp {
117					ctmp[j] *= beta
118				}
119			}
120		}
121	}
122
123	sgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
124}
125
126func sgemmParallel(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
127	// dgemmParallel computes a parallel matrix multiplication by partitioning
128	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
129	// In all cases,
130	// A = [ 	A_11	A_12 ... 	A_1j
131	//			A_21	A_22 ...	A_2j
132	//				...
133	//			A_i1	A_i2 ...	A_ij]
134	//
135	// and same for B. All of the submatrix sizes are blockSize×blockSize except
136	// at the edges.
137	//
138	// In all cases, there is one dimension for each matrix along which
139	// C must be updated sequentially.
140	// Cij = \sum_k Aik Bki,	(A * B)
141	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
142	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
143	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
144	//
145	// This code computes one {i, j} block sequentially along the k dimension,
146	// and computes all of the {i, j} blocks concurrently. This
147	// partitioning allows Cij to be updated in-place without race-conditions.
148	// Instead of launching a goroutine for each possible concurrent computation,
149	// a number of worker goroutines are created and channels are used to pass
150	// available and completed cases.
151	//
152	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
153	// multiplies, though this code does not copy matrices to attempt to eliminate
154	// cache misses.
155
156	maxKLen := k
157	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
158	if parBlocks < minParBlock {
159		// The matrix multiplication is small in the dimensions where it can be
160		// computed concurrently. Just do it in serial.
161		sgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
162		return
163	}
164
165	// workerLimit acts a number of maximum concurrent workers,
166	// with the limit set to the number of procs available.
167	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
168
169	// wg is used to wait for all
170	var wg sync.WaitGroup
171	wg.Add(parBlocks)
172	defer wg.Wait()
173
174	for i := 0; i < m; i += blockSize {
175		for j := 0; j < n; j += blockSize {
176			workerLimit <- struct{}{}
177			go func(i, j int) {
178				defer func() {
179					wg.Done()
180					<-workerLimit
181				}()
182
183				leni := blockSize
184				if i+leni > m {
185					leni = m - i
186				}
187				lenj := blockSize
188				if j+lenj > n {
189					lenj = n - j
190				}
191
192				cSub := sliceView32(c, ldc, i, j, leni, lenj)
193
194				// Compute A_ik B_kj for all k
195				for k := 0; k < maxKLen; k += blockSize {
196					lenk := blockSize
197					if k+lenk > maxKLen {
198						lenk = maxKLen - k
199					}
200					var aSub, bSub []float32
201					if aTrans {
202						aSub = sliceView32(a, lda, k, i, lenk, leni)
203					} else {
204						aSub = sliceView32(a, lda, i, k, leni, lenk)
205					}
206					if bTrans {
207						bSub = sliceView32(b, ldb, j, k, lenj, lenk)
208					} else {
209						bSub = sliceView32(b, ldb, k, j, lenk, lenj)
210					}
211					sgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
212				}
213			}(i, j)
214		}
215	}
216}
217
218// sgemmSerial is serial matrix multiply
219func sgemmSerial(aTrans, bTrans bool, m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
220	switch {
221	case !aTrans && !bTrans:
222		sgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
223		return
224	case aTrans && !bTrans:
225		sgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
226		return
227	case !aTrans && bTrans:
228		sgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
229		return
230	case aTrans && bTrans:
231		sgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
232		return
233	default:
234		panic("unreachable")
235	}
236}
237
238// sgemmSerial where neither a nor b are transposed
239func sgemmSerialNotNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
240	// This style is used instead of the literal [i*stride +j]) is used because
241	// approximately 5 times faster as of go 1.3.
242	for i := 0; i < m; i++ {
243		ctmp := c[i*ldc : i*ldc+n]
244		for l, v := range a[i*lda : i*lda+k] {
245			tmp := alpha * v
246			if tmp != 0 {
247				f32.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
248			}
249		}
250	}
251}
252
253// sgemmSerial where neither a is transposed and b is not
254func sgemmSerialTransNot(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
255	// This style is used instead of the literal [i*stride +j]) is used because
256	// approximately 5 times faster as of go 1.3.
257	for l := 0; l < k; l++ {
258		btmp := b[l*ldb : l*ldb+n]
259		for i, v := range a[l*lda : l*lda+m] {
260			tmp := alpha * v
261			if tmp != 0 {
262				ctmp := c[i*ldc : i*ldc+n]
263				f32.AxpyUnitary(tmp, btmp, ctmp)
264			}
265		}
266	}
267}
268
269// sgemmSerial where neither a is not transposed and b is
270func sgemmSerialNotTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
271	// This style is used instead of the literal [i*stride +j]) is used because
272	// approximately 5 times faster as of go 1.3.
273	for i := 0; i < m; i++ {
274		atmp := a[i*lda : i*lda+k]
275		ctmp := c[i*ldc : i*ldc+n]
276		for j := 0; j < n; j++ {
277			ctmp[j] += alpha * f32.DotUnitary(atmp, b[j*ldb:j*ldb+k])
278		}
279	}
280}
281
282// sgemmSerial where both are transposed
283func sgemmSerialTransTrans(m, n, k int, a []float32, lda int, b []float32, ldb int, c []float32, ldc int, alpha float32) {
284	// This style is used instead of the literal [i*stride +j]) is used because
285	// approximately 5 times faster as of go 1.3.
286	for l := 0; l < k; l++ {
287		for i, v := range a[l*lda : l*lda+m] {
288			tmp := alpha * v
289			if tmp != 0 {
290				ctmp := c[i*ldc : i*ldc+n]
291				f32.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
292			}
293		}
294	}
295}
296
297func sliceView32(a []float32, lda, i, j, r, c int) []float32 {
298	return a[i*lda+j : (i+r-1)*lda+j+c]
299}
300