1// Copyright ©2019 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testblas
6
7import (
8	"fmt"
9	"testing"
10
11	"golang.org/x/exp/rand"
12
13	"gonum.org/v1/gonum/blas"
14)
15
16type Zsyrker interface {
17	Zsyrk(uplo blas.Uplo, trans blas.Transpose, n, k int, alpha complex128, a []complex128, lda int, beta complex128, c []complex128, ldc int)
18}
19
20func ZsyrkTest(t *testing.T, impl Zsyrker) {
21	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
22		for _, trans := range []blas.Transpose{blas.NoTrans, blas.Trans} {
23			name := uploString(uplo) + "-" + transString(trans)
24			t.Run(name, func(t *testing.T) {
25				for _, n := range []int{0, 1, 2, 3, 4, 5} {
26					for _, k := range []int{0, 1, 2, 3, 4, 5, 7} {
27						zsyrkTest(t, impl, uplo, trans, n, k)
28					}
29				}
30			})
31		}
32	}
33}
34
35func zsyrkTest(t *testing.T, impl Zsyrker, uplo blas.Uplo, trans blas.Transpose, n, k int) {
36	const tol = 1e-13
37
38	rnd := rand.New(rand.NewSource(1))
39
40	rowA, colA := n, k
41	if trans == blas.Trans {
42		rowA, colA = k, n
43	}
44	for _, lda := range []int{max(1, colA), colA + 2} {
45		for _, ldc := range []int{max(1, n), n + 4} {
46			for _, alpha := range []complex128{0, 1, complex(0.7, -0.9)} {
47				for _, beta := range []complex128{0, 1, complex(1.3, -1.1)} {
48					// Allocate the matrix A and fill it with random numbers.
49					a := make([]complex128, rowA*lda)
50					for i := range a {
51						a[i] = rndComplex128(rnd)
52					}
53					// Create a copy of A for checking that
54					// Zsyrk does not modify A.
55					aCopy := make([]complex128, len(a))
56					copy(aCopy, a)
57
58					// Allocate the matrix C and fill it with random numbers.
59					c := make([]complex128, n*ldc)
60					for i := range c {
61						c[i] = rndComplex128(rnd)
62					}
63					// Create a copy of C for checking that
64					// Zsyrk does not modify its triangle
65					// opposite to uplo.
66					cCopy := make([]complex128, len(c))
67					copy(cCopy, c)
68					// Create a copy of C expanded into a
69					// full symmetric matrix for computing
70					// the expected result using zmm.
71					cSym := make([]complex128, len(c))
72					copy(cSym, c)
73					if uplo == blas.Upper {
74						for i := 0; i < n-1; i++ {
75							for j := i + 1; j < n; j++ {
76								cSym[j*ldc+i] = cSym[i*ldc+j]
77							}
78						}
79					} else {
80						for i := 1; i < n; i++ {
81							for j := 0; j < i; j++ {
82								cSym[j*ldc+i] = cSym[i*ldc+j]
83							}
84						}
85					}
86
87					// Compute the expected result using an internal Zgemm implementation.
88					var want []complex128
89					if trans == blas.NoTrans {
90						want = zmm(blas.NoTrans, blas.Trans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc)
91					} else {
92						want = zmm(blas.Trans, blas.NoTrans, n, n, k, alpha, a, lda, a, lda, beta, cSym, ldc)
93					}
94
95					// Compute the result using Zsyrk.
96					impl.Zsyrk(uplo, trans, n, k, alpha, a, lda, beta, c, ldc)
97
98					prefix := fmt.Sprintf("n=%v,k=%v,lda=%v,ldc=%v,alpha=%v,beta=%v", n, k, lda, ldc, alpha, beta)
99
100					if !zsame(a, aCopy) {
101						t.Errorf("%v: unexpected modification of A", prefix)
102						continue
103					}
104					if uplo == blas.Upper && !zSameLowerTri(n, c, ldc, cCopy, ldc) {
105						t.Errorf("%v: unexpected modification in lower triangle of C", prefix)
106						continue
107					}
108					if uplo == blas.Lower && !zSameUpperTri(n, c, ldc, cCopy, ldc) {
109						t.Errorf("%v: unexpected modification in upper triangle of C", prefix)
110						continue
111					}
112
113					// Expand C into a full symmetric matrix
114					// for comparison with the result from zmm.
115					if uplo == blas.Upper {
116						for i := 0; i < n-1; i++ {
117							for j := i + 1; j < n; j++ {
118								c[j*ldc+i] = c[i*ldc+j]
119							}
120						}
121					} else {
122						for i := 1; i < n; i++ {
123							for j := 0; j < i; j++ {
124								c[j*ldc+i] = c[i*ldc+j]
125							}
126						}
127					}
128					if !zEqualApprox(c, want, tol) {
129						t.Errorf("%v: unexpected result", prefix)
130					}
131				}
132			}
133		}
134	}
135}
136