1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import "gonum.org/v1/gonum/blas/blas64"
8
9// checkOverlap returns false if the receiver does not overlap data elements
10// referenced by the parameter and panics otherwise.
11//
12// checkOverlap methods return a boolean to allow the check call to be added to a
13// boolean expression, making use of short-circuit operators.
14func checkOverlap(a, b blas64.General) bool {
15	if cap(a.Data) == 0 || cap(b.Data) == 0 {
16		return false
17	}
18
19	off := offset(a.Data[:1], b.Data[:1])
20
21	if off == 0 {
22		// At least one element overlaps.
23		if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride {
24			panic(regionIdentity)
25		}
26		panic(regionOverlap)
27	}
28
29	if off > 0 && len(a.Data) <= off {
30		// We know a is completely before b.
31		return false
32	}
33	if off < 0 && len(b.Data) <= -off {
34		// We know a is completely after b.
35		return false
36	}
37
38	if a.Stride != b.Stride && a.Stride != 1 && b.Stride != 1 {
39		// Too hard, so assume the worst; if either stride
40		// is one it will be caught in rectanglesOverlap.
41		panic(mismatchedStrides)
42	}
43
44	if off < 0 {
45		off = -off
46		a.Cols, b.Cols = b.Cols, a.Cols
47	}
48	if rectanglesOverlap(off, a.Cols, b.Cols, min(a.Stride, b.Stride)) {
49		panic(regionOverlap)
50	}
51	return false
52}
53
54func (m *Dense) checkOverlap(a blas64.General) bool {
55	return checkOverlap(m.RawMatrix(), a)
56}
57
58func (m *Dense) checkOverlapMatrix(a Matrix) bool {
59	if m == a {
60		return false
61	}
62	var amat blas64.General
63	switch ar := a.(type) {
64	default:
65		return false
66	case RawMatrixer:
67		amat = ar.RawMatrix()
68	case RawSymmetricer:
69		amat = generalFromSymmetric(ar.RawSymmetric())
70	case RawSymBander:
71		amat = generalFromSymmetricBand(ar.RawSymBand())
72	case RawTriangular:
73		amat = generalFromTriangular(ar.RawTriangular())
74	case RawVectorer:
75		r, c := a.Dims()
76		amat = generalFromVector(ar.RawVector(), r, c)
77	}
78	return m.checkOverlap(amat)
79}
80
81func (s *SymDense) checkOverlap(a blas64.General) bool {
82	return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
83}
84
85func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
86	if s == a {
87		return false
88	}
89	var amat blas64.General
90	switch ar := a.(type) {
91	default:
92		return false
93	case RawMatrixer:
94		amat = ar.RawMatrix()
95	case RawSymmetricer:
96		amat = generalFromSymmetric(ar.RawSymmetric())
97	case RawSymBander:
98		amat = generalFromSymmetricBand(ar.RawSymBand())
99	case RawTriangular:
100		amat = generalFromTriangular(ar.RawTriangular())
101	case RawVectorer:
102		r, c := a.Dims()
103		amat = generalFromVector(ar.RawVector(), r, c)
104	}
105	return s.checkOverlap(amat)
106}
107
108// generalFromSymmetric returns a blas64.General with the backing
109// data and dimensions of a.
110func generalFromSymmetric(a blas64.Symmetric) blas64.General {
111	return blas64.General{
112		Rows:   a.N,
113		Cols:   a.N,
114		Stride: a.Stride,
115		Data:   a.Data,
116	}
117}
118
119func (t *TriDense) checkOverlap(a blas64.General) bool {
120	return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
121}
122
123func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
124	if t == a {
125		return false
126	}
127	var amat blas64.General
128	switch ar := a.(type) {
129	default:
130		return false
131	case RawMatrixer:
132		amat = ar.RawMatrix()
133	case RawSymmetricer:
134		amat = generalFromSymmetric(ar.RawSymmetric())
135	case RawSymBander:
136		amat = generalFromSymmetricBand(ar.RawSymBand())
137	case RawTriangular:
138		amat = generalFromTriangular(ar.RawTriangular())
139	case RawVectorer:
140		r, c := a.Dims()
141		amat = generalFromVector(ar.RawVector(), r, c)
142	}
143	return t.checkOverlap(amat)
144}
145
146// generalFromTriangular returns a blas64.General with the backing
147// data and dimensions of a.
148func generalFromTriangular(a blas64.Triangular) blas64.General {
149	return blas64.General{
150		Rows:   a.N,
151		Cols:   a.N,
152		Stride: a.Stride,
153		Data:   a.Data,
154	}
155}
156
157func (v *VecDense) checkOverlap(a blas64.Vector) bool {
158	mat := v.mat
159	if cap(mat.Data) == 0 || cap(a.Data) == 0 {
160		return false
161	}
162
163	off := offset(mat.Data[:1], a.Data[:1])
164
165	if off == 0 {
166		// At least one element overlaps.
167		if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) {
168			panic(regionIdentity)
169		}
170		panic(regionOverlap)
171	}
172
173	if off > 0 && len(mat.Data) <= off {
174		// We know v is completely before a.
175		return false
176	}
177	if off < 0 && len(a.Data) <= -off {
178		// We know v is completely after a.
179		return false
180	}
181
182	if mat.Inc != a.Inc && mat.Inc != 1 && a.Inc != 1 {
183		// Too hard, so assume the worst; if either
184		// increment is one it will be caught below.
185		panic(mismatchedStrides)
186	}
187	inc := min(mat.Inc, a.Inc)
188
189	if inc == 1 || off&inc == 0 {
190		panic(regionOverlap)
191	}
192	return false
193}
194
195// generalFromVector returns a blas64.General with the backing
196// data and dimensions of a.
197func generalFromVector(a blas64.Vector, r, c int) blas64.General {
198	return blas64.General{
199		Rows:   r,
200		Cols:   c,
201		Stride: a.Inc,
202		Data:   a.Data,
203	}
204}
205
206func (s *SymBandDense) checkOverlap(a blas64.General) bool {
207	return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a)
208}
209
210func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool {
211	if s == a {
212		return false
213	}
214	var amat blas64.General
215	switch ar := a.(type) {
216	default:
217		return false
218	case RawMatrixer:
219		amat = ar.RawMatrix()
220	case RawSymmetricer:
221		amat = generalFromSymmetric(ar.RawSymmetric())
222	case RawSymBander:
223		amat = generalFromSymmetricBand(ar.RawSymBand())
224	case RawTriangular:
225		amat = generalFromTriangular(ar.RawTriangular())
226	case RawVectorer:
227		r, c := a.Dims()
228		amat = generalFromVector(ar.RawVector(), r, c)
229	}
230	return s.checkOverlap(amat)
231}
232
233// generalFromSymmetricBand returns a blas64.General with the backing
234// data and dimensions of a.
235func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General {
236	return blas64.General{
237		Rows:   a.N,
238		Cols:   a.K + 1,
239		Data:   a.Data,
240		Stride: a.Stride,
241	}
242}
243