1// Copyright ©2017 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package fd 6 7import ( 8 "gonum.org/v1/gonum/floats" 9 "gonum.org/v1/gonum/mat" 10) 11 12// ConstFunc is a constant function returning the value held by the type. 13type ConstFunc float64 14 15func (c ConstFunc) Func(x []float64) float64 { 16 return float64(c) 17} 18 19func (c ConstFunc) Grad(grad, x []float64) { 20 for i := range grad { 21 grad[i] = 0 22 } 23} 24 25func (c ConstFunc) Hess(dst mat.MutableSymmetric, x []float64) { 26 n := len(x) 27 for i := 0; i < n; i++ { 28 for j := i; j < n; j++ { 29 dst.SetSym(i, j, 0) 30 } 31 } 32} 33 34// LinearFunc is a linear function returning w*x+c. 35type LinearFunc struct { 36 w []float64 37 c float64 38} 39 40func (l LinearFunc) Func(x []float64) float64 { 41 return floats.Dot(l.w, x) + l.c 42} 43 44func (l LinearFunc) Grad(grad, x []float64) { 45 copy(grad, l.w) 46} 47 48func (l LinearFunc) Hess(dst mat.MutableSymmetric, x []float64) { 49 n := len(x) 50 for i := 0; i < n; i++ { 51 for j := i; j < n; j++ { 52 dst.SetSym(i, j, 0) 53 } 54 } 55} 56 57// QuadFunc is a quadratic function returning 0.5*x'*a*x + b*x + c. 58type QuadFunc struct { 59 a *mat.SymDense 60 b *mat.VecDense 61 c float64 62} 63 64func (q QuadFunc) Func(x []float64) float64 { 65 v := mat.NewVecDense(len(x), x) 66 var tmp mat.VecDense 67 tmp.MulVec(q.a, v) 68 return 0.5*mat.Dot(&tmp, v) + mat.Dot(q.b, v) + q.c 69} 70 71func (q QuadFunc) Grad(grad, x []float64) { 72 var tmp mat.VecDense 73 v := mat.NewVecDense(len(x), x) 74 tmp.MulVec(q.a, v) 75 for i := range grad { 76 grad[i] = tmp.At(i, 0) + q.b.At(i, 0) 77 } 78} 79 80func (q QuadFunc) Hess(dst mat.MutableSymmetric, x []float64) { 81 n := len(x) 82 for i := 0; i < n; i++ { 83 for j := i; j < n; j++ { 84 dst.SetSym(i, j, q.a.At(i, j)) 85 } 86 } 87} 88