1 use num::Zero; 2 use simba::scalar::ClosedAdd; 3 use std::iter; 4 use std::marker::PhantomData; 5 use std::ops::Range; 6 use std::slice; 7 8 use crate::allocator::Allocator; 9 use crate::sparse::cs_utils; 10 use crate::{Const, DefaultAllocator, Dim, Dynamic, Matrix, OVector, Scalar, Vector, U1}; 11 12 pub struct ColumnEntries<'a, T> { 13 curr: usize, 14 i: &'a [usize], 15 v: &'a [T], 16 } 17 18 impl<'a, T> ColumnEntries<'a, T> { 19 #[inline] new(i: &'a [usize], v: &'a [T]) -> Self20 pub fn new(i: &'a [usize], v: &'a [T]) -> Self { 21 assert_eq!(i.len(), v.len()); 22 Self { curr: 0, i, v } 23 } 24 } 25 26 impl<'a, T: Clone> Iterator for ColumnEntries<'a, T> { 27 type Item = (usize, T); 28 29 #[inline] next(&mut self) -> Option<Self::Item>30 fn next(&mut self) -> Option<Self::Item> { 31 if self.curr >= self.i.len() { 32 None 33 } else { 34 let res = Some((unsafe { *self.i.get_unchecked(self.curr) }, unsafe { 35 self.v.get_unchecked(self.curr).clone() 36 })); 37 self.curr += 1; 38 res 39 } 40 } 41 } 42 43 // TODO: this structure exists for now only because impl trait 44 // cannot be used for trait method return types. 45 /// Trait for iterable compressed-column matrix storage. 46 pub trait CsStorageIter<'a, T, R, C = U1> { 47 /// Iterator through all the rows of a specific columns. 48 /// 49 /// The elements are given as a tuple (`row_index`, value). 50 type ColumnEntries: Iterator<Item = (usize, T)>; 51 /// Iterator through the row indices of a specific column. 52 type ColumnRowIndices: Iterator<Item = usize>; 53 54 /// Iterates through all the row indices of the j-th column. column_row_indices(&'a self, j: usize) -> Self::ColumnRowIndices55 fn column_row_indices(&'a self, j: usize) -> Self::ColumnRowIndices; 56 /// Iterates through all the entries of the j-th column. column_entries(&'a self, j: usize) -> Self::ColumnEntries57 fn column_entries(&'a self, j: usize) -> Self::ColumnEntries; 58 } 59 60 /// Trait for mutably iterable compressed-column sparse matrix storage. 61 pub trait CsStorageIterMut<'a, T: 'a, R, C = U1> { 62 /// Mutable iterator through all the values of the sparse matrix. 63 type ValuesMut: Iterator<Item = &'a mut T>; 64 /// Mutable iterator through all the rows of a specific columns. 65 /// 66 /// The elements are given as a tuple (`row_index`, value). 67 type ColumnEntriesMut: Iterator<Item = (usize, &'a mut T)>; 68 69 /// A mutable iterator through the values buffer of the sparse matrix. values_mut(&'a mut self) -> Self::ValuesMut70 fn values_mut(&'a mut self) -> Self::ValuesMut; 71 /// Iterates mutably through all the entries of the j-th column. column_entries_mut(&'a mut self, j: usize) -> Self::ColumnEntriesMut72 fn column_entries_mut(&'a mut self, j: usize) -> Self::ColumnEntriesMut; 73 } 74 75 /// Trait for compressed column sparse matrix storage. 76 pub trait CsStorage<T, R, C = U1>: for<'a> CsStorageIter<'a, T, R, C> { 77 /// The shape of the stored matrix. shape(&self) -> (R, C)78 fn shape(&self) -> (R, C); 79 /// Retrieve the i-th row index of the underlying row index buffer. 80 /// 81 /// # Safety 82 /// No bound-checking is performed. row_index_unchecked(&self, i: usize) -> usize83 unsafe fn row_index_unchecked(&self, i: usize) -> usize; 84 /// The i-th value on the contiguous value buffer of this storage. 85 /// 86 /// # Safety 87 /// No bound-checking is performed. get_value_unchecked(&self, i: usize) -> &T88 unsafe fn get_value_unchecked(&self, i: usize) -> &T; 89 /// The i-th value on the contiguous value buffer of this storage. get_value(&self, i: usize) -> &T90 fn get_value(&self, i: usize) -> &T; 91 /// Retrieve the i-th row index of the underlying row index buffer. row_index(&self, i: usize) -> usize92 fn row_index(&self, i: usize) -> usize; 93 /// The value indices for the `i`-th column. column_range(&self, i: usize) -> Range<usize>94 fn column_range(&self, i: usize) -> Range<usize>; 95 /// The size of the value buffer (i.e. the entries known as possibly being non-zero). len(&self) -> usize96 fn len(&self) -> usize; 97 } 98 99 /// Trait for compressed column sparse matrix mutable storage. 100 pub trait CsStorageMut<T, R, C = U1>: 101 CsStorage<T, R, C> + for<'a> CsStorageIterMut<'a, T, R, C> 102 { 103 } 104 105 /// A storage of column-compressed sparse matrix based on a Vec. 106 #[derive(Clone, Debug, PartialEq)] 107 pub struct CsVecStorage<T: Scalar, R: Dim, C: Dim> 108 where 109 DefaultAllocator: Allocator<usize, C>, 110 { 111 pub(crate) shape: (R, C), 112 pub(crate) p: OVector<usize, C>, 113 pub(crate) i: Vec<usize>, 114 pub(crate) vals: Vec<T>, 115 } 116 117 impl<T: Scalar, R: Dim, C: Dim> CsVecStorage<T, R, C> 118 where 119 DefaultAllocator: Allocator<usize, C>, 120 { 121 /// The value buffer of this storage. 122 #[must_use] values(&self) -> &[T]123 pub fn values(&self) -> &[T] { 124 &self.vals 125 } 126 127 /// The column shifts buffer. 128 #[must_use] p(&self) -> &[usize]129 pub fn p(&self) -> &[usize] { 130 self.p.as_slice() 131 } 132 133 /// The row index buffers. 134 #[must_use] i(&self) -> &[usize]135 pub fn i(&self) -> &[usize] { 136 &self.i 137 } 138 } 139 140 impl<T: Scalar, R: Dim, C: Dim> CsVecStorage<T, R, C> where DefaultAllocator: Allocator<usize, C> {} 141 142 impl<'a, T: Scalar, R: Dim, C: Dim> CsStorageIter<'a, T, R, C> for CsVecStorage<T, R, C> 143 where 144 DefaultAllocator: Allocator<usize, C>, 145 { 146 type ColumnEntries = ColumnEntries<'a, T>; 147 type ColumnRowIndices = iter::Cloned<slice::Iter<'a, usize>>; 148 149 #[inline] column_entries(&'a self, j: usize) -> Self::ColumnEntries150 fn column_entries(&'a self, j: usize) -> Self::ColumnEntries { 151 let rng = self.column_range(j); 152 ColumnEntries::new(&self.i[rng.clone()], &self.vals[rng]) 153 } 154 155 #[inline] column_row_indices(&'a self, j: usize) -> Self::ColumnRowIndices156 fn column_row_indices(&'a self, j: usize) -> Self::ColumnRowIndices { 157 let rng = self.column_range(j); 158 self.i[rng].iter().cloned() 159 } 160 } 161 162 impl<T: Scalar, R: Dim, C: Dim> CsStorage<T, R, C> for CsVecStorage<T, R, C> 163 where 164 DefaultAllocator: Allocator<usize, C>, 165 { 166 #[inline] shape(&self) -> (R, C)167 fn shape(&self) -> (R, C) { 168 self.shape 169 } 170 171 #[inline] len(&self) -> usize172 fn len(&self) -> usize { 173 self.vals.len() 174 } 175 176 #[inline] row_index(&self, i: usize) -> usize177 fn row_index(&self, i: usize) -> usize { 178 self.i[i] 179 } 180 181 #[inline] row_index_unchecked(&self, i: usize) -> usize182 unsafe fn row_index_unchecked(&self, i: usize) -> usize { 183 *self.i.get_unchecked(i) 184 } 185 186 #[inline] get_value_unchecked(&self, i: usize) -> &T187 unsafe fn get_value_unchecked(&self, i: usize) -> &T { 188 self.vals.get_unchecked(i) 189 } 190 191 #[inline] get_value(&self, i: usize) -> &T192 fn get_value(&self, i: usize) -> &T { 193 &self.vals[i] 194 } 195 196 #[inline] column_range(&self, j: usize) -> Range<usize>197 fn column_range(&self, j: usize) -> Range<usize> { 198 let end = if j + 1 == self.p.len() { 199 self.len() 200 } else { 201 self.p[j + 1] 202 }; 203 204 self.p[j]..end 205 } 206 } 207 208 impl<'a, T: Scalar, R: Dim, C: Dim> CsStorageIterMut<'a, T, R, C> for CsVecStorage<T, R, C> 209 where 210 DefaultAllocator: Allocator<usize, C>, 211 { 212 type ValuesMut = slice::IterMut<'a, T>; 213 type ColumnEntriesMut = iter::Zip<iter::Cloned<slice::Iter<'a, usize>>, slice::IterMut<'a, T>>; 214 215 #[inline] values_mut(&'a mut self) -> Self::ValuesMut216 fn values_mut(&'a mut self) -> Self::ValuesMut { 217 self.vals.iter_mut() 218 } 219 220 #[inline] column_entries_mut(&'a mut self, j: usize) -> Self::ColumnEntriesMut221 fn column_entries_mut(&'a mut self, j: usize) -> Self::ColumnEntriesMut { 222 let rng = self.column_range(j); 223 self.i[rng.clone()] 224 .iter() 225 .cloned() 226 .zip(self.vals[rng].iter_mut()) 227 } 228 } 229 230 impl<T: Scalar, R: Dim, C: Dim> CsStorageMut<T, R, C> for CsVecStorage<T, R, C> where 231 DefaultAllocator: Allocator<usize, C> 232 { 233 } 234 235 /* 236 pub struct CsSliceStorage<'a, T: Scalar, R: Dim, C: DimAdd<U1>> { 237 shape: (R, C), 238 p: VectorSlice<usize, DimSum<C, U1>>, 239 i: VectorSlice<usize, Dynamic>, 240 vals: VectorSlice<T, Dynamic>, 241 }*/ 242 243 /// A compressed sparse column matrix. 244 #[derive(Clone, Debug, PartialEq)] 245 pub struct CsMatrix< 246 T: Scalar, 247 R: Dim = Dynamic, 248 C: Dim = Dynamic, 249 S: CsStorage<T, R, C> = CsVecStorage<T, R, C>, 250 > { 251 pub(crate) data: S, 252 _phantoms: PhantomData<(T, R, C)>, 253 } 254 255 /// A column compressed sparse vector. 256 pub type CsVector<T, R = Dynamic, S = CsVecStorage<T, R, U1>> = CsMatrix<T, R, U1, S>; 257 258 impl<T: Scalar, R: Dim, C: Dim> CsMatrix<T, R, C> 259 where 260 DefaultAllocator: Allocator<usize, C>, 261 { 262 /// Creates a new compressed sparse column matrix with the specified dimension and 263 /// `nvals` possible non-zero values. new_uninitialized_generic(nrows: R, ncols: C, nvals: usize) -> Self264 pub fn new_uninitialized_generic(nrows: R, ncols: C, nvals: usize) -> Self { 265 let mut i = Vec::with_capacity(nvals); 266 unsafe { 267 i.set_len(nvals); 268 } 269 i.shrink_to_fit(); 270 271 let mut vals = Vec::with_capacity(nvals); 272 unsafe { 273 vals.set_len(nvals); 274 } 275 vals.shrink_to_fit(); 276 277 CsMatrix { 278 data: CsVecStorage { 279 shape: (nrows, ncols), 280 p: OVector::zeros_generic(ncols, Const::<1>), 281 i, 282 vals, 283 }, 284 _phantoms: PhantomData, 285 } 286 } 287 288 /* 289 pub(crate) fn from_parts_generic( 290 nrows: R, 291 ncols: C, 292 p: OVector<usize, C>, 293 i: Vec<usize>, 294 vals: Vec<T>, 295 ) -> Self 296 where 297 T: Zero + ClosedAdd, 298 DefaultAllocator: Allocator<T, R>, 299 { 300 assert_eq!(ncols.value(), p.len(), "Invalid inptr size."); 301 assert_eq!(i.len(), vals.len(), "Invalid value size."); 302 303 // Check p. 304 for ptr in &p { 305 assert!(*ptr < i.len(), "Invalid inptr value."); 306 } 307 308 for ptr in p.as_slice().windows(2) { 309 assert!(ptr[0] <= ptr[1], "Invalid inptr ordering."); 310 } 311 312 // Check i. 313 for i in &i { 314 assert!(*i < nrows.value(), "Invalid row ptr value.") 315 } 316 317 let mut res = CsMatrix { 318 data: CsVecStorage { 319 shape: (nrows, ncols), 320 p, 321 i, 322 vals, 323 }, 324 _phantoms: PhantomData, 325 }; 326 327 // Sort and remove duplicates. 328 res.sort(); 329 res.dedup(); 330 331 res 332 }*/ 333 } 334 335 /* 336 impl<T: Scalar + Zero + ClosedAdd> CsMatrix<T> { 337 pub(crate) fn from_parts( 338 nrows: usize, 339 ncols: usize, 340 p: Vec<usize>, 341 i: Vec<usize>, 342 vals: Vec<T>, 343 ) -> Self 344 { 345 let nrows = Dynamic::new(nrows); 346 let ncols = Dynamic::new(ncols); 347 let p = DVector::from_data(VecStorage::new(ncols, U1, p)); 348 Self::from_parts_generic(nrows, ncols, p, i, vals) 349 } 350 } 351 */ 352 353 impl<T: Scalar, R: Dim, C: Dim, S: CsStorage<T, R, C>> CsMatrix<T, R, C, S> { from_data(data: S) -> Self354 pub(crate) fn from_data(data: S) -> Self { 355 CsMatrix { 356 data, 357 _phantoms: PhantomData, 358 } 359 } 360 361 /// The size of the data buffer. 362 #[must_use] len(&self) -> usize363 pub fn len(&self) -> usize { 364 self.data.len() 365 } 366 367 /// The number of rows of this matrix. 368 #[must_use] nrows(&self) -> usize369 pub fn nrows(&self) -> usize { 370 self.data.shape().0.value() 371 } 372 373 /// The number of rows of this matrix. 374 #[must_use] ncols(&self) -> usize375 pub fn ncols(&self) -> usize { 376 self.data.shape().1.value() 377 } 378 379 /// The shape of this matrix. 380 #[must_use] shape(&self) -> (usize, usize)381 pub fn shape(&self) -> (usize, usize) { 382 let (nrows, ncols) = self.data.shape(); 383 (nrows.value(), ncols.value()) 384 } 385 386 /// Whether this matrix is square or not. 387 #[must_use] is_square(&self) -> bool388 pub fn is_square(&self) -> bool { 389 let (nrows, ncols) = self.data.shape(); 390 nrows.value() == ncols.value() 391 } 392 393 /// Should always return `true`. 394 /// 395 /// This method is generally used for debugging and should typically not be called in user code. 396 /// This checks that the row inner indices of this matrix are sorted. It takes `O(n)` time, 397 /// where n` is `self.len()`. 398 /// All operations of CSC matrices on nalgebra assume, and will return, sorted indices. 399 /// If at any time this `is_sorted` method returns `false`, then, something went wrong 400 /// and an issue should be open on the nalgebra repository with details on how to reproduce 401 /// this. 402 #[must_use] is_sorted(&self) -> bool403 pub fn is_sorted(&self) -> bool { 404 for j in 0..self.ncols() { 405 let mut curr = None; 406 for idx in self.data.column_row_indices(j) { 407 if let Some(curr) = curr { 408 if idx <= curr { 409 return false; 410 } 411 } 412 413 curr = Some(idx); 414 } 415 } 416 417 true 418 } 419 420 /// Computes the transpose of this sparse matrix. 421 #[must_use = "This function does not mutate the matrix. Consider using the return value or removing the function call. There's also transpose_mut() for square matrices."] transpose(&self) -> CsMatrix<T, C, R> where DefaultAllocator: Allocator<usize, R>,422 pub fn transpose(&self) -> CsMatrix<T, C, R> 423 where 424 DefaultAllocator: Allocator<usize, R>, 425 { 426 let (nrows, ncols) = self.data.shape(); 427 428 let nvals = self.len(); 429 let mut res = CsMatrix::new_uninitialized_generic(ncols, nrows, nvals); 430 let mut workspace = Vector::zeros_generic(nrows, Const::<1>); 431 432 // Compute p. 433 for i in 0..nvals { 434 let row_id = self.data.row_index(i); 435 workspace[row_id] += 1; 436 } 437 438 let _ = cs_utils::cumsum(&mut workspace, &mut res.data.p); 439 440 // Fill the result. 441 for j in 0..ncols.value() { 442 for (row_id, value) in self.data.column_entries(j) { 443 let shift = workspace[row_id]; 444 445 res.data.vals[shift] = value; 446 res.data.i[shift] = j; 447 workspace[row_id] += 1; 448 } 449 } 450 451 res 452 } 453 } 454 455 impl<T: Scalar, R: Dim, C: Dim, S: CsStorageMut<T, R, C>> CsMatrix<T, R, C, S> { 456 /// Iterator through all the mutable values of this sparse matrix. 457 #[inline] values_mut(&mut self) -> impl Iterator<Item = &mut T>458 pub fn values_mut(&mut self) -> impl Iterator<Item = &mut T> { 459 self.data.values_mut() 460 } 461 } 462 463 impl<T: Scalar, R: Dim, C: Dim> CsMatrix<T, R, C> 464 where 465 DefaultAllocator: Allocator<usize, C>, 466 { sort(&mut self) where T: Zero, DefaultAllocator: Allocator<T, R>,467 pub(crate) fn sort(&mut self) 468 where 469 T: Zero, 470 DefaultAllocator: Allocator<T, R>, 471 { 472 // Size = R 473 let nrows = self.data.shape().0; 474 let mut workspace = Matrix::zeros_generic(nrows, Const::<1>); 475 self.sort_with_workspace(workspace.as_mut_slice()); 476 } 477 sort_with_workspace(&mut self, workspace: &mut [T])478 pub(crate) fn sort_with_workspace(&mut self, workspace: &mut [T]) { 479 assert!( 480 workspace.len() >= self.nrows(), 481 "Workspace must be able to hold at least self.nrows() elements." 482 ); 483 484 for j in 0..self.ncols() { 485 // Scatter the row in the workspace. 486 for (irow, val) in self.data.column_entries(j) { 487 workspace[irow] = val; 488 } 489 490 // Sort the index vector. 491 let range = self.data.column_range(j); 492 self.data.i[range.clone()].sort_unstable(); 493 494 // Permute the values too. 495 for (i, irow) in range.clone().zip(self.data.i[range].iter().cloned()) { 496 self.data.vals[i] = workspace[irow].clone(); 497 } 498 } 499 } 500 501 // Remove dupliate entries on a sorted CsMatrix. dedup(&mut self) where T: Zero + ClosedAdd,502 pub(crate) fn dedup(&mut self) 503 where 504 T: Zero + ClosedAdd, 505 { 506 let mut curr_i = 0; 507 508 for j in 0..self.ncols() { 509 let range = self.data.column_range(j); 510 self.data.p[j] = curr_i; 511 512 if range.start != range.end { 513 let mut value = T::zero(); 514 let mut irow = self.data.i[range.start]; 515 516 for idx in range { 517 let curr_irow = self.data.i[idx]; 518 519 if curr_irow == irow { 520 value += self.data.vals[idx].clone(); 521 } else { 522 self.data.i[curr_i] = irow; 523 self.data.vals[curr_i] = value; 524 value = self.data.vals[idx].clone(); 525 irow = curr_irow; 526 curr_i += 1; 527 } 528 } 529 530 // Handle the last entry. 531 self.data.i[curr_i] = irow; 532 self.data.vals[curr_i] = value; 533 curr_i += 1; 534 } 535 } 536 537 self.data.i.truncate(curr_i); 538 self.data.i.shrink_to_fit(); 539 self.data.vals.truncate(curr_i); 540 self.data.vals.shrink_to_fit(); 541 } 542 } 543