1// Copyright ©2014 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/blas64"
10	"gonum.org/v1/gonum/internal/asm/f64"
11)
12
13// Inner computes the generalized inner product
14//  xᵀ A y
15// between the vectors x and y with matrix A, where x and y are treated as
16// column vectors.
17//
18// This is only a true inner product if A is symmetric positive definite, though
19// the operation works for any matrix A.
20//
21// Inner panics if x.Len != m or y.Len != n when A is an m x n matrix.
22func Inner(x Vector, a Matrix, y Vector) float64 {
23	m, n := a.Dims()
24	if x.Len() != m {
25		panic(ErrShape)
26	}
27	if y.Len() != n {
28		panic(ErrShape)
29	}
30	if m == 0 || n == 0 {
31		return 0
32	}
33
34	var sum float64
35
36	switch a := a.(type) {
37	case RawSymmetricer:
38		amat := a.RawSymmetric()
39		if amat.Uplo != blas.Upper {
40			// Panic as a string not a mat.Error.
41			panic(badSymTriangle)
42		}
43		var xmat, ymat blas64.Vector
44		if xrv, ok := x.(RawVectorer); ok {
45			xmat = xrv.RawVector()
46		} else {
47			break
48		}
49		if yrv, ok := y.(RawVectorer); ok {
50			ymat = yrv.RawVector()
51		} else {
52			break
53		}
54		for i := 0; i < x.Len(); i++ {
55			xi := x.AtVec(i)
56			if xi != 0 {
57				if ymat.Inc == 1 {
58					sum += xi * f64.DotUnitary(
59						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
60						ymat.Data[i:],
61					)
62				} else {
63					sum += xi * f64.DotInc(
64						amat.Data[i*amat.Stride+i:i*amat.Stride+n],
65						ymat.Data[i*ymat.Inc:], uintptr(n-i),
66						1, uintptr(ymat.Inc),
67						0, 0,
68					)
69				}
70			}
71			yi := y.AtVec(i)
72			if i != n-1 && yi != 0 {
73				if xmat.Inc == 1 {
74					sum += yi * f64.DotUnitary(
75						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
76						xmat.Data[i+1:],
77					)
78				} else {
79					sum += yi * f64.DotInc(
80						amat.Data[i*amat.Stride+i+1:i*amat.Stride+n],
81						xmat.Data[(i+1)*xmat.Inc:], uintptr(n-i-1),
82						1, uintptr(xmat.Inc),
83						0, 0,
84					)
85				}
86			}
87		}
88		return sum
89	case RawMatrixer:
90		amat := a.RawMatrix()
91		var ymat blas64.Vector
92		if yrv, ok := y.(RawVectorer); ok {
93			ymat = yrv.RawVector()
94		} else {
95			break
96		}
97		for i := 0; i < x.Len(); i++ {
98			xi := x.AtVec(i)
99			if xi != 0 {
100				if ymat.Inc == 1 {
101					sum += xi * f64.DotUnitary(
102						amat.Data[i*amat.Stride:i*amat.Stride+n],
103						ymat.Data,
104					)
105				} else {
106					sum += xi * f64.DotInc(
107						amat.Data[i*amat.Stride:i*amat.Stride+n],
108						ymat.Data, uintptr(n),
109						1, uintptr(ymat.Inc),
110						0, 0,
111					)
112				}
113			}
114		}
115		return sum
116	}
117	for i := 0; i < x.Len(); i++ {
118		xi := x.AtVec(i)
119		for j := 0; j < y.Len(); j++ {
120			sum += xi * a.At(i, j) * y.AtVec(j)
121		}
122	}
123	return sum
124}
125