1// Copyright ©2019 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gonum
6
7import (
8	"math/cmplx"
9
10	"gonum.org/v1/gonum/blas"
11	"gonum.org/v1/gonum/internal/asm/c128"
12)
13
14var _ blas.Complex128Level3 = Implementation{}
15
16// Zgemm performs one of the matrix-matrix operations
17//  C = alpha * op(A) * op(B) + beta * C
18// where op(X) is one of
19//  op(X) = X  or  op(X) = Xᵀ  or  op(X) = Xᴴ,
20// alpha and beta are scalars, and A, B and C are matrices, with op(A) an m×k matrix,
21// op(B) a k×n matrix and C an m×n matrix.
22func (Implementation) Zgemm(tA, tB blas.Transpose, m, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
23	switch tA {
24	default:
25		panic(badTranspose)
26	case blas.NoTrans, blas.Trans, blas.ConjTrans:
27	}
28	switch tB {
29	default:
30		panic(badTranspose)
31	case blas.NoTrans, blas.Trans, blas.ConjTrans:
32	}
33	switch {
34	case m < 0:
35		panic(mLT0)
36	case n < 0:
37		panic(nLT0)
38	case k < 0:
39		panic(kLT0)
40	}
41	rowA, colA := m, k
42	if tA != blas.NoTrans {
43		rowA, colA = k, m
44	}
45	if lda < max(1, colA) {
46		panic(badLdA)
47	}
48	rowB, colB := k, n
49	if tB != blas.NoTrans {
50		rowB, colB = n, k
51	}
52	if ldb < max(1, colB) {
53		panic(badLdB)
54	}
55	if ldc < max(1, n) {
56		panic(badLdC)
57	}
58
59	// Quick return if possible.
60	if m == 0 || n == 0 {
61		return
62	}
63
64	// For zero matrix size the following slice length checks are trivially satisfied.
65	if len(a) < (rowA-1)*lda+colA {
66		panic(shortA)
67	}
68	if len(b) < (rowB-1)*ldb+colB {
69		panic(shortB)
70	}
71	if len(c) < (m-1)*ldc+n {
72		panic(shortC)
73	}
74
75	// Quick return if possible.
76	if (alpha == 0 || k == 0) && beta == 1 {
77		return
78	}
79
80	if alpha == 0 {
81		if beta == 0 {
82			for i := 0; i < m; i++ {
83				for j := 0; j < n; j++ {
84					c[i*ldc+j] = 0
85				}
86			}
87		} else {
88			for i := 0; i < m; i++ {
89				for j := 0; j < n; j++ {
90					c[i*ldc+j] *= beta
91				}
92			}
93		}
94		return
95	}
96
97	switch tA {
98	case blas.NoTrans:
99		switch tB {
100		case blas.NoTrans:
101			// Form  C = alpha * A * B + beta * C.
102			for i := 0; i < m; i++ {
103				switch {
104				case beta == 0:
105					for j := 0; j < n; j++ {
106						c[i*ldc+j] = 0
107					}
108				case beta != 1:
109					for j := 0; j < n; j++ {
110						c[i*ldc+j] *= beta
111					}
112				}
113				for l := 0; l < k; l++ {
114					tmp := alpha * a[i*lda+l]
115					for j := 0; j < n; j++ {
116						c[i*ldc+j] += tmp * b[l*ldb+j]
117					}
118				}
119			}
120		case blas.Trans:
121			// Form  C = alpha * A * Bᵀ + beta * C.
122			for i := 0; i < m; i++ {
123				switch {
124				case beta == 0:
125					for j := 0; j < n; j++ {
126						c[i*ldc+j] = 0
127					}
128				case beta != 1:
129					for j := 0; j < n; j++ {
130						c[i*ldc+j] *= beta
131					}
132				}
133				for l := 0; l < k; l++ {
134					tmp := alpha * a[i*lda+l]
135					for j := 0; j < n; j++ {
136						c[i*ldc+j] += tmp * b[j*ldb+l]
137					}
138				}
139			}
140		case blas.ConjTrans:
141			// Form  C = alpha * A * Bᴴ + beta * C.
142			for i := 0; i < m; i++ {
143				switch {
144				case beta == 0:
145					for j := 0; j < n; j++ {
146						c[i*ldc+j] = 0
147					}
148				case beta != 1:
149					for j := 0; j < n; j++ {
150						c[i*ldc+j] *= beta
151					}
152				}
153				for l := 0; l < k; l++ {
154					tmp := alpha * a[i*lda+l]
155					for j := 0; j < n; j++ {
156						c[i*ldc+j] += tmp * cmplx.Conj(b[j*ldb+l])
157					}
158				}
159			}
160		}
161	case blas.Trans:
162		switch tB {
163		case blas.NoTrans:
164			// Form  C = alpha * Aᵀ * B + beta * C.
165			for i := 0; i < m; i++ {
166				for j := 0; j < n; j++ {
167					var tmp complex128
168					for l := 0; l < k; l++ {
169						tmp += a[l*lda+i] * b[l*ldb+j]
170					}
171					if beta == 0 {
172						c[i*ldc+j] = alpha * tmp
173					} else {
174						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
175					}
176				}
177			}
178		case blas.Trans:
179			// Form  C = alpha * Aᵀ * Bᵀ + beta * C.
180			for i := 0; i < m; i++ {
181				for j := 0; j < n; j++ {
182					var tmp complex128
183					for l := 0; l < k; l++ {
184						tmp += a[l*lda+i] * b[j*ldb+l]
185					}
186					if beta == 0 {
187						c[i*ldc+j] = alpha * tmp
188					} else {
189						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
190					}
191				}
192			}
193		case blas.ConjTrans:
194			// Form  C = alpha * Aᵀ * Bᴴ + beta * C.
195			for i := 0; i < m; i++ {
196				for j := 0; j < n; j++ {
197					var tmp complex128
198					for l := 0; l < k; l++ {
199						tmp += a[l*lda+i] * cmplx.Conj(b[j*ldb+l])
200					}
201					if beta == 0 {
202						c[i*ldc+j] = alpha * tmp
203					} else {
204						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
205					}
206				}
207			}
208		}
209	case blas.ConjTrans:
210		switch tB {
211		case blas.NoTrans:
212			// Form  C = alpha * Aᴴ * B + beta * C.
213			for i := 0; i < m; i++ {
214				for j := 0; j < n; j++ {
215					var tmp complex128
216					for l := 0; l < k; l++ {
217						tmp += cmplx.Conj(a[l*lda+i]) * b[l*ldb+j]
218					}
219					if beta == 0 {
220						c[i*ldc+j] = alpha * tmp
221					} else {
222						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
223					}
224				}
225			}
226		case blas.Trans:
227			// Form  C = alpha * Aᴴ * Bᵀ + beta * C.
228			for i := 0; i < m; i++ {
229				for j := 0; j < n; j++ {
230					var tmp complex128
231					for l := 0; l < k; l++ {
232						tmp += cmplx.Conj(a[l*lda+i]) * b[j*ldb+l]
233					}
234					if beta == 0 {
235						c[i*ldc+j] = alpha * tmp
236					} else {
237						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
238					}
239				}
240			}
241		case blas.ConjTrans:
242			// Form  C = alpha * Aᴴ * Bᴴ + beta * C.
243			for i := 0; i < m; i++ {
244				for j := 0; j < n; j++ {
245					var tmp complex128
246					for l := 0; l < k; l++ {
247						tmp += cmplx.Conj(a[l*lda+i]) * cmplx.Conj(b[j*ldb+l])
248					}
249					if beta == 0 {
250						c[i*ldc+j] = alpha * tmp
251					} else {
252						c[i*ldc+j] = alpha*tmp + beta*c[i*ldc+j]
253					}
254				}
255			}
256		}
257	}
258}
259
260// Zhemm performs one of the matrix-matrix operations
261//  C = alpha*A*B + beta*C  if side == blas.Left
262//  C = alpha*B*A + beta*C  if side == blas.Right
263// where alpha and beta are scalars, A is an m×m or n×n hermitian matrix and B
264// and C are m×n matrices. The imaginary parts of the diagonal elements of A are
265// assumed to be zero.
266func (Implementation) Zhemm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
267	na := m
268	if side == blas.Right {
269		na = n
270	}
271	switch {
272	case side != blas.Left && side != blas.Right:
273		panic(badSide)
274	case uplo != blas.Lower && uplo != blas.Upper:
275		panic(badUplo)
276	case m < 0:
277		panic(mLT0)
278	case n < 0:
279		panic(nLT0)
280	case lda < max(1, na):
281		panic(badLdA)
282	case ldb < max(1, n):
283		panic(badLdB)
284	case ldc < max(1, n):
285		panic(badLdC)
286	}
287
288	// Quick return if possible.
289	if m == 0 || n == 0 {
290		return
291	}
292
293	// For zero matrix size the following slice length checks are trivially satisfied.
294	if len(a) < lda*(na-1)+na {
295		panic(shortA)
296	}
297	if len(b) < ldb*(m-1)+n {
298		panic(shortB)
299	}
300	if len(c) < ldc*(m-1)+n {
301		panic(shortC)
302	}
303
304	// Quick return if possible.
305	if alpha == 0 && beta == 1 {
306		return
307	}
308
309	if alpha == 0 {
310		if beta == 0 {
311			for i := 0; i < m; i++ {
312				ci := c[i*ldc : i*ldc+n]
313				for j := range ci {
314					ci[j] = 0
315				}
316			}
317		} else {
318			for i := 0; i < m; i++ {
319				ci := c[i*ldc : i*ldc+n]
320				c128.ScalUnitary(beta, ci)
321			}
322		}
323		return
324	}
325
326	if side == blas.Left {
327		// Form  C = alpha*A*B + beta*C.
328		for i := 0; i < m; i++ {
329			atmp := alpha * complex(real(a[i*lda+i]), 0)
330			bi := b[i*ldb : i*ldb+n]
331			ci := c[i*ldc : i*ldc+n]
332			if beta == 0 {
333				for j, bij := range bi {
334					ci[j] = atmp * bij
335				}
336			} else {
337				for j, bij := range bi {
338					ci[j] = atmp*bij + beta*ci[j]
339				}
340			}
341			if uplo == blas.Upper {
342				for k := 0; k < i; k++ {
343					atmp = alpha * cmplx.Conj(a[k*lda+i])
344					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
345				}
346				for k := i + 1; k < m; k++ {
347					atmp = alpha * a[i*lda+k]
348					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
349				}
350			} else {
351				for k := 0; k < i; k++ {
352					atmp = alpha * a[i*lda+k]
353					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
354				}
355				for k := i + 1; k < m; k++ {
356					atmp = alpha * cmplx.Conj(a[k*lda+i])
357					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
358				}
359			}
360		}
361	} else {
362		// Form  C = alpha*B*A + beta*C.
363		if uplo == blas.Upper {
364			for i := 0; i < m; i++ {
365				for j := n - 1; j >= 0; j-- {
366					abij := alpha * b[i*ldb+j]
367					aj := a[j*lda+j+1 : j*lda+n]
368					bi := b[i*ldb+j+1 : i*ldb+n]
369					ci := c[i*ldc+j+1 : i*ldc+n]
370					var tmp complex128
371					for k, ajk := range aj {
372						ci[k] += abij * ajk
373						tmp += bi[k] * cmplx.Conj(ajk)
374					}
375					ajj := complex(real(a[j*lda+j]), 0)
376					if beta == 0 {
377						c[i*ldc+j] = abij*ajj + alpha*tmp
378					} else {
379						c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
380					}
381				}
382			}
383		} else {
384			for i := 0; i < m; i++ {
385				for j := 0; j < n; j++ {
386					abij := alpha * b[i*ldb+j]
387					aj := a[j*lda : j*lda+j]
388					bi := b[i*ldb : i*ldb+j]
389					ci := c[i*ldc : i*ldc+j]
390					var tmp complex128
391					for k, ajk := range aj {
392						ci[k] += abij * ajk
393						tmp += bi[k] * cmplx.Conj(ajk)
394					}
395					ajj := complex(real(a[j*lda+j]), 0)
396					if beta == 0 {
397						c[i*ldc+j] = abij*ajj + alpha*tmp
398					} else {
399						c[i*ldc+j] = abij*ajj + alpha*tmp + beta*c[i*ldc+j]
400					}
401				}
402			}
403		}
404	}
405}
406
407// Zherk performs one of the hermitian rank-k operations
408//  C = alpha*A*Aᴴ + beta*C  if trans == blas.NoTrans
409//  C = alpha*Aᴴ*A + beta*C  if trans == blas.ConjTrans
410// where alpha and beta are real scalars, C is an n×n hermitian matrix and A is
411// an n×k matrix in the first case and a k×n matrix in the second case.
412//
413// The imaginary parts of the diagonal elements of C are assumed to be zero, and
414// on return they will be set to zero.
415func (Implementation) Zherk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha float64, a []complex128, lda int, beta float64, c []complex128, ldc int) {
416	var rowA, colA int
417	switch trans {
418	default:
419		panic(badTranspose)
420	case blas.NoTrans:
421		rowA, colA = n, k
422	case blas.ConjTrans:
423		rowA, colA = k, n
424	}
425	switch {
426	case uplo != blas.Lower && uplo != blas.Upper:
427		panic(badUplo)
428	case n < 0:
429		panic(nLT0)
430	case k < 0:
431		panic(kLT0)
432	case lda < max(1, colA):
433		panic(badLdA)
434	case ldc < max(1, n):
435		panic(badLdC)
436	}
437
438	// Quick return if possible.
439	if n == 0 {
440		return
441	}
442
443	// For zero matrix size the following slice length checks are trivially satisfied.
444	if len(a) < (rowA-1)*lda+colA {
445		panic(shortA)
446	}
447	if len(c) < (n-1)*ldc+n {
448		panic(shortC)
449	}
450
451	// Quick return if possible.
452	if (alpha == 0 || k == 0) && beta == 1 {
453		return
454	}
455
456	if alpha == 0 {
457		if uplo == blas.Upper {
458			if beta == 0 {
459				for i := 0; i < n; i++ {
460					ci := c[i*ldc+i : i*ldc+n]
461					for j := range ci {
462						ci[j] = 0
463					}
464				}
465			} else {
466				for i := 0; i < n; i++ {
467					ci := c[i*ldc+i : i*ldc+n]
468					ci[0] = complex(beta*real(ci[0]), 0)
469					if i != n-1 {
470						c128.DscalUnitary(beta, ci[1:])
471					}
472				}
473			}
474		} else {
475			if beta == 0 {
476				for i := 0; i < n; i++ {
477					ci := c[i*ldc : i*ldc+i+1]
478					for j := range ci {
479						ci[j] = 0
480					}
481				}
482			} else {
483				for i := 0; i < n; i++ {
484					ci := c[i*ldc : i*ldc+i+1]
485					if i != 0 {
486						c128.DscalUnitary(beta, ci[:i])
487					}
488					ci[i] = complex(beta*real(ci[i]), 0)
489				}
490			}
491		}
492		return
493	}
494
495	calpha := complex(alpha, 0)
496	if trans == blas.NoTrans {
497		// Form  C = alpha*A*Aᴴ + beta*C.
498		cbeta := complex(beta, 0)
499		if uplo == blas.Upper {
500			for i := 0; i < n; i++ {
501				ci := c[i*ldc+i : i*ldc+n]
502				ai := a[i*lda : i*lda+k]
503				switch {
504				case beta == 0:
505					// Handle the i-th diagonal element of C.
506					ci[0] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
507					// Handle the remaining elements on the i-th row of C.
508					for jc := range ci[1:] {
509						j := i + 1 + jc
510						ci[jc+1] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
511					}
512				case beta != 1:
513					cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[0]
514					ci[0] = complex(real(cii), 0)
515					for jc, cij := range ci[1:] {
516						j := i + 1 + jc
517						ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
518					}
519				default:
520					cii := calpha*c128.DotcUnitary(ai, ai) + ci[0]
521					ci[0] = complex(real(cii), 0)
522					for jc, cij := range ci[1:] {
523						j := i + 1 + jc
524						ci[jc+1] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
525					}
526				}
527			}
528		} else {
529			for i := 0; i < n; i++ {
530				ci := c[i*ldc : i*ldc+i+1]
531				ai := a[i*lda : i*lda+k]
532				switch {
533				case beta == 0:
534					// Handle the first i-1 elements on the i-th row of C.
535					for j := range ci[:i] {
536						ci[j] = calpha * c128.DotcUnitary(a[j*lda:j*lda+k], ai)
537					}
538					// Handle the i-th diagonal element of C.
539					ci[i] = complex(alpha*real(c128.DotcUnitary(ai, ai)), 0)
540				case beta != 1:
541					for j, cij := range ci[:i] {
542						ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cbeta*cij
543					}
544					cii := calpha*c128.DotcUnitary(ai, ai) + cbeta*ci[i]
545					ci[i] = complex(real(cii), 0)
546				default:
547					for j, cij := range ci[:i] {
548						ci[j] = calpha*c128.DotcUnitary(a[j*lda:j*lda+k], ai) + cij
549					}
550					cii := calpha*c128.DotcUnitary(ai, ai) + ci[i]
551					ci[i] = complex(real(cii), 0)
552				}
553			}
554		}
555	} else {
556		// Form  C = alpha*Aᴴ*A + beta*C.
557		if uplo == blas.Upper {
558			for i := 0; i < n; i++ {
559				ci := c[i*ldc+i : i*ldc+n]
560				switch {
561				case beta == 0:
562					for jc := range ci {
563						ci[jc] = 0
564					}
565				case beta != 1:
566					c128.DscalUnitary(beta, ci)
567					ci[0] = complex(real(ci[0]), 0)
568				default:
569					ci[0] = complex(real(ci[0]), 0)
570				}
571				for j := 0; j < k; j++ {
572					aji := cmplx.Conj(a[j*lda+i])
573					if aji != 0 {
574						c128.AxpyUnitary(calpha*aji, a[j*lda+i:j*lda+n], ci)
575					}
576				}
577				c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
578			}
579		} else {
580			for i := 0; i < n; i++ {
581				ci := c[i*ldc : i*ldc+i+1]
582				switch {
583				case beta == 0:
584					for j := range ci {
585						ci[j] = 0
586					}
587				case beta != 1:
588					c128.DscalUnitary(beta, ci)
589					ci[i] = complex(real(ci[i]), 0)
590				default:
591					ci[i] = complex(real(ci[i]), 0)
592				}
593				for j := 0; j < k; j++ {
594					aji := cmplx.Conj(a[j*lda+i])
595					if aji != 0 {
596						c128.AxpyUnitary(calpha*aji, a[j*lda:j*lda+i+1], ci)
597					}
598				}
599				c[i*ldc+i] = complex(real(c[i*ldc+i]), 0)
600			}
601		}
602	}
603}
604
605// Zher2k performs one of the hermitian rank-2k operations
606//  C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C  if trans == blas.NoTrans
607//  C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C  if trans == blas.ConjTrans
608// where alpha and beta are scalars with beta real, C is an n×n hermitian matrix
609// and A and B are n×k matrices in the first case and k×n matrices in the second case.
610//
611// The imaginary parts of the diagonal elements of C are assumed to be zero, and
612// on return they will be set to zero.
613func (Implementation) Zher2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta float64, c []complex128, ldc int) {
614	var row, col int
615	switch trans {
616	default:
617		panic(badTranspose)
618	case blas.NoTrans:
619		row, col = n, k
620	case blas.ConjTrans:
621		row, col = k, n
622	}
623	switch {
624	case uplo != blas.Lower && uplo != blas.Upper:
625		panic(badUplo)
626	case n < 0:
627		panic(nLT0)
628	case k < 0:
629		panic(kLT0)
630	case lda < max(1, col):
631		panic(badLdA)
632	case ldb < max(1, col):
633		panic(badLdB)
634	case ldc < max(1, n):
635		panic(badLdC)
636	}
637
638	// Quick return if possible.
639	if n == 0 {
640		return
641	}
642
643	// For zero matrix size the following slice length checks are trivially satisfied.
644	if len(a) < (row-1)*lda+col {
645		panic(shortA)
646	}
647	if len(b) < (row-1)*ldb+col {
648		panic(shortB)
649	}
650	if len(c) < (n-1)*ldc+n {
651		panic(shortC)
652	}
653
654	// Quick return if possible.
655	if (alpha == 0 || k == 0) && beta == 1 {
656		return
657	}
658
659	if alpha == 0 {
660		if uplo == blas.Upper {
661			if beta == 0 {
662				for i := 0; i < n; i++ {
663					ci := c[i*ldc+i : i*ldc+n]
664					for j := range ci {
665						ci[j] = 0
666					}
667				}
668			} else {
669				for i := 0; i < n; i++ {
670					ci := c[i*ldc+i : i*ldc+n]
671					ci[0] = complex(beta*real(ci[0]), 0)
672					if i != n-1 {
673						c128.DscalUnitary(beta, ci[1:])
674					}
675				}
676			}
677		} else {
678			if beta == 0 {
679				for i := 0; i < n; i++ {
680					ci := c[i*ldc : i*ldc+i+1]
681					for j := range ci {
682						ci[j] = 0
683					}
684				}
685			} else {
686				for i := 0; i < n; i++ {
687					ci := c[i*ldc : i*ldc+i+1]
688					if i != 0 {
689						c128.DscalUnitary(beta, ci[:i])
690					}
691					ci[i] = complex(beta*real(ci[i]), 0)
692				}
693			}
694		}
695		return
696	}
697
698	conjalpha := cmplx.Conj(alpha)
699	cbeta := complex(beta, 0)
700	if trans == blas.NoTrans {
701		// Form  C = alpha*A*Bᴴ + conj(alpha)*B*Aᴴ + beta*C.
702		if uplo == blas.Upper {
703			for i := 0; i < n; i++ {
704				ci := c[i*ldc+i+1 : i*ldc+n]
705				ai := a[i*lda : i*lda+k]
706				bi := b[i*ldb : i*ldb+k]
707				if beta == 0 {
708					cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
709					c[i*ldc+i] = complex(real(cii), 0)
710					for jc := range ci {
711						j := i + 1 + jc
712						ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
713					}
714				} else {
715					cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
716					c[i*ldc+i] = complex(real(cii), 0)
717					for jc, cij := range ci {
718						j := i + 1 + jc
719						ci[jc] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
720					}
721				}
722			}
723		} else {
724			for i := 0; i < n; i++ {
725				ci := c[i*ldc : i*ldc+i]
726				ai := a[i*lda : i*lda+k]
727				bi := b[i*ldb : i*ldb+k]
728				if beta == 0 {
729					for j := range ci {
730						ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi)
731					}
732					cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi)
733					c[i*ldc+i] = complex(real(cii), 0)
734				} else {
735					for j, cij := range ci {
736						ci[j] = alpha*c128.DotcUnitary(b[j*ldb:j*ldb+k], ai) + conjalpha*c128.DotcUnitary(a[j*lda:j*lda+k], bi) + cbeta*cij
737					}
738					cii := alpha*c128.DotcUnitary(bi, ai) + conjalpha*c128.DotcUnitary(ai, bi) + cbeta*c[i*ldc+i]
739					c[i*ldc+i] = complex(real(cii), 0)
740				}
741			}
742		}
743	} else {
744		// Form  C = alpha*Aᴴ*B + conj(alpha)*Bᴴ*A + beta*C.
745		if uplo == blas.Upper {
746			for i := 0; i < n; i++ {
747				ci := c[i*ldc+i : i*ldc+n]
748				switch {
749				case beta == 0:
750					for jc := range ci {
751						ci[jc] = 0
752					}
753				case beta != 1:
754					c128.DscalUnitary(beta, ci)
755					ci[0] = complex(real(ci[0]), 0)
756				default:
757					ci[0] = complex(real(ci[0]), 0)
758				}
759				for j := 0; j < k; j++ {
760					aji := a[j*lda+i]
761					bji := b[j*ldb+i]
762					if aji != 0 {
763						c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb+i:j*ldb+n], ci)
764					}
765					if bji != 0 {
766						c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda+i:j*lda+n], ci)
767					}
768				}
769				ci[0] = complex(real(ci[0]), 0)
770			}
771		} else {
772			for i := 0; i < n; i++ {
773				ci := c[i*ldc : i*ldc+i+1]
774				switch {
775				case beta == 0:
776					for j := range ci {
777						ci[j] = 0
778					}
779				case beta != 1:
780					c128.DscalUnitary(beta, ci)
781					ci[i] = complex(real(ci[i]), 0)
782				default:
783					ci[i] = complex(real(ci[i]), 0)
784				}
785				for j := 0; j < k; j++ {
786					aji := a[j*lda+i]
787					bji := b[j*ldb+i]
788					if aji != 0 {
789						c128.AxpyUnitary(alpha*cmplx.Conj(aji), b[j*ldb:j*ldb+i+1], ci)
790					}
791					if bji != 0 {
792						c128.AxpyUnitary(conjalpha*cmplx.Conj(bji), a[j*lda:j*lda+i+1], ci)
793					}
794				}
795				ci[i] = complex(real(ci[i]), 0)
796			}
797		}
798	}
799}
800
801// Zsymm performs one of the matrix-matrix operations
802//  C = alpha*A*B + beta*C  if side == blas.Left
803//  C = alpha*B*A + beta*C  if side == blas.Right
804// where alpha and beta are scalars, A is an m×m or n×n symmetric matrix and B
805// and C are m×n matrices.
806func (Implementation) Zsymm(side blas.Side, uplo blas.Uplo, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
807	na := m
808	if side == blas.Right {
809		na = n
810	}
811	switch {
812	case side != blas.Left && side != blas.Right:
813		panic(badSide)
814	case uplo != blas.Lower && uplo != blas.Upper:
815		panic(badUplo)
816	case m < 0:
817		panic(mLT0)
818	case n < 0:
819		panic(nLT0)
820	case lda < max(1, na):
821		panic(badLdA)
822	case ldb < max(1, n):
823		panic(badLdB)
824	case ldc < max(1, n):
825		panic(badLdC)
826	}
827
828	// Quick return if possible.
829	if m == 0 || n == 0 {
830		return
831	}
832
833	// For zero matrix size the following slice length checks are trivially satisfied.
834	if len(a) < lda*(na-1)+na {
835		panic(shortA)
836	}
837	if len(b) < ldb*(m-1)+n {
838		panic(shortB)
839	}
840	if len(c) < ldc*(m-1)+n {
841		panic(shortC)
842	}
843
844	// Quick return if possible.
845	if alpha == 0 && beta == 1 {
846		return
847	}
848
849	if alpha == 0 {
850		if beta == 0 {
851			for i := 0; i < m; i++ {
852				ci := c[i*ldc : i*ldc+n]
853				for j := range ci {
854					ci[j] = 0
855				}
856			}
857		} else {
858			for i := 0; i < m; i++ {
859				ci := c[i*ldc : i*ldc+n]
860				c128.ScalUnitary(beta, ci)
861			}
862		}
863		return
864	}
865
866	if side == blas.Left {
867		// Form  C = alpha*A*B + beta*C.
868		for i := 0; i < m; i++ {
869			atmp := alpha * a[i*lda+i]
870			bi := b[i*ldb : i*ldb+n]
871			ci := c[i*ldc : i*ldc+n]
872			if beta == 0 {
873				for j, bij := range bi {
874					ci[j] = atmp * bij
875				}
876			} else {
877				for j, bij := range bi {
878					ci[j] = atmp*bij + beta*ci[j]
879				}
880			}
881			if uplo == blas.Upper {
882				for k := 0; k < i; k++ {
883					atmp = alpha * a[k*lda+i]
884					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
885				}
886				for k := i + 1; k < m; k++ {
887					atmp = alpha * a[i*lda+k]
888					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
889				}
890			} else {
891				for k := 0; k < i; k++ {
892					atmp = alpha * a[i*lda+k]
893					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
894				}
895				for k := i + 1; k < m; k++ {
896					atmp = alpha * a[k*lda+i]
897					c128.AxpyUnitary(atmp, b[k*ldb:k*ldb+n], ci)
898				}
899			}
900		}
901	} else {
902		// Form  C = alpha*B*A + beta*C.
903		if uplo == blas.Upper {
904			for i := 0; i < m; i++ {
905				for j := n - 1; j >= 0; j-- {
906					abij := alpha * b[i*ldb+j]
907					aj := a[j*lda+j+1 : j*lda+n]
908					bi := b[i*ldb+j+1 : i*ldb+n]
909					ci := c[i*ldc+j+1 : i*ldc+n]
910					var tmp complex128
911					for k, ajk := range aj {
912						ci[k] += abij * ajk
913						tmp += bi[k] * ajk
914					}
915					if beta == 0 {
916						c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
917					} else {
918						c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
919					}
920				}
921			}
922		} else {
923			for i := 0; i < m; i++ {
924				for j := 0; j < n; j++ {
925					abij := alpha * b[i*ldb+j]
926					aj := a[j*lda : j*lda+j]
927					bi := b[i*ldb : i*ldb+j]
928					ci := c[i*ldc : i*ldc+j]
929					var tmp complex128
930					for k, ajk := range aj {
931						ci[k] += abij * ajk
932						tmp += bi[k] * ajk
933					}
934					if beta == 0 {
935						c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp
936					} else {
937						c[i*ldc+j] = abij*a[j*lda+j] + alpha*tmp + beta*c[i*ldc+j]
938					}
939				}
940			}
941		}
942	}
943}
944
945// Zsyrk performs one of the symmetric rank-k operations
946//  C = alpha*A*Aᵀ + beta*C  if trans == blas.NoTrans
947//  C = alpha*Aᵀ*A + beta*C  if trans == blas.Trans
948// where alpha and beta are scalars, C is an n×n symmetric matrix and A is
949// an n×k matrix in the first case and a k×n matrix in the second case.
950func (Implementation) Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int) {
951	var rowA, colA int
952	switch trans {
953	default:
954		panic(badTranspose)
955	case blas.NoTrans:
956		rowA, colA = n, k
957	case blas.Trans:
958		rowA, colA = k, n
959	}
960	switch {
961	case uplo != blas.Lower && uplo != blas.Upper:
962		panic(badUplo)
963	case n < 0:
964		panic(nLT0)
965	case k < 0:
966		panic(kLT0)
967	case lda < max(1, colA):
968		panic(badLdA)
969	case ldc < max(1, n):
970		panic(badLdC)
971	}
972
973	// Quick return if possible.
974	if n == 0 {
975		return
976	}
977
978	// For zero matrix size the following slice length checks are trivially satisfied.
979	if len(a) < (rowA-1)*lda+colA {
980		panic(shortA)
981	}
982	if len(c) < (n-1)*ldc+n {
983		panic(shortC)
984	}
985
986	// Quick return if possible.
987	if (alpha == 0 || k == 0) && beta == 1 {
988		return
989	}
990
991	if alpha == 0 {
992		if uplo == blas.Upper {
993			if beta == 0 {
994				for i := 0; i < n; i++ {
995					ci := c[i*ldc+i : i*ldc+n]
996					for j := range ci {
997						ci[j] = 0
998					}
999				}
1000			} else {
1001				for i := 0; i < n; i++ {
1002					ci := c[i*ldc+i : i*ldc+n]
1003					c128.ScalUnitary(beta, ci)
1004				}
1005			}
1006		} else {
1007			if beta == 0 {
1008				for i := 0; i < n; i++ {
1009					ci := c[i*ldc : i*ldc+i+1]
1010					for j := range ci {
1011						ci[j] = 0
1012					}
1013				}
1014			} else {
1015				for i := 0; i < n; i++ {
1016					ci := c[i*ldc : i*ldc+i+1]
1017					c128.ScalUnitary(beta, ci)
1018				}
1019			}
1020		}
1021		return
1022	}
1023
1024	if trans == blas.NoTrans {
1025		// Form  C = alpha*A*Aᵀ + beta*C.
1026		if uplo == blas.Upper {
1027			for i := 0; i < n; i++ {
1028				ci := c[i*ldc+i : i*ldc+n]
1029				ai := a[i*lda : i*lda+k]
1030				if beta == 0 {
1031					for jc := range ci {
1032						j := i + jc
1033						ci[jc] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
1034					}
1035				} else {
1036					for jc, cij := range ci {
1037						j := i + jc
1038						ci[jc] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
1039					}
1040				}
1041			}
1042		} else {
1043			for i := 0; i < n; i++ {
1044				ci := c[i*ldc : i*ldc+i+1]
1045				ai := a[i*lda : i*lda+k]
1046				if beta == 0 {
1047					for j := range ci {
1048						ci[j] = alpha * c128.DotuUnitary(ai, a[j*lda:j*lda+k])
1049					}
1050				} else {
1051					for j, cij := range ci {
1052						ci[j] = beta*cij + alpha*c128.DotuUnitary(ai, a[j*lda:j*lda+k])
1053					}
1054				}
1055			}
1056		}
1057	} else {
1058		// Form  C = alpha*Aᵀ*A + beta*C.
1059		if uplo == blas.Upper {
1060			for i := 0; i < n; i++ {
1061				ci := c[i*ldc+i : i*ldc+n]
1062				switch {
1063				case beta == 0:
1064					for jc := range ci {
1065						ci[jc] = 0
1066					}
1067				case beta != 1:
1068					for jc := range ci {
1069						ci[jc] *= beta
1070					}
1071				}
1072				for j := 0; j < k; j++ {
1073					aji := a[j*lda+i]
1074					if aji != 0 {
1075						c128.AxpyUnitary(alpha*aji, a[j*lda+i:j*lda+n], ci)
1076					}
1077				}
1078			}
1079		} else {
1080			for i := 0; i < n; i++ {
1081				ci := c[i*ldc : i*ldc+i+1]
1082				switch {
1083				case beta == 0:
1084					for j := range ci {
1085						ci[j] = 0
1086					}
1087				case beta != 1:
1088					for j := range ci {
1089						ci[j] *= beta
1090					}
1091				}
1092				for j := 0; j < k; j++ {
1093					aji := a[j*lda+i]
1094					if aji != 0 {
1095						c128.AxpyUnitary(alpha*aji, a[j*lda:j*lda+i+1], ci)
1096					}
1097				}
1098			}
1099		}
1100	}
1101}
1102
1103// Zsyr2k performs one of the symmetric rank-2k operations
1104//  C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C  if trans == blas.NoTrans
1105//  C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C  if trans == blas.Trans
1106// where alpha and beta are scalars, C is an n×n symmetric matrix and A and B
1107// are n×k matrices in the first case and k×n matrices in the second case.
1108func (Implementation) Zsyr2k(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, b []complex128, ldb int, beta complex128, c []complex128, ldc int) {
1109	var row, col int
1110	switch trans {
1111	default:
1112		panic(badTranspose)
1113	case blas.NoTrans:
1114		row, col = n, k
1115	case blas.Trans:
1116		row, col = k, n
1117	}
1118	switch {
1119	case uplo != blas.Lower && uplo != blas.Upper:
1120		panic(badUplo)
1121	case n < 0:
1122		panic(nLT0)
1123	case k < 0:
1124		panic(kLT0)
1125	case lda < max(1, col):
1126		panic(badLdA)
1127	case ldb < max(1, col):
1128		panic(badLdB)
1129	case ldc < max(1, n):
1130		panic(badLdC)
1131	}
1132
1133	// Quick return if possible.
1134	if n == 0 {
1135		return
1136	}
1137
1138	// For zero matrix size the following slice length checks are trivially satisfied.
1139	if len(a) < (row-1)*lda+col {
1140		panic(shortA)
1141	}
1142	if len(b) < (row-1)*ldb+col {
1143		panic(shortB)
1144	}
1145	if len(c) < (n-1)*ldc+n {
1146		panic(shortC)
1147	}
1148
1149	// Quick return if possible.
1150	if (alpha == 0 || k == 0) && beta == 1 {
1151		return
1152	}
1153
1154	if alpha == 0 {
1155		if uplo == blas.Upper {
1156			if beta == 0 {
1157				for i := 0; i < n; i++ {
1158					ci := c[i*ldc+i : i*ldc+n]
1159					for j := range ci {
1160						ci[j] = 0
1161					}
1162				}
1163			} else {
1164				for i := 0; i < n; i++ {
1165					ci := c[i*ldc+i : i*ldc+n]
1166					c128.ScalUnitary(beta, ci)
1167				}
1168			}
1169		} else {
1170			if beta == 0 {
1171				for i := 0; i < n; i++ {
1172					ci := c[i*ldc : i*ldc+i+1]
1173					for j := range ci {
1174						ci[j] = 0
1175					}
1176				}
1177			} else {
1178				for i := 0; i < n; i++ {
1179					ci := c[i*ldc : i*ldc+i+1]
1180					c128.ScalUnitary(beta, ci)
1181				}
1182			}
1183		}
1184		return
1185	}
1186
1187	if trans == blas.NoTrans {
1188		// Form  C = alpha*A*Bᵀ + alpha*B*Aᵀ + beta*C.
1189		if uplo == blas.Upper {
1190			for i := 0; i < n; i++ {
1191				ci := c[i*ldc+i : i*ldc+n]
1192				ai := a[i*lda : i*lda+k]
1193				bi := b[i*ldb : i*ldb+k]
1194				if beta == 0 {
1195					for jc := range ci {
1196						j := i + jc
1197						ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
1198					}
1199				} else {
1200					for jc, cij := range ci {
1201						j := i + jc
1202						ci[jc] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
1203					}
1204				}
1205			}
1206		} else {
1207			for i := 0; i < n; i++ {
1208				ci := c[i*ldc : i*ldc+i+1]
1209				ai := a[i*lda : i*lda+k]
1210				bi := b[i*ldb : i*ldb+k]
1211				if beta == 0 {
1212					for j := range ci {
1213						ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k])
1214					}
1215				} else {
1216					for j, cij := range ci {
1217						ci[j] = alpha*c128.DotuUnitary(ai, b[j*ldb:j*ldb+k]) + alpha*c128.DotuUnitary(bi, a[j*lda:j*lda+k]) + beta*cij
1218					}
1219				}
1220			}
1221		}
1222	} else {
1223		// Form  C = alpha*Aᵀ*B + alpha*Bᵀ*A + beta*C.
1224		if uplo == blas.Upper {
1225			for i := 0; i < n; i++ {
1226				ci := c[i*ldc+i : i*ldc+n]
1227				switch {
1228				case beta == 0:
1229					for jc := range ci {
1230						ci[jc] = 0
1231					}
1232				case beta != 1:
1233					for jc := range ci {
1234						ci[jc] *= beta
1235					}
1236				}
1237				for j := 0; j < k; j++ {
1238					aji := a[j*lda+i]
1239					bji := b[j*ldb+i]
1240					if aji != 0 {
1241						c128.AxpyUnitary(alpha*aji, b[j*ldb+i:j*ldb+n], ci)
1242					}
1243					if bji != 0 {
1244						c128.AxpyUnitary(alpha*bji, a[j*lda+i:j*lda+n], ci)
1245					}
1246				}
1247			}
1248		} else {
1249			for i := 0; i < n; i++ {
1250				ci := c[i*ldc : i*ldc+i+1]
1251				switch {
1252				case beta == 0:
1253					for j := range ci {
1254						ci[j] = 0
1255					}
1256				case beta != 1:
1257					for j := range ci {
1258						ci[j] *= beta
1259					}
1260				}
1261				for j := 0; j < k; j++ {
1262					aji := a[j*lda+i]
1263					bji := b[j*ldb+i]
1264					if aji != 0 {
1265						c128.AxpyUnitary(alpha*aji, b[j*ldb:j*ldb+i+1], ci)
1266					}
1267					if bji != 0 {
1268						c128.AxpyUnitary(alpha*bji, a[j*lda:j*lda+i+1], ci)
1269					}
1270				}
1271			}
1272		}
1273	}
1274}
1275
1276// Ztrmm performs one of the matrix-matrix operations
1277//  B = alpha * op(A) * B  if side == blas.Left,
1278//  B = alpha * B * op(A)  if side == blas.Right,
1279// where alpha is a scalar, B is an m×n matrix, A is a unit, or non-unit,
1280// upper or lower triangular matrix and op(A) is one of
1281//  op(A) = A   if trans == blas.NoTrans,
1282//  op(A) = Aᵀ  if trans == blas.Trans,
1283//  op(A) = Aᴴ  if trans == blas.ConjTrans.
1284func (Implementation) Ztrmm(side blas.Side, uplo blas.Uplo, trans blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
1285	na := m
1286	if side == blas.Right {
1287		na = n
1288	}
1289	switch {
1290	case side != blas.Left && side != blas.Right:
1291		panic(badSide)
1292	case uplo != blas.Lower && uplo != blas.Upper:
1293		panic(badUplo)
1294	case trans != blas.NoTrans && trans != blas.Trans && trans != blas.ConjTrans:
1295		panic(badTranspose)
1296	case diag != blas.Unit && diag != blas.NonUnit:
1297		panic(badDiag)
1298	case m < 0:
1299		panic(mLT0)
1300	case n < 0:
1301		panic(nLT0)
1302	case lda < max(1, na):
1303		panic(badLdA)
1304	case ldb < max(1, n):
1305		panic(badLdB)
1306	}
1307
1308	// Quick return if possible.
1309	if m == 0 || n == 0 {
1310		return
1311	}
1312
1313	// For zero matrix size the following slice length checks are trivially satisfied.
1314	if len(a) < (na-1)*lda+na {
1315		panic(shortA)
1316	}
1317	if len(b) < (m-1)*ldb+n {
1318		panic(shortB)
1319	}
1320
1321	// Quick return if possible.
1322	if alpha == 0 {
1323		for i := 0; i < m; i++ {
1324			bi := b[i*ldb : i*ldb+n]
1325			for j := range bi {
1326				bi[j] = 0
1327			}
1328		}
1329		return
1330	}
1331
1332	noConj := trans != blas.ConjTrans
1333	noUnit := diag == blas.NonUnit
1334	if side == blas.Left {
1335		if trans == blas.NoTrans {
1336			// Form B = alpha*A*B.
1337			if uplo == blas.Upper {
1338				for i := 0; i < m; i++ {
1339					aii := alpha
1340					if noUnit {
1341						aii *= a[i*lda+i]
1342					}
1343					bi := b[i*ldb : i*ldb+n]
1344					for j := range bi {
1345						bi[j] *= aii
1346					}
1347					for ja, aij := range a[i*lda+i+1 : i*lda+m] {
1348						j := ja + i + 1
1349						if aij != 0 {
1350							c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
1351						}
1352					}
1353				}
1354			} else {
1355				for i := m - 1; i >= 0; i-- {
1356					aii := alpha
1357					if noUnit {
1358						aii *= a[i*lda+i]
1359					}
1360					bi := b[i*ldb : i*ldb+n]
1361					for j := range bi {
1362						bi[j] *= aii
1363					}
1364					for j, aij := range a[i*lda : i*lda+i] {
1365						if aij != 0 {
1366							c128.AxpyUnitary(alpha*aij, b[j*ldb:j*ldb+n], bi)
1367						}
1368					}
1369				}
1370			}
1371		} else {
1372			// Form B = alpha*Aᵀ*B  or  B = alpha*Aᴴ*B.
1373			if uplo == blas.Upper {
1374				for k := m - 1; k >= 0; k-- {
1375					bk := b[k*ldb : k*ldb+n]
1376					for ja, ajk := range a[k*lda+k+1 : k*lda+m] {
1377						if ajk == 0 {
1378							continue
1379						}
1380						j := k + 1 + ja
1381						if noConj {
1382							c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
1383						} else {
1384							c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
1385						}
1386					}
1387					akk := alpha
1388					if noUnit {
1389						if noConj {
1390							akk *= a[k*lda+k]
1391						} else {
1392							akk *= cmplx.Conj(a[k*lda+k])
1393						}
1394					}
1395					if akk != 1 {
1396						c128.ScalUnitary(akk, bk)
1397					}
1398				}
1399			} else {
1400				for k := 0; k < m; k++ {
1401					bk := b[k*ldb : k*ldb+n]
1402					for j, ajk := range a[k*lda : k*lda+k] {
1403						if ajk == 0 {
1404							continue
1405						}
1406						if noConj {
1407							c128.AxpyUnitary(alpha*ajk, bk, b[j*ldb:j*ldb+n])
1408						} else {
1409							c128.AxpyUnitary(alpha*cmplx.Conj(ajk), bk, b[j*ldb:j*ldb+n])
1410						}
1411					}
1412					akk := alpha
1413					if noUnit {
1414						if noConj {
1415							akk *= a[k*lda+k]
1416						} else {
1417							akk *= cmplx.Conj(a[k*lda+k])
1418						}
1419					}
1420					if akk != 1 {
1421						c128.ScalUnitary(akk, bk)
1422					}
1423				}
1424			}
1425		}
1426	} else {
1427		if trans == blas.NoTrans {
1428			// Form B = alpha*B*A.
1429			if uplo == blas.Upper {
1430				for i := 0; i < m; i++ {
1431					bi := b[i*ldb : i*ldb+n]
1432					for k := n - 1; k >= 0; k-- {
1433						abik := alpha * bi[k]
1434						if abik == 0 {
1435							continue
1436						}
1437						bi[k] = abik
1438						if noUnit {
1439							bi[k] *= a[k*lda+k]
1440						}
1441						c128.AxpyUnitary(abik, a[k*lda+k+1:k*lda+n], bi[k+1:])
1442					}
1443				}
1444			} else {
1445				for i := 0; i < m; i++ {
1446					bi := b[i*ldb : i*ldb+n]
1447					for k := 0; k < n; k++ {
1448						abik := alpha * bi[k]
1449						if abik == 0 {
1450							continue
1451						}
1452						bi[k] = abik
1453						if noUnit {
1454							bi[k] *= a[k*lda+k]
1455						}
1456						c128.AxpyUnitary(abik, a[k*lda:k*lda+k], bi[:k])
1457					}
1458				}
1459			}
1460		} else {
1461			// Form B = alpha*B*Aᵀ  or  B = alpha*B*Aᴴ.
1462			if uplo == blas.Upper {
1463				for i := 0; i < m; i++ {
1464					bi := b[i*ldb : i*ldb+n]
1465					for j, bij := range bi {
1466						if noConj {
1467							if noUnit {
1468								bij *= a[j*lda+j]
1469							}
1470							bij += c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
1471						} else {
1472							if noUnit {
1473								bij *= cmplx.Conj(a[j*lda+j])
1474							}
1475							bij += c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
1476						}
1477						bi[j] = alpha * bij
1478					}
1479				}
1480			} else {
1481				for i := 0; i < m; i++ {
1482					bi := b[i*ldb : i*ldb+n]
1483					for j := n - 1; j >= 0; j-- {
1484						bij := bi[j]
1485						if noConj {
1486							if noUnit {
1487								bij *= a[j*lda+j]
1488							}
1489							bij += c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
1490						} else {
1491							if noUnit {
1492								bij *= cmplx.Conj(a[j*lda+j])
1493							}
1494							bij += c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
1495						}
1496						bi[j] = alpha * bij
1497					}
1498				}
1499			}
1500		}
1501	}
1502}
1503
1504// Ztrsm solves one of the matrix equations
1505//  op(A) * X = alpha * B  if side == blas.Left,
1506//  X * op(A) = alpha * B  if side == blas.Right,
1507// where alpha is a scalar, X and B are m×n matrices, A is a unit or
1508// non-unit, upper or lower triangular matrix and op(A) is one of
1509//  op(A) = A   if transA == blas.NoTrans,
1510//  op(A) = Aᵀ  if transA == blas.Trans,
1511//  op(A) = Aᴴ  if transA == blas.ConjTrans.
1512// On return the matrix X is overwritten on B.
1513func (Implementation) Ztrsm(side blas.Side, uplo blas.Uplo, transA blas.Transpose, diag blas.Diag, m, n int, alpha complex128, a []complex128, lda int, b []complex128, ldb int) {
1514	na := m
1515	if side == blas.Right {
1516		na = n
1517	}
1518	switch {
1519	case side != blas.Left && side != blas.Right:
1520		panic(badSide)
1521	case uplo != blas.Lower && uplo != blas.Upper:
1522		panic(badUplo)
1523	case transA != blas.NoTrans && transA != blas.Trans && transA != blas.ConjTrans:
1524		panic(badTranspose)
1525	case diag != blas.Unit && diag != blas.NonUnit:
1526		panic(badDiag)
1527	case m < 0:
1528		panic(mLT0)
1529	case n < 0:
1530		panic(nLT0)
1531	case lda < max(1, na):
1532		panic(badLdA)
1533	case ldb < max(1, n):
1534		panic(badLdB)
1535	}
1536
1537	// Quick return if possible.
1538	if m == 0 || n == 0 {
1539		return
1540	}
1541
1542	// For zero matrix size the following slice length checks are trivially satisfied.
1543	if len(a) < (na-1)*lda+na {
1544		panic(shortA)
1545	}
1546	if len(b) < (m-1)*ldb+n {
1547		panic(shortB)
1548	}
1549
1550	if alpha == 0 {
1551		for i := 0; i < m; i++ {
1552			for j := 0; j < n; j++ {
1553				b[i*ldb+j] = 0
1554			}
1555		}
1556		return
1557	}
1558
1559	noConj := transA != blas.ConjTrans
1560	noUnit := diag == blas.NonUnit
1561	if side == blas.Left {
1562		if transA == blas.NoTrans {
1563			// Form  B = alpha*inv(A)*B.
1564			if uplo == blas.Upper {
1565				for i := m - 1; i >= 0; i-- {
1566					bi := b[i*ldb : i*ldb+n]
1567					if alpha != 1 {
1568						c128.ScalUnitary(alpha, bi)
1569					}
1570					for ka, aik := range a[i*lda+i+1 : i*lda+m] {
1571						k := i + 1 + ka
1572						if aik != 0 {
1573							c128.AxpyUnitary(-aik, b[k*ldb:k*ldb+n], bi)
1574						}
1575					}
1576					if noUnit {
1577						c128.ScalUnitary(1/a[i*lda+i], bi)
1578					}
1579				}
1580			} else {
1581				for i := 0; i < m; i++ {
1582					bi := b[i*ldb : i*ldb+n]
1583					if alpha != 1 {
1584						c128.ScalUnitary(alpha, bi)
1585					}
1586					for j, aij := range a[i*lda : i*lda+i] {
1587						if aij != 0 {
1588							c128.AxpyUnitary(-aij, b[j*ldb:j*ldb+n], bi)
1589						}
1590					}
1591					if noUnit {
1592						c128.ScalUnitary(1/a[i*lda+i], bi)
1593					}
1594				}
1595			}
1596		} else {
1597			// Form  B = alpha*inv(Aᵀ)*B  or  B = alpha*inv(Aᴴ)*B.
1598			if uplo == blas.Upper {
1599				for i := 0; i < m; i++ {
1600					bi := b[i*ldb : i*ldb+n]
1601					if noUnit {
1602						if noConj {
1603							c128.ScalUnitary(1/a[i*lda+i], bi)
1604						} else {
1605							c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
1606						}
1607					}
1608					for ja, aij := range a[i*lda+i+1 : i*lda+m] {
1609						if aij == 0 {
1610							continue
1611						}
1612						j := i + 1 + ja
1613						if noConj {
1614							c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
1615						} else {
1616							c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
1617						}
1618					}
1619					if alpha != 1 {
1620						c128.ScalUnitary(alpha, bi)
1621					}
1622				}
1623			} else {
1624				for i := m - 1; i >= 0; i-- {
1625					bi := b[i*ldb : i*ldb+n]
1626					if noUnit {
1627						if noConj {
1628							c128.ScalUnitary(1/a[i*lda+i], bi)
1629						} else {
1630							c128.ScalUnitary(1/cmplx.Conj(a[i*lda+i]), bi)
1631						}
1632					}
1633					for j, aij := range a[i*lda : i*lda+i] {
1634						if aij == 0 {
1635							continue
1636						}
1637						if noConj {
1638							c128.AxpyUnitary(-aij, bi, b[j*ldb:j*ldb+n])
1639						} else {
1640							c128.AxpyUnitary(-cmplx.Conj(aij), bi, b[j*ldb:j*ldb+n])
1641						}
1642					}
1643					if alpha != 1 {
1644						c128.ScalUnitary(alpha, bi)
1645					}
1646				}
1647			}
1648		}
1649	} else {
1650		if transA == blas.NoTrans {
1651			// Form  B = alpha*B*inv(A).
1652			if uplo == blas.Upper {
1653				for i := 0; i < m; i++ {
1654					bi := b[i*ldb : i*ldb+n]
1655					if alpha != 1 {
1656						c128.ScalUnitary(alpha, bi)
1657					}
1658					for j, bij := range bi {
1659						if bij == 0 {
1660							continue
1661						}
1662						if noUnit {
1663							bi[j] /= a[j*lda+j]
1664						}
1665						c128.AxpyUnitary(-bi[j], a[j*lda+j+1:j*lda+n], bi[j+1:n])
1666					}
1667				}
1668			} else {
1669				for i := 0; i < m; i++ {
1670					bi := b[i*ldb : i*ldb+n]
1671					if alpha != 1 {
1672						c128.ScalUnitary(alpha, bi)
1673					}
1674					for j := n - 1; j >= 0; j-- {
1675						if bi[j] == 0 {
1676							continue
1677						}
1678						if noUnit {
1679							bi[j] /= a[j*lda+j]
1680						}
1681						c128.AxpyUnitary(-bi[j], a[j*lda:j*lda+j], bi[:j])
1682					}
1683				}
1684			}
1685		} else {
1686			// Form  B = alpha*B*inv(Aᵀ)  or   B = alpha*B*inv(Aᴴ).
1687			if uplo == blas.Upper {
1688				for i := 0; i < m; i++ {
1689					bi := b[i*ldb : i*ldb+n]
1690					for j := n - 1; j >= 0; j-- {
1691						bij := alpha * bi[j]
1692						if noConj {
1693							bij -= c128.DotuUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
1694							if noUnit {
1695								bij /= a[j*lda+j]
1696							}
1697						} else {
1698							bij -= c128.DotcUnitary(a[j*lda+j+1:j*lda+n], bi[j+1:n])
1699							if noUnit {
1700								bij /= cmplx.Conj(a[j*lda+j])
1701							}
1702						}
1703						bi[j] = bij
1704					}
1705				}
1706			} else {
1707				for i := 0; i < m; i++ {
1708					bi := b[i*ldb : i*ldb+n]
1709					for j, bij := range bi {
1710						bij *= alpha
1711						if noConj {
1712							bij -= c128.DotuUnitary(a[j*lda:j*lda+j], bi[:j])
1713							if noUnit {
1714								bij /= a[j*lda+j]
1715							}
1716						} else {
1717							bij -= c128.DotcUnitary(a[j*lda:j*lda+j], bi[:j])
1718							if noUnit {
1719								bij /= cmplx.Conj(a[j*lda+j])
1720							}
1721						}
1722						bi[j] = bij
1723					}
1724				}
1725			}
1726		}
1727	}
1728}
1729