1// Copyright ©2015 The Gonum Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package mat
6
7import (
8	"gonum.org/v1/gonum/blas/blas64"
9)
10
11const (
12	// regionOverlap is the panic string used for the general case
13	// of a matrix region overlap between a source and destination.
14	regionOverlap = "mat: bad region: overlap"
15
16	// regionIdentity is the panic string used for the specific
17	// case of complete agreement between a source and a destination.
18	regionIdentity = "mat: bad region: identical"
19
20	// mismatchedStrides is the panic string used for overlapping
21	// data slices with differing strides.
22	mismatchedStrides = "mat: bad region: different strides"
23)
24
25// checkOverlap returns false if the receiver does not overlap data elements
26// referenced by the parameter and panics otherwise.
27//
28// checkOverlap methods return a boolean to allow the check call to be added to a
29// boolean expression, making use of short-circuit operators.
30func checkOverlap(a, b blas64.General) bool {
31	if cap(a.Data) == 0 || cap(b.Data) == 0 {
32		return false
33	}
34
35	off := offset(a.Data[:1], b.Data[:1])
36
37	if off == 0 {
38		// At least one element overlaps.
39		if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride {
40			panic(regionIdentity)
41		}
42		panic(regionOverlap)
43	}
44
45	if off > 0 && len(a.Data) <= off {
46		// We know a is completely before b.
47		return false
48	}
49	if off < 0 && len(b.Data) <= -off {
50		// We know a is completely after b.
51		return false
52	}
53
54	if a.Stride != b.Stride {
55		// Too hard, so assume the worst.
56		panic(mismatchedStrides)
57	}
58
59	if off < 0 {
60		off = -off
61		a.Cols, b.Cols = b.Cols, a.Cols
62	}
63	if rectanglesOverlap(off, a.Cols, b.Cols, a.Stride) {
64		panic(regionOverlap)
65	}
66	return false
67}
68
69func (m *Dense) checkOverlap(a blas64.General) bool {
70	return checkOverlap(m.RawMatrix(), a)
71}
72
73func (m *Dense) checkOverlapMatrix(a Matrix) bool {
74	if m == a {
75		return false
76	}
77	var amat blas64.General
78	switch a := a.(type) {
79	default:
80		return false
81	case RawMatrixer:
82		amat = a.RawMatrix()
83	case RawSymmetricer:
84		amat = generalFromSymmetric(a.RawSymmetric())
85	case RawTriangular:
86		amat = generalFromTriangular(a.RawTriangular())
87	}
88	return m.checkOverlap(amat)
89}
90
91func (s *SymDense) checkOverlap(a blas64.General) bool {
92	return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a)
93}
94
95func (s *SymDense) checkOverlapMatrix(a Matrix) bool {
96	if s == a {
97		return false
98	}
99	var amat blas64.General
100	switch a := a.(type) {
101	default:
102		return false
103	case RawMatrixer:
104		amat = a.RawMatrix()
105	case RawSymmetricer:
106		amat = generalFromSymmetric(a.RawSymmetric())
107	case RawTriangular:
108		amat = generalFromTriangular(a.RawTriangular())
109	}
110	return s.checkOverlap(amat)
111}
112
113// generalFromSymmetric returns a blas64.General with the backing
114// data and dimensions of a.
115func generalFromSymmetric(a blas64.Symmetric) blas64.General {
116	return blas64.General{
117		Rows:   a.N,
118		Cols:   a.N,
119		Stride: a.Stride,
120		Data:   a.Data,
121	}
122}
123
124func (t *TriDense) checkOverlap(a blas64.General) bool {
125	return checkOverlap(generalFromTriangular(t.RawTriangular()), a)
126}
127
128func (t *TriDense) checkOverlapMatrix(a Matrix) bool {
129	if t == a {
130		return false
131	}
132	var amat blas64.General
133	switch a := a.(type) {
134	default:
135		return false
136	case RawMatrixer:
137		amat = a.RawMatrix()
138	case RawSymmetricer:
139		amat = generalFromSymmetric(a.RawSymmetric())
140	case RawTriangular:
141		amat = generalFromTriangular(a.RawTriangular())
142	}
143	return t.checkOverlap(amat)
144}
145
146// generalFromTriangular returns a blas64.General with the backing
147// data and dimensions of a.
148func generalFromTriangular(a blas64.Triangular) blas64.General {
149	return blas64.General{
150		Rows:   a.N,
151		Cols:   a.N,
152		Stride: a.Stride,
153		Data:   a.Data,
154	}
155}
156
157func (v *VecDense) checkOverlap(a blas64.Vector) bool {
158	mat := v.mat
159	if cap(mat.Data) == 0 || cap(a.Data) == 0 {
160		return false
161	}
162
163	off := offset(mat.Data[:1], a.Data[:1])
164
165	if off == 0 {
166		// At least one element overlaps.
167		if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) {
168			panic(regionIdentity)
169		}
170		panic(regionOverlap)
171	}
172
173	if off > 0 && len(mat.Data) <= off {
174		// We know v is completely before a.
175		return false
176	}
177	if off < 0 && len(a.Data) <= -off {
178		// We know v is completely after a.
179		return false
180	}
181
182	if mat.Inc != a.Inc {
183		// Too hard, so assume the worst.
184		panic(mismatchedStrides)
185	}
186
187	if mat.Inc == 1 || off&mat.Inc == 0 {
188		panic(regionOverlap)
189	}
190	return false
191}
192
193// rectanglesOverlap returns whether the strided rectangles a and b overlap
194// when b is offset by off elements after a but has at least one element before
195// the end of a. off must be positive. a and b have aCols and bCols respectively.
196//
197// rectanglesOverlap works by shifting both matrices left such that the left
198// column of a is at 0. The column indexes are flattened by obtaining the shifted
199// relative left and right column positions modulo the common stride. This allows
200// direct comparison of the column offsets when the matrix backing data slices
201// are known to overlap.
202func rectanglesOverlap(off, aCols, bCols, stride int) bool {
203	if stride == 1 {
204		// Unit stride means overlapping data
205		// slices must overlap as matrices.
206		return true
207	}
208
209	// Flatten the shifted matrix column positions
210	// so a starts at 0, modulo the common stride.
211	aTo := aCols
212	// The mod stride operations here make the from
213	// and to indexes comparable between a and b when
214	// the data slices of a and b overlap.
215	bFrom := off % stride
216	bTo := (bFrom + bCols) % stride
217
218	if bTo == 0 || bFrom < bTo {
219		// b matrix is not wrapped: compare for
220		// simple overlap.
221		return bFrom < aTo
222	}
223
224	// b strictly wraps and so must overlap with a.
225	return true
226}
227