1// Copyright ©2017 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package fd
6
7import (
8	"gonum.org/v1/gonum/floats"
9	"gonum.org/v1/gonum/mat"
10)
11
12// ConstFunc is a constant function returning the value held by the type.
13type ConstFunc float64
14
15func (c ConstFunc) Func(x []float64) float64 {
16	return float64(c)
17}
18
19func (c ConstFunc) Grad(grad, x []float64) {
20	for i := range grad {
21		grad[i] = 0
22	}
23}
24
25func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) {
26	n := len(x)
27	for i := 0; i < n; i++ {
28		for j := i; j < n; j++ {
29			dst.SetSym(i, j, 0)
30		}
31	}
32}
33
34// LinearFunc is a linear function returning w*x+c.
35type LinearFunc struct {
36	w []float64
37	c float64
38}
39
40func (l LinearFunc) Func(x []float64) float64 {
41	return floats.Dot(l.w, x) + l.c
42}
43
44func (l LinearFunc) Grad(grad, x []float64) {
45	copy(grad, l.w)
46}
47
48func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) {
49	n := len(x)
50	for i := 0; i < n; i++ {
51		for j := i; j < n; j++ {
52			dst.SetSym(i, j, 0)
53		}
54	}
55}
56
57// QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c.
58type QuadFunc struct {
59	a *mat.SymDense
60	b *mat.VecDense
61	c float64
62}
63
64func (q QuadFunc) Func(x []float64) float64 {
65	v := mat.NewVecDense(len(x), x)
66	var tmp mat.VecDense
67	tmp.MulVec(q.a, v)
68	return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c
69}
70
71func (q QuadFunc) Grad(grad, x []float64) {
72	var tmp mat.VecDense
73	v := mat.NewVecDense(len(x), x)
74	tmp.MulVec(q.a, v)
75	for i := range grad {
76		grad[i] = tmp.At(i, 0) + q.b.At(i, 0)
77	}
78}
79
80func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) {
81	n := len(x)
82	for i := 0; i < n; i++ {
83		for j := i; j < n; j++ {
84			dst.SetSym(i, j, q.a.At(i, j))
85		}
86	}
87}
88