1// Copyright ©2013 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"gonum.org/v1/gonum/blas"
9	"gonum.org/v1/gonum/blas/blas64"
10	"gonum.org/v1/gonum/internal/asm/f64"
11)
12
13var (
14	vector *VecDense
15
16	_ Matrix  = vector
17	_ Vector  = vector
18	_ Reseter = vector
19)
20
21// Vector is a vector.
22type Vector interface {
23	Matrix
24	AtVec(int) float64
25	Len() int
26}
27
28// TransposeVec is a type for performing an implicit transpose of a Vector.
29// It implements the Vector interface, returning values from the transpose
30// of the vector within.
31type TransposeVec struct {
32	Vector Vector
33}
34
35// At returns the value of the element at row i and column j of the transposed
36// matrix, that is, row j and column i of the Vector field.
37func (t TransposeVec) At(i, j int) float64 {
38	return t.Vector.At(j, i)
39}
40
41// AtVec returns the element at position i. It panics if i is out of bounds.
42func (t TransposeVec) AtVec(i int) float64 {
43	return t.Vector.AtVec(i)
44}
45
46// Dims returns the dimensions of the transposed vector.
47func (t TransposeVec) Dims() (r, c int) {
48	c, r = t.Vector.Dims()
49	return r, c
50}
51
52// T performs an implicit transpose by returning the Vector field.
53func (t TransposeVec) T() Matrix {
54	return t.Vector
55}
56
57// Len returns the number of columns in the vector.
58func (t TransposeVec) Len() int {
59	return t.Vector.Len()
60}
61
62// TVec performs an implicit transpose by returning the Vector field.
63func (t TransposeVec) TVec() Vector {
64	return t.Vector
65}
66
67// Untranspose returns the Vector field.
68func (t TransposeVec) Untranspose() Matrix {
69	return t.Vector
70}
71
72func (t TransposeVec) UntransposeVec() Vector {
73	return t.Vector
74}
75
76// VecDense represents a column vector.
77type VecDense struct {
78	mat blas64.Vector
79	// A BLAS vector can have a negative increment, but allowing this
80	// in the mat type complicates a lot of code, and doesn't gain anything.
81	// VecDense must have positive increment in this package.
82}
83
84// NewVecDense creates a new VecDense of length n. If data == nil,
85// a new slice is allocated for the backing slice. If len(data) == n, data is
86// used as the backing slice, and changes to the elements of the returned VecDense
87// will be reflected in data. If neither of these is true, NewVecDense will panic.
88// NewVecDense will panic if n is zero.
89func NewVecDense(n int, data []float64) *VecDense {
90	if n <= 0 {
91		if n == 0 {
92			panic(ErrZeroLength)
93		}
94		panic("mat: negative dimension")
95	}
96	if len(data) != n && data != nil {
97		panic(ErrShape)
98	}
99	if data == nil {
100		data = make([]float64, n)
101	}
102	return &VecDense{
103		mat: blas64.Vector{
104			N:    n,
105			Inc:  1,
106			Data: data,
107		},
108	}
109}
110
111// SliceVec returns a new Vector that shares backing data with the receiver.
112// The returned matrix starts at i of the receiver and extends k-i elements.
113// SliceVec panics with ErrIndexOutOfRange if the slice is outside the capacity
114// of the receiver.
115func (v *VecDense) SliceVec(i, k int) Vector {
116	if i < 0 || k <= i || v.Cap() < k {
117		panic(ErrIndexOutOfRange)
118	}
119	return &VecDense{
120		mat: blas64.Vector{
121			N:    k - i,
122			Inc:  v.mat.Inc,
123			Data: v.mat.Data[i*v.mat.Inc : (k-1)*v.mat.Inc+1],
124		},
125	}
126}
127
128// Dims returns the number of rows and columns in the matrix. Columns is always 1
129// for a non-Reset vector.
130func (v *VecDense) Dims() (r, c int) {
131	if v.IsZero() {
132		return 0, 0
133	}
134	return v.mat.N, 1
135}
136
137// Caps returns the number of rows and columns in the backing matrix. Columns is always 1
138// for a non-Reset vector.
139func (v *VecDense) Caps() (r, c int) {
140	if v.IsZero() {
141		return 0, 0
142	}
143	return v.Cap(), 1
144}
145
146// Len returns the length of the vector.
147func (v *VecDense) Len() int {
148	return v.mat.N
149}
150
151// Cap returns the capacity of the vector.
152func (v *VecDense) Cap() int {
153	if v.IsZero() {
154		return 0
155	}
156	return (cap(v.mat.Data)-1)/v.mat.Inc + 1
157}
158
159// T performs an implicit transpose by returning the receiver inside a Transpose.
160func (v *VecDense) T() Matrix {
161	return Transpose{v}
162}
163
164// TVec performs an implicit transpose by returning the receiver inside a TransposeVec.
165func (v *VecDense) TVec() Vector {
166	return TransposeVec{v}
167}
168
169// Reset zeros the length of the vector so that it can be reused as the
170// receiver of a dimensionally restricted operation.
171//
172// See the Reseter interface for more information.
173func (v *VecDense) Reset() {
174	// No change of Inc or N to 0 may be
175	// made unless both are set to 0.
176	v.mat.Inc = 0
177	v.mat.N = 0
178	v.mat.Data = v.mat.Data[:0]
179}
180
181// Zero sets all of the matrix elements to zero.
182func (v *VecDense) Zero() {
183	for i := 0; i < v.mat.N; i++ {
184		v.mat.Data[v.mat.Inc*i] = 0
185	}
186}
187
188// CloneVec makes a copy of a into the receiver, overwriting the previous value
189// of the receiver.
190func (v *VecDense) CloneVec(a Vector) {
191	if v == a {
192		return
193	}
194	n := a.Len()
195	v.mat = blas64.Vector{
196		N:    n,
197		Inc:  1,
198		Data: use(v.mat.Data, n),
199	}
200	if r, ok := a.(RawVectorer); ok {
201		blas64.Copy(r.RawVector(), v.mat)
202		return
203	}
204	for i := 0; i < a.Len(); i++ {
205		v.SetVec(i, a.AtVec(i))
206	}
207}
208
209// VecDenseCopyOf returns a newly allocated copy of the elements of a.
210func VecDenseCopyOf(a Vector) *VecDense {
211	v := &VecDense{}
212	v.CloneVec(a)
213	return v
214}
215
216func (v *VecDense) RawVector() blas64.Vector {
217	return v.mat
218}
219
220// CopyVec makes a copy of elements of a into the receiver. It is similar to the
221// built-in copy; it copies as much as the overlap between the two vectors and
222// returns the number of elements it copied.
223func (v *VecDense) CopyVec(a Vector) int {
224	n := min(v.Len(), a.Len())
225	if v == a {
226		return n
227	}
228	if r, ok := a.(RawVectorer); ok {
229		blas64.Copy(r.RawVector(), v.mat)
230		return n
231	}
232	for i := 0; i < n; i++ {
233		v.setVec(i, a.AtVec(i))
234	}
235	return n
236}
237
238// ScaleVec scales the vector a by alpha, placing the result in the receiver.
239func (v *VecDense) ScaleVec(alpha float64, a Vector) {
240	n := a.Len()
241
242	if v == a {
243		if v.mat.Inc == 1 {
244			f64.ScalUnitary(alpha, v.mat.Data)
245			return
246		}
247		f64.ScalInc(alpha, v.mat.Data, uintptr(n), uintptr(v.mat.Inc))
248		return
249	}
250
251	v.reuseAs(n)
252
253	if rv, ok := a.(RawVectorer); ok {
254		mat := rv.RawVector()
255		v.checkOverlap(mat)
256		if v.mat.Inc == 1 && mat.Inc == 1 {
257			f64.ScalUnitaryTo(v.mat.Data, alpha, mat.Data)
258			return
259		}
260		f64.ScalIncTo(v.mat.Data, uintptr(v.mat.Inc),
261			alpha, mat.Data, uintptr(n), uintptr(mat.Inc))
262		return
263	}
264
265	for i := 0; i < n; i++ {
266		v.setVec(i, alpha*a.AtVec(i))
267	}
268}
269
270// AddScaledVec adds the vectors a and alpha*b, placing the result in the receiver.
271func (v *VecDense) AddScaledVec(a Vector, alpha float64, b Vector) {
272	if alpha == 1 {
273		v.AddVec(a, b)
274		return
275	}
276	if alpha == -1 {
277		v.SubVec(a, b)
278		return
279	}
280
281	ar := a.Len()
282	br := b.Len()
283
284	if ar != br {
285		panic(ErrShape)
286	}
287
288	var amat, bmat blas64.Vector
289	fast := true
290	aU, _ := untranspose(a)
291	if rv, ok := aU.(RawVectorer); ok {
292		amat = rv.RawVector()
293		if v != a {
294			v.checkOverlap(amat)
295		}
296	} else {
297		fast = false
298	}
299	bU, _ := untranspose(b)
300	if rv, ok := bU.(RawVectorer); ok {
301		bmat = rv.RawVector()
302		if v != b {
303			v.checkOverlap(bmat)
304		}
305	} else {
306		fast = false
307	}
308
309	v.reuseAs(ar)
310
311	switch {
312	case alpha == 0: // v <- a
313		if v == a {
314			return
315		}
316		v.CopyVec(a)
317	case v == a && v == b: // v <- v + alpha * v = (alpha + 1) * v
318		blas64.Scal(alpha+1, v.mat)
319	case !fast: // v <- a + alpha * b without blas64 support.
320		for i := 0; i < ar; i++ {
321			v.setVec(i, a.AtVec(i)+alpha*b.AtVec(i))
322		}
323	case v == a && v != b: // v <- v + alpha * b
324		if v.mat.Inc == 1 && bmat.Inc == 1 {
325			// Fast path for a common case.
326			f64.AxpyUnitaryTo(v.mat.Data, alpha, bmat.Data, amat.Data)
327		} else {
328			f64.AxpyInc(alpha, bmat.Data, v.mat.Data,
329				uintptr(ar), uintptr(bmat.Inc), uintptr(v.mat.Inc), 0, 0)
330		}
331	default: // v <- a + alpha * b or v <- a + alpha * v
332		if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
333			// Fast path for a common case.
334			f64.AxpyUnitaryTo(v.mat.Data, alpha, bmat.Data, amat.Data)
335		} else {
336			f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
337				alpha, bmat.Data, amat.Data,
338				uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
339		}
340	}
341}
342
343// AddVec adds the vectors a and b, placing the result in the receiver.
344func (v *VecDense) AddVec(a, b Vector) {
345	ar := a.Len()
346	br := b.Len()
347
348	if ar != br {
349		panic(ErrShape)
350	}
351
352	v.reuseAs(ar)
353
354	aU, _ := untranspose(a)
355	bU, _ := untranspose(b)
356
357	if arv, ok := aU.(RawVectorer); ok {
358		if brv, ok := bU.(RawVectorer); ok {
359			amat := arv.RawVector()
360			bmat := brv.RawVector()
361
362			if v != a {
363				v.checkOverlap(amat)
364			}
365			if v != b {
366				v.checkOverlap(bmat)
367			}
368
369			if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
370				// Fast path for a common case.
371				f64.AxpyUnitaryTo(v.mat.Data, 1, bmat.Data, amat.Data)
372				return
373			}
374			f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
375				1, bmat.Data, amat.Data,
376				uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
377			return
378		}
379	}
380
381	for i := 0; i < ar; i++ {
382		v.setVec(i, a.AtVec(i)+b.AtVec(i))
383	}
384}
385
386// SubVec subtracts the vector b from a, placing the result in the receiver.
387func (v *VecDense) SubVec(a, b Vector) {
388	ar := a.Len()
389	br := b.Len()
390
391	if ar != br {
392		panic(ErrShape)
393	}
394
395	v.reuseAs(ar)
396
397	aU, _ := untranspose(a)
398	bU, _ := untranspose(b)
399
400	if arv, ok := aU.(RawVectorer); ok {
401		if brv, ok := bU.(RawVectorer); ok {
402			amat := arv.RawVector()
403			bmat := brv.RawVector()
404
405			if v != a {
406				v.checkOverlap(amat)
407			}
408			if v != b {
409				v.checkOverlap(bmat)
410			}
411
412			if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
413				// Fast path for a common case.
414				f64.AxpyUnitaryTo(v.mat.Data, -1, bmat.Data, amat.Data)
415				return
416			}
417			f64.AxpyIncTo(v.mat.Data, uintptr(v.mat.Inc), 0,
418				-1, bmat.Data, amat.Data,
419				uintptr(ar), uintptr(bmat.Inc), uintptr(amat.Inc), 0, 0)
420			return
421		}
422	}
423
424	for i := 0; i < ar; i++ {
425		v.setVec(i, a.AtVec(i)-b.AtVec(i))
426	}
427}
428
429// MulElemVec performs element-wise multiplication of a and b, placing the result
430// in the receiver.
431func (v *VecDense) MulElemVec(a, b Vector) {
432	ar := a.Len()
433	br := b.Len()
434
435	if ar != br {
436		panic(ErrShape)
437	}
438
439	v.reuseAs(ar)
440
441	aU, _ := untranspose(a)
442	bU, _ := untranspose(b)
443
444	if arv, ok := aU.(RawVectorer); ok {
445		if brv, ok := bU.(RawVectorer); ok {
446			amat := arv.RawVector()
447			bmat := brv.RawVector()
448
449			if v != a {
450				v.checkOverlap(amat)
451			}
452			if v != b {
453				v.checkOverlap(bmat)
454			}
455
456			if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
457				// Fast path for a common case.
458				for i, a := range amat.Data {
459					v.mat.Data[i] = a * bmat.Data[i]
460				}
461				return
462			}
463			var ia, ib int
464			for i := 0; i < ar; i++ {
465				v.setVec(i, amat.Data[ia]*bmat.Data[ib])
466				ia += amat.Inc
467				ib += bmat.Inc
468			}
469			return
470		}
471	}
472
473	for i := 0; i < ar; i++ {
474		v.setVec(i, a.AtVec(i)*b.AtVec(i))
475	}
476}
477
478// DivElemVec performs element-wise division of a by b, placing the result
479// in the receiver.
480func (v *VecDense) DivElemVec(a, b Vector) {
481	ar := a.Len()
482	br := b.Len()
483
484	if ar != br {
485		panic(ErrShape)
486	}
487
488	v.reuseAs(ar)
489
490	aU, _ := untranspose(a)
491	bU, _ := untranspose(b)
492
493	if arv, ok := aU.(RawVectorer); ok {
494		if brv, ok := bU.(RawVectorer); ok {
495			amat := arv.RawVector()
496			bmat := brv.RawVector()
497
498			if v != a {
499				v.checkOverlap(amat)
500			}
501			if v != b {
502				v.checkOverlap(bmat)
503			}
504
505			if v.mat.Inc == 1 && amat.Inc == 1 && bmat.Inc == 1 {
506				// Fast path for a common case.
507				for i, a := range amat.Data {
508					v.setVec(i, a/bmat.Data[i])
509				}
510				return
511			}
512			var ia, ib int
513			for i := 0; i < ar; i++ {
514				v.setVec(i, amat.Data[ia]/bmat.Data[ib])
515				ia += amat.Inc
516				ib += bmat.Inc
517			}
518		}
519	}
520
521	for i := 0; i < ar; i++ {
522		v.setVec(i, a.AtVec(i)/b.AtVec(i))
523	}
524}
525
526// MulVec computes a * b. The result is stored into the receiver.
527// MulVec panics if the number of columns in a does not equal the number of rows in b
528// or if the number of columns in b does not equal 1.
529func (v *VecDense) MulVec(a Matrix, b Vector) {
530	r, c := a.Dims()
531	br, bc := b.Dims()
532	if c != br || bc != 1 {
533		panic(ErrShape)
534	}
535
536	aU, trans := untranspose(a)
537	var bmat blas64.Vector
538	fast := true
539	bU, _ := untranspose(b)
540	if rv, ok := bU.(RawVectorer); ok {
541		bmat = rv.RawVector()
542		if v != b {
543			v.checkOverlap(bmat)
544		}
545	} else {
546		fast = false
547	}
548
549	v.reuseAs(r)
550	var restore func()
551	if v == aU {
552		v, restore = v.isolatedWorkspace(aU.(*VecDense))
553		defer restore()
554	} else if v == b {
555		v, restore = v.isolatedWorkspace(b)
556		defer restore()
557	}
558
559	// TODO(kortschak): Improve the non-fast paths.
560	switch aU := aU.(type) {
561	case Vector:
562		if b.Len() == 1 {
563			// {n,1} x {1,1}
564			v.ScaleVec(b.AtVec(0), aU)
565			return
566		}
567
568		// {1,n} x {n,1}
569		if fast {
570			if rv, ok := aU.(RawVectorer); ok {
571				amat := rv.RawVector()
572				if v != aU {
573					v.checkOverlap(amat)
574				}
575
576				if amat.Inc == 1 && bmat.Inc == 1 {
577					// Fast path for a common case.
578					v.setVec(0, f64.DotUnitary(amat.Data, bmat.Data))
579					return
580				}
581				v.setVec(0, f64.DotInc(amat.Data, bmat.Data,
582					uintptr(c), uintptr(amat.Inc), uintptr(bmat.Inc), 0, 0))
583				return
584			}
585		}
586		var sum float64
587		for i := 0; i < c; i++ {
588			sum += aU.AtVec(i) * b.AtVec(i)
589		}
590		v.setVec(0, sum)
591		return
592	case RawSymmetricer:
593		if fast {
594			amat := aU.RawSymmetric()
595			// We don't know that a is a *SymDense, so make
596			// a temporary SymDense to check overlap.
597			(&SymDense{mat: amat}).checkOverlap(v.asGeneral())
598			blas64.Symv(1, amat, bmat, 0, v.mat)
599			return
600		}
601	case RawTriangular:
602		v.CopyVec(b)
603		amat := aU.RawTriangular()
604		// We don't know that a is a *TriDense, so make
605		// a temporary TriDense to check overlap.
606		(&TriDense{mat: amat}).checkOverlap(v.asGeneral())
607		ta := blas.NoTrans
608		if trans {
609			ta = blas.Trans
610		}
611		blas64.Trmv(ta, amat, v.mat)
612	case RawMatrixer:
613		if fast {
614			amat := aU.RawMatrix()
615			// We don't know that a is a *Dense, so make
616			// a temporary Dense to check overlap.
617			(&Dense{mat: amat}).checkOverlap(v.asGeneral())
618			t := blas.NoTrans
619			if trans {
620				t = blas.Trans
621			}
622			blas64.Gemv(t, 1, amat, bmat, 0, v.mat)
623			return
624		}
625	default:
626		if fast {
627			for i := 0; i < r; i++ {
628				var f float64
629				for j := 0; j < c; j++ {
630					f += a.At(i, j) * bmat.Data[j*bmat.Inc]
631				}
632				v.setVec(i, f)
633			}
634			return
635		}
636	}
637
638	for i := 0; i < r; i++ {
639		var f float64
640		for j := 0; j < c; j++ {
641			f += a.At(i, j) * b.AtVec(j)
642		}
643		v.setVec(i, f)
644	}
645}
646
647// reuseAs resizes an empty vector to a r×1 vector,
648// or checks that a non-empty matrix is r×1.
649func (v *VecDense) reuseAs(r int) {
650	if r == 0 {
651		panic(ErrZeroLength)
652	}
653	if v.IsZero() {
654		v.mat = blas64.Vector{
655			N:    r,
656			Inc:  1,
657			Data: use(v.mat.Data, r),
658		}
659		return
660	}
661	if r != v.mat.N {
662		panic(ErrShape)
663	}
664}
665
666// IsZero returns whether the receiver is zero-sized. Zero-sized vectors can be the
667// receiver for size-restricted operations. VecDenses can be zeroed using Reset.
668func (v *VecDense) IsZero() bool {
669	// It must be the case that v.Dims() returns
670	// zeros in this case. See comment in Reset().
671	return v.mat.Inc == 0
672}
673
674func (v *VecDense) isolatedWorkspace(a Vector) (n *VecDense, restore func()) {
675	l := a.Len()
676	if l == 0 {
677		panic(ErrZeroLength)
678	}
679	n = getWorkspaceVec(l, false)
680	return n, func() {
681		v.CopyVec(n)
682		putWorkspaceVec(n)
683	}
684}
685
686// asDense returns a Dense representation of the receiver with the same
687// underlying data.
688func (v *VecDense) asDense() *Dense {
689	return &Dense{
690		mat:     v.asGeneral(),
691		capRows: v.mat.N,
692		capCols: 1,
693	}
694}
695
696// asGeneral returns a blas64.General representation of the receiver with the
697// same underlying data.
698func (v *VecDense) asGeneral() blas64.General {
699	return blas64.General{
700		Rows:   v.mat.N,
701		Cols:   1,
702		Stride: v.mat.Inc,
703		Data:   v.mat.Data,
704	}
705}
706
707// ColViewOf reflects the column j of the RawMatrixer m, into the receiver
708// backed by the same underlying data. The length of the receiver must either be
709// zero or match the number of rows in m.
710func (v *VecDense) ColViewOf(m RawMatrixer, j int) {
711	rm := m.RawMatrix()
712
713	if j >= rm.Cols || j < 0 {
714		panic(ErrColAccess)
715	}
716	if !v.IsZero() && v.mat.N != rm.Rows {
717		panic(ErrShape)
718	}
719
720	v.mat.Inc = rm.Stride
721	v.mat.Data = rm.Data[j : (rm.Rows-1)*rm.Stride+j+1]
722	v.mat.N = rm.Rows
723}
724
725// RowViewOf reflects the row i of the RawMatrixer m, into the receiver
726// backed by the same underlying data. The length of the receiver must either be
727// zero or match the number of columns in m.
728func (v *VecDense) RowViewOf(m RawMatrixer, i int) {
729	rm := m.RawMatrix()
730
731	if i >= rm.Rows || i < 0 {
732		panic(ErrRowAccess)
733	}
734	if !v.IsZero() && v.mat.N != rm.Cols {
735		panic(ErrShape)
736	}
737
738	v.mat.Inc = 1
739	v.mat.Data = rm.Data[i*rm.Stride : i*rm.Stride+rm.Cols]
740	v.mat.N = rm.Cols
741}
742