1// Copyright ©2019 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package testlapack
6
7import (
8	"fmt"
9	"testing"
10
11	"golang.org/x/exp/rand"
12
13	"gonum.org/v1/gonum/blas"
14	"gonum.org/v1/gonum/blas/blas64"
15)
16
17type Dpotrier interface {
18	Dpotri(uplo blas.Uplo, n int, a []float64, lda int) bool
19
20	Dpotrf(uplo blas.Uplo, n int, a []float64, lda int) bool
21}
22
23func DpotriTest(t *testing.T, impl Dpotrier) {
24	for _, uplo := range []blas.Uplo{blas.Upper, blas.Lower} {
25		name := "Upper"
26		if uplo == blas.Lower {
27			name = "Lower"
28		}
29		t.Run(name, func(t *testing.T) {
30			// Include small and large sizes to make sure that both
31			// unblocked and blocked paths are taken.
32			ns := []int{0, 1, 2, 3, 4, 5, 10, 25, 31, 32, 33, 63, 64, 65, 127, 128, 129}
33			const tol = 1e-12
34
35			bi := blas64.Implementation()
36			rnd := rand.New(rand.NewSource(1))
37			for _, n := range ns {
38				for _, lda := range []int{max(1, n), n + 11} {
39					prefix := fmt.Sprintf("n=%v,lda=%v", n, lda)
40
41					// Generate a random diagonal matrix D with positive entries.
42					d := make([]float64, n)
43					Dlatm1(d, 3, 10000, false, 2, rnd)
44
45					// Construct a positive definite matrix A as
46					//  A = U * D * U^T
47					// where U is a random orthogonal matrix.
48					a := make([]float64, n*lda)
49					Dlagsy(n, 0, d, a, lda, rnd, make([]float64, 2*n))
50					// Create a copy of A.
51					aCopy := make([]float64, len(a))
52					copy(aCopy, a)
53					// Compute the Cholesky factorization of A.
54					ok := impl.Dpotrf(uplo, n, a, lda)
55					if !ok {
56						t.Fatalf("%v: unexpected Cholesky failure", prefix)
57					}
58
59					// Compute the inverse inv(A).
60					ok = impl.Dpotri(uplo, n, a, lda)
61					if !ok {
62						t.Errorf("%v: unexpected failure", prefix)
63						continue
64					}
65
66					// Check that the triangle of A opposite to uplo has not been modified.
67					if uplo == blas.Upper && !sameLowerTri(n, aCopy, lda, a, lda) {
68						t.Errorf("%v: unexpected modification in lower triangle", prefix)
69						continue
70					}
71					if uplo == blas.Lower && !sameUpperTri(n, aCopy, lda, a, lda) {
72						t.Errorf("%v: unexpected modification in upper triangle", prefix)
73						continue
74					}
75
76					// Change notation for the sake of clarity.
77					ainv := a
78					ldainv := lda
79
80					// Expand ainv into a full dense matrix so that we can call Dsymm below.
81					if uplo == blas.Upper {
82						for i := 1; i < n; i++ {
83							for j := 0; j < i; j++ {
84								ainv[i*ldainv+j] = ainv[j*ldainv+i]
85							}
86						}
87					} else {
88						for i := 0; i < n-1; i++ {
89							for j := i + 1; j < n; j++ {
90								ainv[i*ldainv+j] = ainv[j*ldainv+i]
91							}
92						}
93					}
94
95					// Compute A*inv(A) and store the result into want.
96					ldwant := max(1, n)
97					want := make([]float64, n*ldwant)
98					bi.Dsymm(blas.Left, uplo, n, n, 1, aCopy, lda, ainv, ldainv, 0, want, ldwant)
99
100					// Check that want is close to the identity matrix.
101					dist := distFromIdentity(n, want, ldwant)
102					if dist > tol {
103						t.Errorf("%v: |A * inv(A) - I| = %v is too large", prefix, dist)
104					}
105				}
106			}
107		})
108	}
109}
110