1 // 2 // gsSparseRows.hpp 3 // 4 // Clemens Hofreither 5 // 6 #pragma once 7 8 #include <vector> 9 #include <stdexcept> 10 11 #include <gsCore/gsLinearAlgebra.h> 12 13 namespace gismo 14 { 15 16 /** 17 * \brief A specialized sparse matrix class which stores each row 18 * as a separate sparse vector. 19 * 20 * This allows efficient row resizing and insertion 21 * operations, particularly for knot insertion algorithms. 22 */ 23 template <class T> 24 class gsSparseRows 25 { 26 public: 27 typedef Eigen::SparseVector<T> Row; 28 29 struct RowBlockXpr; 30 gsSparseRows()31 gsSparseRows() 32 { } 33 gsSparseRows(index_t rows,index_t cols)34 gsSparseRows(index_t rows, index_t cols) 35 : m_rows(rows) 36 { 37 for (index_t i = 0; i < rows; ++i) 38 m_rows[i] = new Row(cols); 39 } 40 gsSparseRows(const gsSparseRows & other)41 gsSparseRows(const gsSparseRows& other) 42 : m_rows(other.rows()) 43 { 44 for (int i = 0; i < rows(); ++i) 45 m_rows[i] = new Row( *other.m_rows[i] ); 46 } 47 gsSparseRows(const RowBlockXpr & rowxpr)48 gsSparseRows(const RowBlockXpr& rowxpr) 49 : m_rows(rowxpr.num) 50 { 51 for (index_t i = 0; i < rowxpr.num; ++i) 52 m_rows[i] = new Row( *rowxpr.mat.m_rows[rowxpr.start + i] ); 53 } 54 ~gsSparseRows()55 ~gsSparseRows() 56 { 57 clear(); 58 } 59 operator =(const gsSparseRows other)60 gsSparseRows& operator= (const gsSparseRows other) 61 { 62 this->swap( other ); 63 return *this; 64 } 65 operator =(const RowBlockXpr & rowxpr)66 gsSparseRows& operator= (const RowBlockXpr& rowxpr) 67 { 68 gsSparseRows temp(rowxpr); 69 this->swap( temp ); 70 return *this; 71 } 72 rows() const73 index_t rows() const { return m_rows.size(); } cols() const74 index_t cols() const { return (m_rows.size() > 0) ? m_rows[0]->size() : 0; } 75 row(index_t i)76 Row& row(index_t i) { return *m_rows[i]; } row(index_t i) const77 const Row& row(index_t i) const { return *m_rows[i]; } 78 clear()79 void clear() 80 { 81 for (int i = 0; i < rows(); ++i) 82 delete m_rows[i]; 83 m_rows.clear(); 84 } 85 swap(gsSparseRows & other)86 void swap(gsSparseRows& other) 87 { 88 m_rows.swap( other.m_rows ); 89 } 90 setIdentity(index_t n)91 void setIdentity(index_t n) 92 { 93 assert( n >= 0 ); 94 95 resize(n, n); 96 97 for (index_t i = 0; i < n; ++i) 98 m_rows[i]->insert(i) = T(1.0); 99 } 100 resize(index_t rows,index_t cols)101 void resize(index_t rows, index_t cols) 102 { 103 assert( rows >= 0 && cols >= 0 ); 104 105 clear(); 106 m_rows.resize(rows); 107 for (index_t i = 0; i < rows; ++i) 108 m_rows[i] = new Row(cols); 109 } 110 conservativeResize(index_t newRows,index_t newCols)111 void conservativeResize(index_t newRows, index_t newCols) 112 { 113 if (rows() > 0 && cols() != newCols) 114 throw std::runtime_error("cannot resize columns -- not implemented"); 115 116 const index_t oldRows = rows(); 117 resizeRows(newRows); 118 119 // allocate newly added rows, if any 120 for (index_t i = oldRows; i < newRows; ++i) 121 m_rows[i] = new Row(newCols); 122 } 123 duplicateRow(index_t k)124 void duplicateRow(index_t k) 125 { 126 assert ( 0 <= k && k < rows() ); 127 128 // add one new row 129 resizeRows( rows() + 1 ); 130 131 // shift rows [k+1,...) down to [k+2,...) 132 for (index_t i = rows() - 1; i > k + 1; --i) 133 m_rows[i] = m_rows[i-1]; 134 135 // allocate new row 136 m_rows[k+1] = new Row( row(k) ); 137 } 138 139 // row expressions 140 topRows(index_t num)141 RowBlockXpr topRows(index_t num) { return RowBlockXpr(*this, 0, num); } topRows(index_t num) const142 const RowBlockXpr topRows(index_t num) const { return RowBlockXpr(*this, 0, num); } 143 bottomRows(index_t num)144 RowBlockXpr bottomRows(index_t num) { return RowBlockXpr(*this, rows() - num, num); } bottomRows(index_t num) const145 const RowBlockXpr bottomRows(index_t num) const { return RowBlockXpr(*this, rows() - num, num); } 146 middleRows(index_t start,index_t num)147 RowBlockXpr middleRows(index_t start, index_t num) { return RowBlockXpr(*this, start, num); } middleRows(index_t start,index_t num) const148 const RowBlockXpr middleRows(index_t start, index_t num) const { return RowBlockXpr(*this, start, num); } 149 nonZeros() const150 index_t nonZeros() const 151 { 152 index_t nnz = 0; 153 for (index_t i = 0; i < rows(); ++i) 154 nnz += m_rows[i]->nonZeros(); 155 return nnz; 156 } 157 158 template <class Derived> toSparseMatrix(Eigen::SparseMatrixBase<Derived> & m) const159 void toSparseMatrix(Eigen::SparseMatrixBase<Derived>& m) const 160 { 161 m.derived().resize( rows(), cols() ); 162 m.derived().reserve( nonZeros() ); 163 for (index_t i = 0; i < rows(); ++i) 164 { 165 for (typename Row::InnerIterator it(*m_rows[i]); it; ++it) 166 m.derived().insert(i, it.index()) = it.value(); 167 } 168 m.derived().makeCompressed(); 169 } 170 171 struct RowBlockXpr 172 { RowBlockXprgismo::gsSparseRows::RowBlockXpr173 RowBlockXpr(const gsSparseRows& _mat, index_t _start, index_t _num) 174 : mat(const_cast<gsSparseRows&>(_mat)), start(_start), num(_num) 175 { 176 // HACK: We cast away the constness of the matrix, otherwise we would need two versions of 177 // this expression class. 178 // It's still safe because the row block methods in gsSparseRows above return the proper constness. 179 assert( 0 <= num && 0 <= start ); 180 assert( start < mat.rows() ); 181 assert( start + num <= mat.rows() ); 182 } 183 184 gsSparseRows & mat; 185 index_t start, num; 186 operator =gismo::gsSparseRows::RowBlockXpr187 RowBlockXpr& operator= (const RowBlockXpr& other) 188 { 189 assert(num == other.num); 190 for (index_t i = 0; i < num; ++i) 191 mat.row(start + i) = other.mat.row(other.start + i); 192 return *this; 193 } 194 operator =gismo::gsSparseRows::RowBlockXpr195 RowBlockXpr& operator= (const gsSparseRows& other) 196 { 197 assert(num == other.rows()); 198 for (index_t i = 0; i < num; ++i) 199 mat.row(start + i) = other.row(i); 200 return *this; 201 } 202 }; 203 204 private: 205 std::vector< Row* > m_rows; 206 207 /// Change the number of rows without allocating newly added rows resizeRows(index_t newRows)208 void resizeRows(index_t newRows) 209 { 210 // delete rows which will be removed from the array 211 // (does nothing if newRows >= rows()) 212 for (index_t i = newRows; i < rows(); ++i) 213 delete m_rows[i]; 214 215 m_rows.resize(newRows); 216 } 217 218 }; 219 220 } // namespace gismo 221