1// Copyright ©2015 The Gonum Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package mat 6 7import "gonum.org/v1/gonum/blas/blas64" 8 9// checkOverlap returns false if the receiver does not overlap data elements 10// referenced by the parameter and panics otherwise. 11// 12// checkOverlap methods return a boolean to allow the check call to be added to a 13// boolean expression, making use of short-circuit operators. 14func checkOverlap(a, b blas64.General) bool { 15 if cap(a.Data) == 0 || cap(b.Data) == 0 { 16 return false 17 } 18 19 off := offset(a.Data[:1], b.Data[:1]) 20 21 if off == 0 { 22 // At least one element overlaps. 23 if a.Cols == b.Cols && a.Rows == b.Rows && a.Stride == b.Stride { 24 panic(regionIdentity) 25 } 26 panic(regionOverlap) 27 } 28 29 if off > 0 && len(a.Data) <= off { 30 // We know a is completely before b. 31 return false 32 } 33 if off < 0 && len(b.Data) <= -off { 34 // We know a is completely after b. 35 return false 36 } 37 38 if a.Stride != b.Stride && a.Stride != 1 && b.Stride != 1 { 39 // Too hard, so assume the worst; if either stride 40 // is one it will be caught in rectanglesOverlap. 41 panic(mismatchedStrides) 42 } 43 44 if off < 0 { 45 off = -off 46 a.Cols, b.Cols = b.Cols, a.Cols 47 } 48 if rectanglesOverlap(off, a.Cols, b.Cols, min(a.Stride, b.Stride)) { 49 panic(regionOverlap) 50 } 51 return false 52} 53 54func (m *Dense) checkOverlap(a blas64.General) bool { 55 return checkOverlap(m.RawMatrix(), a) 56} 57 58func (m *Dense) checkOverlapMatrix(a Matrix) bool { 59 if m == a { 60 return false 61 } 62 var amat blas64.General 63 switch ar := a.(type) { 64 default: 65 return false 66 case RawMatrixer: 67 amat = ar.RawMatrix() 68 case RawSymmetricer: 69 amat = generalFromSymmetric(ar.RawSymmetric()) 70 case RawSymBander: 71 amat = generalFromSymmetricBand(ar.RawSymBand()) 72 case RawTriangular: 73 amat = generalFromTriangular(ar.RawTriangular()) 74 case RawVectorer: 75 r, c := a.Dims() 76 amat = generalFromVector(ar.RawVector(), r, c) 77 } 78 return m.checkOverlap(amat) 79} 80 81func (s *SymDense) checkOverlap(a blas64.General) bool { 82 return checkOverlap(generalFromSymmetric(s.RawSymmetric()), a) 83} 84 85func (s *SymDense) checkOverlapMatrix(a Matrix) bool { 86 if s == a { 87 return false 88 } 89 var amat blas64.General 90 switch ar := a.(type) { 91 default: 92 return false 93 case RawMatrixer: 94 amat = ar.RawMatrix() 95 case RawSymmetricer: 96 amat = generalFromSymmetric(ar.RawSymmetric()) 97 case RawSymBander: 98 amat = generalFromSymmetricBand(ar.RawSymBand()) 99 case RawTriangular: 100 amat = generalFromTriangular(ar.RawTriangular()) 101 case RawVectorer: 102 r, c := a.Dims() 103 amat = generalFromVector(ar.RawVector(), r, c) 104 } 105 return s.checkOverlap(amat) 106} 107 108// generalFromSymmetric returns a blas64.General with the backing 109// data and dimensions of a. 110func generalFromSymmetric(a blas64.Symmetric) blas64.General { 111 return blas64.General{ 112 Rows: a.N, 113 Cols: a.N, 114 Stride: a.Stride, 115 Data: a.Data, 116 } 117} 118 119func (t *TriDense) checkOverlap(a blas64.General) bool { 120 return checkOverlap(generalFromTriangular(t.RawTriangular()), a) 121} 122 123func (t *TriDense) checkOverlapMatrix(a Matrix) bool { 124 if t == a { 125 return false 126 } 127 var amat blas64.General 128 switch ar := a.(type) { 129 default: 130 return false 131 case RawMatrixer: 132 amat = ar.RawMatrix() 133 case RawSymmetricer: 134 amat = generalFromSymmetric(ar.RawSymmetric()) 135 case RawSymBander: 136 amat = generalFromSymmetricBand(ar.RawSymBand()) 137 case RawTriangular: 138 amat = generalFromTriangular(ar.RawTriangular()) 139 case RawVectorer: 140 r, c := a.Dims() 141 amat = generalFromVector(ar.RawVector(), r, c) 142 } 143 return t.checkOverlap(amat) 144} 145 146// generalFromTriangular returns a blas64.General with the backing 147// data and dimensions of a. 148func generalFromTriangular(a blas64.Triangular) blas64.General { 149 return blas64.General{ 150 Rows: a.N, 151 Cols: a.N, 152 Stride: a.Stride, 153 Data: a.Data, 154 } 155} 156 157func (v *VecDense) checkOverlap(a blas64.Vector) bool { 158 mat := v.mat 159 if cap(mat.Data) == 0 || cap(a.Data) == 0 { 160 return false 161 } 162 163 off := offset(mat.Data[:1], a.Data[:1]) 164 165 if off == 0 { 166 // At least one element overlaps. 167 if mat.Inc == a.Inc && len(mat.Data) == len(a.Data) { 168 panic(regionIdentity) 169 } 170 panic(regionOverlap) 171 } 172 173 if off > 0 && len(mat.Data) <= off { 174 // We know v is completely before a. 175 return false 176 } 177 if off < 0 && len(a.Data) <= -off { 178 // We know v is completely after a. 179 return false 180 } 181 182 if mat.Inc != a.Inc && mat.Inc != 1 && a.Inc != 1 { 183 // Too hard, so assume the worst; if either 184 // increment is one it will be caught below. 185 panic(mismatchedStrides) 186 } 187 inc := min(mat.Inc, a.Inc) 188 189 if inc == 1 || off&inc == 0 { 190 panic(regionOverlap) 191 } 192 return false 193} 194 195// generalFromVector returns a blas64.General with the backing 196// data and dimensions of a. 197func generalFromVector(a blas64.Vector, r, c int) blas64.General { 198 return blas64.General{ 199 Rows: r, 200 Cols: c, 201 Stride: a.Inc, 202 Data: a.Data, 203 } 204} 205 206func (s *SymBandDense) checkOverlap(a blas64.General) bool { 207 return checkOverlap(generalFromSymmetricBand(s.RawSymBand()), a) 208} 209 210func (s *SymBandDense) checkOverlapMatrix(a Matrix) bool { 211 if s == a { 212 return false 213 } 214 var amat blas64.General 215 switch ar := a.(type) { 216 default: 217 return false 218 case RawMatrixer: 219 amat = ar.RawMatrix() 220 case RawSymmetricer: 221 amat = generalFromSymmetric(ar.RawSymmetric()) 222 case RawSymBander: 223 amat = generalFromSymmetricBand(ar.RawSymBand()) 224 case RawTriangular: 225 amat = generalFromTriangular(ar.RawTriangular()) 226 case RawVectorer: 227 r, c := a.Dims() 228 amat = generalFromVector(ar.RawVector(), r, c) 229 } 230 return s.checkOverlap(amat) 231} 232 233// generalFromSymmetricBand returns a blas64.General with the backing 234// data and dimensions of a. 235func generalFromSymmetricBand(a blas64.SymmetricBand) blas64.General { 236 return blas64.General{ 237 Rows: a.N, 238 Cols: a.K + 1, 239 Data: a.Data, 240 Stride: a.Stride, 241 } 242} 243