1// Copyright ©2015 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import ( 8 "gonum.org/v1/gonum/blas/blas64" 9) 10 11const ( 12 // regionOverlap is the panic string used for the general case 13 // of a matrix region overlap between a source and destination. 14 regionOverlap = "mat: bad region: overlap" 15 16 // regionIdentity is the panic string used for the specific 17 // case of complete agreement between a source and a destination. 18 regionIdentity = "mat: bad region: identical" 19 20 // mismatchedStrides is the panic string used for overlapping 21 // data slices with differing strides. 22 mismatchedStrides = "mat: bad region: different strides" 23) 24 25// checkOverlap returns false if the receiver does not overlap data elements 26// referenced by the parameter and panics otherwise. 27// 28// checkOverlap methods return a boolean to allow the check call to be added to a 29// boolean expression, making use of short-circuit operators. 30func checkOverlap(a, b blas64.General) bool { 31 if cap(a.Data) == 0 || cap(b.Data) == 0 { 32 return false 33 } 34 35 off := offset(a.Data[:1], b.Data[:1]) 36 37 if off == 0 { 38 // At least one element overlaps. 39 if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride { 40 panic(regionIdentity) 41 } 42 panic(regionOverlap) 43 } 44 45 if off > 0 && len(a.Data) <= off { 46 // We know a is completely before b. 47 return false 48 } 49 if off < 0 && len(b.Data) <= -off { 50 // We know a is completely after b. 51 return false 52 } 53 54 if a.Stride != b.Stride { 55 // Too hard, so assume the worst. 56 panic(mismatchedStrides) 57 } 58 59 if off < 0 { 60 off = -off 61 a.Cols, b.Cols = b.Cols, a.Cols 62 } 63 if rectanglesOverlap(off, a.Cols, b.Cols, a.Stride) { 64 panic(regionOverlap) 65 } 66 return false 67} 68 69func (m *Dense) checkOverlap(a blas64.General) bool { 70 return checkOverlap(m.RawMatrix(), a) 71} 72 73func (m *Dense) checkOverlapMatrix(a Matrix) bool { 74 if m == a { 75 return false 76 } 77 var amat blas64.General 78 switch a := a.(type) { 79 default: 80 return false 81 case RawMatrixer: 82 amat = a.RawMatrix() 83 case RawSymmetricer: 84 amat = generalFromSymmetric(a.RawSymmetric()) 85 case RawTriangular: 86 amat = generalFromTriangular(a.RawTriangular()) 87 } 88 return m.checkOverlap(amat) 89} 90 91func (s *SymDense) checkOverlap(a blas64.General) bool { 92 return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) 93} 94 95func (s *SymDense) checkOverlapMatrix(a Matrix) bool { 96 if s == a { 97 return false 98 } 99 var amat blas64.General 100 switch a := a.(type) { 101 default: 102 return false 103 case RawMatrixer: 104 amat = a.RawMatrix() 105 case RawSymmetricer: 106 amat = generalFromSymmetric(a.RawSymmetric()) 107 case RawTriangular: 108 amat = generalFromTriangular(a.RawTriangular()) 109 } 110 return s.checkOverlap(amat) 111} 112 113// generalFromSymmetric returns a blas64.General with the backing 114// data and dimensions of a. 115func generalFromSymmetric(a blas64.Symmetric) blas64.General { 116 return blas64.General{ 117 Rows: a.N, 118 Cols: a.N, 119 Stride: a.Stride, 120 Data: a.Data, 121 } 122} 123 124func (t *TriDense) checkOverlap(a blas64.General) bool { 125 return checkOverlap(generalFromTriangular(t.RawTriangular()), a) 126} 127 128func (t *TriDense) checkOverlapMatrix(a Matrix) bool { 129 if t == a { 130 return false 131 } 132 var amat blas64.General 133 switch a := a.(type) { 134 default: 135 return false 136 case RawMatrixer: 137 amat = a.RawMatrix() 138 case RawSymmetricer: 139 amat = generalFromSymmetric(a.RawSymmetric()) 140 case RawTriangular: 141 amat = generalFromTriangular(a.RawTriangular()) 142 } 143 return t.checkOverlap(amat) 144} 145 146// generalFromTriangular returns a blas64.General with the backing 147// data and dimensions of a. 148func generalFromTriangular(a blas64.Triangular) blas64.General { 149 return blas64.General{ 150 Rows: a.N, 151 Cols: a.N, 152 Stride: a.Stride, 153 Data: a.Data, 154 } 155} 156 157func (v *VecDense) checkOverlap(a blas64.Vector) bool { 158 mat := v.mat 159 if cap(mat.Data) == 0 || cap(a.Data) == 0 { 160 return false 161 } 162 163 off := offset(mat.Data[:1], a.Data[:1]) 164 165 if off == 0 { 166 // At least one element overlaps. 167 if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) { 168 panic(regionIdentity) 169 } 170 panic(regionOverlap) 171 } 172 173 if off > 0 && len(mat.Data) <= off { 174 // We know v is completely before a. 175 return false 176 } 177 if off < 0 && len(a.Data) <= -off { 178 // We know v is completely after a. 179 return false 180 } 181 182 if mat.Inc != a.Inc { 183 // Too hard, so assume the worst. 184 panic(mismatchedStrides) 185 } 186 187 if mat.Inc == 1 || off&mat.Inc == 0 { 188 panic(regionOverlap) 189 } 190 return false 191} 192 193// rectanglesOverlap returns whether the strided rectangles a and b overlap 194// when b is offset by off elements after a but has at least one element before 195// the end of a. off must be positive. a and b have aCols and bCols respectively. 196// 197// rectanglesOverlap works by shifting both matrices left such that the left 198// column of a is at 0. The column indexes are flattened by obtaining the shifted 199// relative left and right column positions modulo the common stride. This allows 200// direct comparison of the column offsets when the matrix backing data slices 201// are known to overlap. 202func rectanglesOverlap(off, aCols, bCols, stride int) bool { 203 if stride == 1 { 204 // Unit stride means overlapping data 205 // slices must overlap as matrices. 206 return true 207 } 208 209 // Flatten the shifted matrix column positions 210 // so a starts at 0, modulo the common stride. 211 aTo := aCols 212 // The mod stride operations here make the from 213 // and to indexes comparable between a and b when 214 // the data slices of a and b overlap. 215 bFrom := off % stride 216 bTo := (bFrom + bCols) % stride 217 218 if bTo == 0 || bFrom < bTo { 219 // b matrix is not wrapped: compare for 220 // simple overlap. 221 return bFrom < aTo 222 } 223 224 // b strictly wraps and so must overlap with a. 225 return true 226} 227