1// Copyright ©2017 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package fd 6 7import ( 8 "testing" 9 10 "gonum.org/v1/gonum/floats" 11 "gonum.org/v1/gonum/mat" 12) 13 14type CrossLaplacianTester interface { 15 Func(x, y []float64) float64 16 CrossLaplacian(x, y []float64) float64 17} 18 19type WrapperCL struct { 20 Tester HessianTester 21} 22 23func (WrapperCL) constructZ(x, y []float64) []float64 { 24 z := make([]float64, len(x)+len(y)) 25 copy(z, x) 26 copy(z[len(x):], y) 27 return z 28} 29 30func (w WrapperCL) Func(x, y []float64) float64 { 31 z := w.constructZ(x, y) 32 return w.Tester.Func(z) 33} 34 35func (w WrapperCL) CrossLaplacian(x, y []float64) float64 { 36 z := w.constructZ(x, y) 37 hess := mat.NewSymDense(len(z), nil) 38 w.Tester.Hess(hess, z) 39 // The CrossLaplacian is the trace of the off-diagonal block of the Hessian. 40 var l float64 41 for i := 0; i < len(x); i++ { 42 l += hess.At(i, i+len(x)) 43 } 44 return l 45} 46 47func TestCrossLaplacian(t *testing.T) { 48 for cas, test := range []struct { 49 l CrossLaplacianTester 50 x, y []float64 51 settings *Settings 52 tol float64 53 }{ 54 { 55 l: WrapperCL{Watson{}}, 56 x: []float64{0.2, 0.3}, 57 y: []float64{0.1, 0.4}, 58 tol: 1e-3, 59 }, 60 { 61 l: WrapperCL{Watson{}}, 62 x: []float64{2, 3, 1}, 63 y: []float64{1, 4, 1}, 64 tol: 1e-3, 65 }, 66 { 67 l: WrapperCL{ConstFunc(6)}, 68 x: []float64{2, -3, 1}, 69 y: []float64{1, 4, -5}, 70 tol: 1e-6, 71 }, 72 { 73 l: WrapperCL{LinearFunc{w: []float64{10, 6, -1, 5}, c: 5}}, 74 x: []float64{3, 1}, 75 y: []float64{8, 6}, 76 tol: 1e-6, 77 }, 78 { 79 l: WrapperCL{QuadFunc{ 80 a: mat.NewSymDense(4, []float64{ 81 10, 2, 1, 9, 82 2, 5, -3, 4, 83 1, -3, 6, 2, 84 9, 4, 2, -14, 85 }), 86 b: mat.NewVecDense(4, []float64{3, -2, -1, 4}), 87 c: 5, 88 }}, 89 x: []float64{-1.6, -3}, 90 y: []float64{1.8, 3.4}, 91 tol: 1e-6, 92 }, 93 } { 94 got := CrossLaplacian(test.l.Func, test.x, test.y, test.settings) 95 want := test.l.CrossLaplacian(test.x, test.y) 96 if !floats.EqualWithinAbsOrRel(got, want, test.tol, test.tol) { 97 t.Errorf("Cas %d: CrossLaplacian mismatch serial. got %v, want %v", cas, got, want) 98 } 99 100 // Test that concurrency works. 101 settings := test.settings 102 if settings == nil { 103 settings = &Settings{} 104 } 105 settings.Concurrent = true 106 got2 := CrossLaplacian(test.l.Func, test.x, test.y, settings) 107 if !floats.EqualWithinAbsOrRel(got, got2, 1e-6, 1e-6) { 108 t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got) 109 } 110 } 111} 112