1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package fd
6
7import (
8	"testing"
9
10	"gonum.org/v1/gonum/floats"
11	"gonum.org/v1/gonum/mat"
12)
13
14type CrossLaplacianTester interface {
15	Func(x, y []float64) float64
16	CrossLaplacian(x, y []float64) float64
17}
18
19type WrapperCL struct {
20	Tester HessianTester
21}
22
23func (WrapperCL) constructZ(x, y []float64) []float64 {
24	z := make([]float64, len(x)+len(y))
25	copy(z, x)
26	copy(z[len(x):], y)
27	return z
28}
29
30func (w WrapperCL) Func(x, y []float64) float64 {
31	z := w.constructZ(x, y)
32	return w.Tester.Func(z)
33}
34
35func (w WrapperCL) CrossLaplacian(x, y []float64) float64 {
36	z := w.constructZ(x, y)
37	hess := mat.NewSymDense(len(z), nil)
38	w.Tester.Hess(hess, z)
39	// The CrossLaplacian is the trace of the off-diagonal block of the Hessian.
40	var l float64
41	for i := 0; i < len(x); i++ {
42		l += hess.At(i, i+len(x))
43	}
44	return l
45}
46
47func TestCrossLaplacian(t *testing.T) {
48	for cas, test := range []struct {
49		l        CrossLaplacianTester
50		x, y     []float64
51		settings *Settings
52		tol      float64
53	}{
54		{
55			l:   WrapperCL{Watson{}},
56			x:   []float64{0.2, 0.3},
57			y:   []float64{0.1, 0.4},
58			tol: 1e-3,
59		},
60		{
61			l:   WrapperCL{Watson{}},
62			x:   []float64{2, 3, 1},
63			y:   []float64{1, 4, 1},
64			tol: 1e-3,
65		},
66		{
67			l:   WrapperCL{ConstFunc(6)},
68			x:   []float64{2, -3, 1},
69			y:   []float64{1, 4, -5},
70			tol: 1e-6,
71		},
72		{
73			l:   WrapperCL{LinearFunc{w: []float64{10, 6, -1, 5}, c: 5}},
74			x:   []float64{3, 1},
75			y:   []float64{8, 6},
76			tol: 1e-6,
77		},
78		{
79			l: WrapperCL{QuadFunc{
80				a: mat.NewSymDense(4, []float64{
81					10, 2, 1, 9,
82					2, 5, -3, 4,
83					1, -3, 6, 2,
84					9, 4, 2, -14,
85				}),
86				b: mat.NewVecDense(4, []float64{3, -2, -1, 4}),
87				c: 5,
88			}},
89			x:   []float64{-1.6, -3},
90			y:   []float64{1.8, 3.4},
91			tol: 1e-6,
92		},
93	} {
94		got := CrossLaplacian(test.l.Func, test.x, test.y, test.settings)
95		want := test.l.CrossLaplacian(test.x, test.y)
96		if !floats.EqualWithinAbsOrRel(got, want, test.tol, test.tol) {
97			t.Errorf("Cas %d: CrossLaplacian mismatch serial. got %v, want %v", cas, got, want)
98		}
99
100		// Test that concurrency works.
101		settings := test.settings
102		if settings == nil {
103			settings = &Settings{}
104		}
105		settings.Concurrent = true
106		got2 := CrossLaplacian(test.l.Func, test.x, test.y, settings)
107		if !floats.EqualWithinAbsOrRel(got, got2, 1e-6, 1e-6) {
108			t.Errorf("Cas %d: Laplacian mismatch. got %v, want %v", cas, got2, got)
109		}
110	}
111}
112