1// Copyright ©2014 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package gonum 6 7import ( 8 "runtime" 9 "sync" 10 11 "gonum.org/v1/gonum/blas" 12 "gonum.org/v1/gonum/internal/asm/f64" 13) 14 15// Dgemm performs one of the matrix-matrix operations 16// C = alpha * A * B + beta * C 17// C = alpha * A^T * B + beta * C 18// C = alpha * A * B^T + beta * C 19// C = alpha * A^T * B^T + beta * C 20// where A is an m×k or k×m dense matrix, B is an n×k or k×n dense matrix, C is 21// an m×n matrix, and alpha and beta are scalars. tA and tB specify whether A or 22// B are transposed. 23func (Implementation) Dgemm(tA, tB blas.Transpose, m, n, k int, alpha float64, a []float64, lda int, b []float64, ldb int, beta float64, c []float64, ldc int) { 24 switch tA { 25 default: 26 panic(badTranspose) 27 case blas.NoTrans, blas.Trans, blas.ConjTrans: 28 } 29 switch tB { 30 default: 31 panic(badTranspose) 32 case blas.NoTrans, blas.Trans, blas.ConjTrans: 33 } 34 if m < 0 { 35 panic(mLT0) 36 } 37 if n < 0 { 38 panic(nLT0) 39 } 40 if k < 0 { 41 panic(kLT0) 42 } 43 aTrans := tA == blas.Trans || tA == blas.ConjTrans 44 if aTrans { 45 if lda < max(1, m) { 46 panic(badLdA) 47 } 48 } else { 49 if lda < max(1, k) { 50 panic(badLdA) 51 } 52 } 53 bTrans := tB == blas.Trans || tB == blas.ConjTrans 54 if bTrans { 55 if ldb < max(1, k) { 56 panic(badLdB) 57 } 58 } else { 59 if ldb < max(1, n) { 60 panic(badLdB) 61 } 62 } 63 if ldc < max(1, n) { 64 panic(badLdC) 65 } 66 67 // Quick return if possible. 68 if m == 0 || n == 0 { 69 return 70 } 71 72 // For zero matrix size the following slice length checks are trivially satisfied. 73 if aTrans { 74 if len(a) < (k-1)*lda+m { 75 panic(shortA) 76 } 77 } else { 78 if len(a) < (m-1)*lda+k { 79 panic(shortA) 80 } 81 } 82 if bTrans { 83 if len(b) < (n-1)*ldb+k { 84 panic(shortB) 85 } 86 } else { 87 if len(b) < (k-1)*ldb+n { 88 panic(shortB) 89 } 90 } 91 if len(c) < (m-1)*ldc+n { 92 panic(shortC) 93 } 94 95 // Quick return if possible. 96 if (alpha == 0 || k == 0) && beta == 1 { 97 return 98 } 99 100 // scale c 101 if beta != 1 { 102 if beta == 0 { 103 for i := 0; i < m; i++ { 104 ctmp := c[i*ldc : i*ldc+n] 105 for j := range ctmp { 106 ctmp[j] = 0 107 } 108 } 109 } else { 110 for i := 0; i < m; i++ { 111 ctmp := c[i*ldc : i*ldc+n] 112 for j := range ctmp { 113 ctmp[j] *= beta 114 } 115 } 116 } 117 } 118 119 dgemmParallel(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 120} 121 122func dgemmParallel(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 123 // dgemmParallel computes a parallel matrix multiplication by partitioning 124 // a and b into sub-blocks, and updating c with the multiplication of the sub-block 125 // In all cases, 126 // A = [ A_11 A_12 ... A_1j 127 // A_21 A_22 ... A_2j 128 // ... 129 // A_i1 A_i2 ... A_ij] 130 // 131 // and same for B. All of the submatrix sizes are blockSize×blockSize except 132 // at the edges. 133 // 134 // In all cases, there is one dimension for each matrix along which 135 // C must be updated sequentially. 136 // Cij = \sum_k Aik Bki, (A * B) 137 // Cij = \sum_k Aki Bkj, (A^T * B) 138 // Cij = \sum_k Aik Bjk, (A * B^T) 139 // Cij = \sum_k Aki Bjk, (A^T * B^T) 140 // 141 // This code computes one {i, j} block sequentially along the k dimension, 142 // and computes all of the {i, j} blocks concurrently. This 143 // partitioning allows Cij to be updated in-place without race-conditions. 144 // Instead of launching a goroutine for each possible concurrent computation, 145 // a number of worker goroutines are created and channels are used to pass 146 // available and completed cases. 147 // 148 // http://alexkr.com/docs/matrixmult.pdf is a good reference on matrix-matrix 149 // multiplies, though this code does not copy matrices to attempt to eliminate 150 // cache misses. 151 152 maxKLen := k 153 parBlocks := blocks(m, blockSize) * blocks(n, blockSize) 154 if parBlocks < minParBlock { 155 // The matrix multiplication is small in the dimensions where it can be 156 // computed concurrently. Just do it in serial. 157 dgemmSerial(aTrans, bTrans, m, n, k, a, lda, b, ldb, c, ldc, alpha) 158 return 159 } 160 161 nWorkers := runtime.GOMAXPROCS(0) 162 if parBlocks < nWorkers { 163 nWorkers = parBlocks 164 } 165 // There is a tradeoff between the workers having to wait for work 166 // and a large buffer making operations slow. 167 buf := buffMul * nWorkers 168 if buf > parBlocks { 169 buf = parBlocks 170 } 171 172 sendChan := make(chan subMul, buf) 173 174 // Launch workers. A worker receives an {i, j} submatrix of c, and computes 175 // A_ik B_ki (or the transposed version) storing the result in c_ij. When the 176 // channel is finally closed, it signals to the waitgroup that it has finished 177 // computing. 178 var wg sync.WaitGroup 179 for i := 0; i < nWorkers; i++ { 180 wg.Add(1) 181 go func() { 182 defer wg.Done() 183 for sub := range sendChan { 184 i := sub.i 185 j := sub.j 186 leni := blockSize 187 if i+leni > m { 188 leni = m - i 189 } 190 lenj := blockSize 191 if j+lenj > n { 192 lenj = n - j 193 } 194 195 cSub := sliceView64(c, ldc, i, j, leni, lenj) 196 197 // Compute A_ik B_kj for all k 198 for k := 0; k < maxKLen; k += blockSize { 199 lenk := blockSize 200 if k+lenk > maxKLen { 201 lenk = maxKLen - k 202 } 203 var aSub, bSub []float64 204 if aTrans { 205 aSub = sliceView64(a, lda, k, i, lenk, leni) 206 } else { 207 aSub = sliceView64(a, lda, i, k, leni, lenk) 208 } 209 if bTrans { 210 bSub = sliceView64(b, ldb, j, k, lenj, lenk) 211 } else { 212 bSub = sliceView64(b, ldb, k, j, lenk, lenj) 213 } 214 dgemmSerial(aTrans, bTrans, leni, lenj, lenk, aSub, lda, bSub, ldb, cSub, ldc, alpha) 215 } 216 } 217 }() 218 } 219 220 // Send out all of the {i, j} subblocks for computation. 221 for i := 0; i < m; i += blockSize { 222 for j := 0; j < n; j += blockSize { 223 sendChan <- subMul{ 224 i: i, 225 j: j, 226 } 227 } 228 } 229 close(sendChan) 230 wg.Wait() 231} 232 233// dgemmSerial is serial matrix multiply 234func dgemmSerial(aTrans, bTrans bool, m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 235 switch { 236 case !aTrans && !bTrans: 237 dgemmSerialNotNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 238 return 239 case aTrans && !bTrans: 240 dgemmSerialTransNot(m, n, k, a, lda, b, ldb, c, ldc, alpha) 241 return 242 case !aTrans && bTrans: 243 dgemmSerialNotTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 244 return 245 case aTrans && bTrans: 246 dgemmSerialTransTrans(m, n, k, a, lda, b, ldb, c, ldc, alpha) 247 return 248 default: 249 panic("unreachable") 250 } 251} 252 253// dgemmSerial where neither a nor b are transposed 254func dgemmSerialNotNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 255 // This style is used instead of the literal [i*stride +j]) is used because 256 // approximately 5 times faster as of go 1.3. 257 for i := 0; i < m; i++ { 258 ctmp := c[i*ldc : i*ldc+n] 259 for l, v := range a[i*lda : i*lda+k] { 260 tmp := alpha * v 261 if tmp != 0 { 262 f64.AxpyUnitary(tmp, b[l*ldb:l*ldb+n], ctmp) 263 } 264 } 265 } 266} 267 268// dgemmSerial where neither a is transposed and b is not 269func dgemmSerialTransNot(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 270 // This style is used instead of the literal [i*stride +j]) is used because 271 // approximately 5 times faster as of go 1.3. 272 for l := 0; l < k; l++ { 273 btmp := b[l*ldb : l*ldb+n] 274 for i, v := range a[l*lda : l*lda+m] { 275 tmp := alpha * v 276 if tmp != 0 { 277 ctmp := c[i*ldc : i*ldc+n] 278 f64.AxpyUnitary(tmp, btmp, ctmp) 279 } 280 } 281 } 282} 283 284// dgemmSerial where neither a is not transposed and b is 285func dgemmSerialNotTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 286 // This style is used instead of the literal [i*stride +j]) is used because 287 // approximately 5 times faster as of go 1.3. 288 for i := 0; i < m; i++ { 289 atmp := a[i*lda : i*lda+k] 290 ctmp := c[i*ldc : i*ldc+n] 291 for j := 0; j < n; j++ { 292 ctmp[j] += alpha * f64.DotUnitary(atmp, b[j*ldb:j*ldb+k]) 293 } 294 } 295} 296 297// dgemmSerial where both are transposed 298func dgemmSerialTransTrans(m, n, k int, a []float64, lda int, b []float64, ldb int, c []float64, ldc int, alpha float64) { 299 // This style is used instead of the literal [i*stride +j]) is used because 300 // approximately 5 times faster as of go 1.3. 301 for l := 0; l < k; l++ { 302 for i, v := range a[l*lda : l*lda+m] { 303 tmp := alpha * v 304 if tmp != 0 { 305 ctmp := c[i*ldc : i*ldc+n] 306 f64.AxpyInc(tmp, b[l:], ctmp, uintptr(n), uintptr(ldb), 1, 0, 0) 307 } 308 } 309 } 310} 311 312func sliceView64(a []float64, lda, i, j, r, c int) []float64 { 313 return a[i*lda+j : (i+r-1)*lda+j+c] 314} 315