1// Copyright ©2014 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"runtime"
9	"sync"
10
11	"gonum.org/v1/gonum/blas"
12	"gonum.org/v1/gonum/internal/asm/f64"
13)
14
15// Dgemm performs one of the matrix-matrix operations
16//  C = alpha * A * B + beta * C
17//  C = alpha * Aᵀ * B + beta * C
18//  C = alpha * A * Bᵀ + beta * C
19//  C = alpha * Aᵀ * Bᵀ + beta * C
20// where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is
21// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or
22// B are transposed.
23func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) {
24	switch tA {
25	default:
26		panic(badTranspose)
27	case blas.NoTrans, blas.Trans, blas.ConjTrans:
28	}
29	switch tB {
30	default:
31		panic(badTranspose)
32	case blas.NoTrans, blas.Trans, blas.ConjTrans:
33	}
34	if m < 0 {
35		panic(mLT0)
36	}
37	if n < 0 {
38		panic(nLT0)
39	}
40	if k < 0 {
41		panic(kLT0)
42	}
43	aTrans := tA == blas.Trans || tA == blas.ConjTrans
44	if aTrans {
45		if lda < max(1, m) {
46			panic(badLdA)
47		}
48	} else {
49		if lda < max(1, k) {
50			panic(badLdA)
51		}
52	}
53	bTrans := tB == blas.Trans || tB == blas.ConjTrans
54	if bTrans {
55		if ldb < max(1, k) {
56			panic(badLdB)
57		}
58	} else {
59		if ldb < max(1, n) {
60			panic(badLdB)
61		}
62	}
63	if ldc < max(1, n) {
64		panic(badLdC)
65	}
66
67	// Quick return if possible.
68	if m == 0 || n == 0 {
69		return
70	}
71
72	// For zero matrix size the following slice length checks are trivially satisfied.
73	if aTrans {
74		if len(a) < (k-1)*lda+m {
75			panic(shortA)
76		}
77	} else {
78		if len(a) < (m-1)*lda+k {
79			panic(shortA)
80		}
81	}
82	if bTrans {
83		if len(b) < (n-1)*ldb+k {
84			panic(shortB)
85		}
86	} else {
87		if len(b) < (k-1)*ldb+n {
88			panic(shortB)
89		}
90	}
91	if len(c) < (m-1)*ldc+n {
92		panic(shortC)
93	}
94
95	// Quick return if possible.
96	if (alpha == 0 || k == 0) && beta == 1 {
97		return
98	}
99
100	// scale c
101	if beta != 1 {
102		if beta == 0 {
103			for i := 0; i < m; i++ {
104				ctmp := c[i*ldc : i*ldc+n]
105				for j := range ctmp {
106					ctmp[j] = 0
107				}
108			}
109		} else {
110			for i := 0; i < m; i++ {
111				ctmp := c[i*ldc : i*ldc+n]
112				for j := range ctmp {
113					ctmp[j] *= beta
114				}
115			}
116		}
117	}
118
119	dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
120}
121
122func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
123	// dgemmParallel computes a parallel matrix multiplication by partitioning
124	// a and b into sub-blocks, and updating c with the multiplication of the sub-block
125	// In all cases,
126	// A = [ 	A_11	A_12 ... 	A_1j
127	//			A_21	A_22 ...	A_2j
128	//				...
129	//			A_i1	A_i2 ...	A_ij]
130	//
131	// and same for B. All of the submatrix sizes are blockSize×blockSize except
132	// at the edges.
133	//
134	// In all cases, there is one dimension for each matrix along which
135	// C must be updated sequentially.
136	// Cij = \sum_k Aik Bki,	(A * B)
137	// Cij = \sum_k Aki Bkj,	(Aᵀ * B)
138	// Cij = \sum_k Aik Bjk,	(A * Bᵀ)
139	// Cij = \sum_k Aki Bjk,	(Aᵀ * Bᵀ)
140	//
141	// This code computes one {i, j} block sequentially along the k dimension,
142	// and computes all of the {i, j} blocks concurrently. This
143	// partitioning allows Cij to be updated in-place without race-conditions.
144	// Instead of launching a goroutine for each possible concurrent computation,
145	// a number of worker goroutines are created and channels are used to pass
146	// available and completed cases.
147	//
148	// http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix
149	// multiplies, though this code does not copy matrices to attempt to eliminate
150	// cache misses.
151
152	maxKLen := k
153	parBlocks := blocks(m, blockSize) * blocks(n, blockSize)
154	if parBlocks < minParBlock {
155		// The matrix multiplication is small in the dimensions where it can be
156		// computed concurrently. Just do it in serial.
157		dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha)
158		return
159	}
160
161	// workerLimit acts a number of maximum concurrent workers,
162	// with the limit set to the number of procs available.
163	workerLimit := make(chan struct{}, runtime.GOMAXPROCS(0))
164
165	// wg is used to wait for all
166	var wg sync.WaitGroup
167	wg.Add(parBlocks)
168	defer wg.Wait()
169
170	for i := 0; i < m; i += blockSize {
171		for j := 0; j < n; j += blockSize {
172			workerLimit <- struct{}{}
173			go func(i, j int) {
174				defer func() {
175					wg.Done()
176					<-workerLimit
177				}()
178
179				leni := blockSize
180				if i+leni > m {
181					leni = m - i
182				}
183				lenj := blockSize
184				if j+lenj > n {
185					lenj = n - j
186				}
187
188				cSub := sliceView64(c, ldc, i, j, leni, lenj)
189
190				// Compute A_ik B_kj for all k
191				for k := 0; k < maxKLen; k += blockSize {
192					lenk := blockSize
193					if k+lenk > maxKLen {
194						lenk = maxKLen - k
195					}
196					var aSub, bSub []float64
197					if aTrans {
198						aSub = sliceView64(a, lda, k, i, lenk, leni)
199					} else {
200						aSub = sliceView64(a, lda, i, k, leni, lenk)
201					}
202					if bTrans {
203						bSub = sliceView64(b, ldb, j, k, lenj, lenk)
204					} else {
205						bSub = sliceView64(b, ldb, k, j, lenk, lenj)
206					}
207					dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha)
208				}
209			}(i, j)
210		}
211	}
212}
213
214// dgemmSerial is serial matrix multiply
215func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
216	switch {
217	case !aTrans && !bTrans:
218		dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
219		return
220	case aTrans && !bTrans:
221		dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha)
222		return
223	case !aTrans && bTrans:
224		dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
225		return
226	case aTrans && bTrans:
227		dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha)
228		return
229	default:
230		panic("unreachable")
231	}
232}
233
234// dgemmSerial where neither a nor b are transposed
235func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
236	// This style is used instead of the literal [i*stride +j]) is used because
237	// approximately 5 times faster as of go 1.3.
238	for i := 0; i < m; i++ {
239		ctmp := c[i*ldc : i*ldc+n]
240		for l, v := range a[i*lda : i*lda+k] {
241			tmp := alpha * v
242			if tmp != 0 {
243				f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp)
244			}
245		}
246	}
247}
248
249// dgemmSerial where neither a is transposed and b is not
250func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
251	// This style is used instead of the literal [i*stride +j]) is used because
252	// approximately 5 times faster as of go 1.3.
253	for l := 0; l < k; l++ {
254		btmp := b[l*ldb : l*ldb+n]
255		for i, v := range a[l*lda : l*lda+m] {
256			tmp := alpha * v
257			if tmp != 0 {
258				ctmp := c[i*ldc : i*ldc+n]
259				f64.AxpyUnitary(tmp, btmp, ctmp)
260			}
261		}
262	}
263}
264
265// dgemmSerial where neither a is not transposed and b is
266func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
267	// This style is used instead of the literal [i*stride +j]) is used because
268	// approximately 5 times faster as of go 1.3.
269	for i := 0; i < m; i++ {
270		atmp := a[i*lda : i*lda+k]
271		ctmp := c[i*ldc : i*ldc+n]
272		for j := 0; j < n; j++ {
273			ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k])
274		}
275	}
276}
277
278// dgemmSerial where both are transposed
279func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) {
280	// This style is used instead of the literal [i*stride +j]) is used because
281	// approximately 5 times faster as of go 1.3.
282	for l := 0; l < k; l++ {
283		for i, v := range a[l*lda : l*lda+m] {
284			tmp := alpha * v
285			if tmp != 0 {
286				ctmp := c[i*ldc : i*ldc+n]
287				f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0)
288			}
289		}
290	}
291}
292
293func sliceView64(a []float64, lda, i, j, r, c int) []float64 {
294	return a[i*lda+j : (i+r-1)*lda+j+c]
295}
296