1// Copyright ©2014 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"runtime"
9	"sync"
10
11	"gonum.org/v1/gonum/blas"
12	"gonum.org/v1/gonum/internal/asm/f64"
13)
14
15// Dgemm performs one of the matrix-matrix operations
16//  C = alpha * A * B + beta * C
17//  C = alpha * A^T * B + beta * C
18//  C = alpha * A * B^T + beta * C
19//  C = alpha * A^T * B^T + beta * C
20// where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
21// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
22// B are transposed.
23func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
24	switch tA {
25	default:
26		panic(badTranspose)
27	case blas.NoTrans, blas.Trans, blas.ConjTrans:
28	}
29	switch tB {
30	default:
31		panic(badTranspose)
32	case blas.NoTrans, blas.Trans, blas.ConjTrans:
33	}
34	if m < 0 {
35		panic(mLT0)
36	}
37	if n < 0 {
38		panic(nLT0)
39	}
40	if k < 0 {
41		panic(kLT0)
42	}
43	aTrans := tA == blas.Trans || tA == blas.ConjTrans
44	if aTrans {
45		if lda < max(1, m) {
46			panic(badLdA)
47		}
48	} else {
49		if lda < max(1, k) {
50			panic(badLdA)
51		}
52	}
53	bTrans := tB == blas.Trans || tB == blas.ConjTrans
54	if bTrans {
55		if ldb < max(1, k) {
56			panic(badLdB)
57		}
58	} else {
59		if ldb < max(1, n) {
60			panic(badLdB)
61		}
62	}
63	if ldc < max(1, n) {
64		panic(badLdC)
65	}
66
67	// Quick return if possible.
68	if m == 0 || n == 0 {
69		return
70	}
71
72	// For zero matrix size the following slice length checks are trivially satisfied.
73	if aTrans {
74		if len(a) < (k-1)*lda+m {
75			panic(shortA)
76		}
77	} else {
78		if len(a) < (m-1)*lda+k {
79			panic(shortA)
80		}
81	}
82	if bTrans {
83		if len(b) < (n-1)*ldb+k {
84			panic(shortB)
85		}
86	} else {
87		if len(b) < (k-1)*ldb+n {
88			panic(shortB)
89		}
90	}
91	if len(c) < (m-1)*ldc+n {
92		panic(shortC)
93	}
94
95	// Quick return if possible.
96	if (alpha == 0 || k == 0) && beta == 1 {
97		return
98	}
99
100	// scale c
101	if beta != 1 {
102		if beta == 0 {
103			for i := 0; i < m; i++ {
104				ctmp := c[i*ldc : i*ldc+n]
105				for j := range ctmp {
106					ctmp[j] = 0
107				}
108			}
109		} else {
110			for i := 0; i < m; i++ {
111				ctmp := c[i*ldc : i*ldc+n]
112				for j := range ctmp {
113					ctmp[j] *= beta
114				}
115			}
116		}
117	}
118
119	dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
120}
121
122func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
123	// dgemmParallel computes a parallel matrix multiplication by partitioning
124	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
125	// In all cases,
126	// A = [ 	A_11	A_12 ... 	A_1j
127	//			A_21	A_22 ...	A_2j
128	//				...
129	//			A_i1	A_i2 ...	A_ij]
130	//
131	// and same for B. All of the submatrix sizes are blockSize×blockSize except
132	// at the edges.
133	//
134	// In all cases, there is one dimension for each matrix along which
135	// C must be updated sequentially.
136	// Cij = \sum_k Aik Bki,	(A * B)
137	// Cij = \sum_k Aki Bkj,	(A^T * B)
138	// Cij = \sum_k Aik Bjk,	(A * B^T)
139	// Cij = \sum_k Aki Bjk,	(A^T * B^T)
140	//
141	// This code computes one {i, j} block sequentially along the k dimension,
142	// and computes all of the {i, j} blocks concurrently. This
143	// partitioning allows Cij to be updated in-place without race-conditions.
144	// Instead of launching a goroutine for each possible concurrent computation,
145	// a number of worker goroutines are created and channels are used to pass
146	// available and completed cases.
147	//
148	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
149	// multiplies, though this code does not copy matrices to attempt to eliminate
150	// cache misses.
151
152	maxKLen := k
153	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
154	if parBlocks < minParBlock {
155		// The matrix multiplication is small in the dimensions where it can be
156		// computed concurrently. Just do it in serial.
157		dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
158		return
159	}
160
161	nWorkers := runtime.GOMAXPROCS(0)
162	if parBlocks < nWorkers {
163		nWorkers = parBlocks
164	}
165	// There is a tradeoff between the workers having to wait for work
166	// and a large buffer making operations slow.
167	buf := buffMul * nWorkers
168	if buf > parBlocks {
169		buf = parBlocks
170	}
171
172	sendChan := make(chan subMul, buf)
173
174	// Launch workers. A worker receives an {i, j} submatrix of c, and computes
175	// A_ik B_ki (or the transposed version) storing the result in c_ij. When the
176	// channel is finally closed, it signals to the waitgroup that it has finished
177	// computing.
178	var wg sync.WaitGroup
179	for i := 0; i < nWorkers; i++ {
180		wg.Add(1)
181		go func() {
182			defer wg.Done()
183			for sub := range sendChan {
184				i := sub.i
185				j := sub.j
186				leni := blockSize
187				if i+leni > m {
188					leni = m - i
189				}
190				lenj := blockSize
191				if j+lenj > n {
192					lenj = n - j
193				}
194
195				cSub := sliceView64(c, ldc, i, j, leni, lenj)
196
197				// Compute A_ik B_kj for all k
198				for k := 0; k < maxKLen; k += blockSize {
199					lenk := blockSize
200					if k+lenk > maxKLen {
201						lenk = maxKLen - k
202					}
203					var aSub, bSub []float64
204					if aTrans {
205						aSub = sliceView64(a, lda, k, i, lenk, leni)
206					} else {
207						aSub = sliceView64(a, lda, i, k, leni, lenk)
208					}
209					if bTrans {
210						bSub = sliceView64(b, ldb, j, k, lenj, lenk)
211					} else {
212						bSub = sliceView64(b, ldb, k, j, lenk, lenj)
213					}
214					dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
215				}
216			}
217		}()
218	}
219
220	// Send out all of the {i, j} subblocks for computation.
221	for i := 0; i < m; i += blockSize {
222		for j := 0; j < n; j += blockSize {
223			sendChan <- subMul{
224				i: i,
225				j: j,
226			}
227		}
228	}
229	close(sendChan)
230	wg.Wait()
231}
232
233// dgemmSerial is serial matrix multiply
234func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
235	switch {
236	case !aTrans && !bTrans:
237		dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
238		return
239	case aTrans && !bTrans:
240		dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
241		return
242	case !aTrans && bTrans:
243		dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
244		return
245	case aTrans && bTrans:
246		dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
247		return
248	default:
249		panic("unreachable")
250	}
251}
252
253// dgemmSerial where neither a nor b are transposed
254func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
255	// This style is used instead of the literal [i*stride +j]) is used because
256	// approximately 5 times faster as of go 1.3.
257	for i := 0; i < m; i++ {
258		ctmp := c[i*ldc : i*ldc+n]
259		for l, v := range a[i*lda : i*lda+k] {
260			tmp := alpha * v
261			if tmp != 0 {
262				f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
263			}
264		}
265	}
266}
267
268// dgemmSerial where neither a is transposed and b is not
269func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
270	// This style is used instead of the literal [i*stride +j]) is used because
271	// approximately 5 times faster as of go 1.3.
272	for l := 0; l < k; l++ {
273		btmp := b[l*ldb : l*ldb+n]
274		for i, v := range a[l*lda : l*lda+m] {
275			tmp := alpha * v
276			if tmp != 0 {
277				ctmp := c[i*ldc : i*ldc+n]
278				f64.AxpyUnitary(tmp, btmp, ctmp)
279			}
280		}
281	}
282}
283
284// dgemmSerial where neither a is not transposed and b is
285func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
286	// This style is used instead of the literal [i*stride +j]) is used because
287	// approximately 5 times faster as of go 1.3.
288	for i := 0; i < m; i++ {
289		atmp := a[i*lda : i*lda+k]
290		ctmp := c[i*ldc : i*ldc+n]
291		for j := 0; j < n; j++ {
292			ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
293		}
294	}
295}
296
297// dgemmSerial where both are transposed
298func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
299	// This style is used instead of the literal [i*stride +j]) is used because
300	// approximately 5 times faster as of go 1.3.
301	for l := 0; l < k; l++ {
302		for i, v := range a[l*lda : l*lda+m] {
303			tmp := alpha * v
304			if tmp != 0 {
305				ctmp := c[i*ldc : i*ldc+n]
306				f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
307			}
308		}
309	}
310}
311
312func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
313	return a[i*lda+j : (i+r-1)*lda+j+c]
314}
315