1 #ifndef PARSPARSEMATRIX_H 2 #define PARSPARSEMATRIX_H 3 4 #include "mpi.h" 5 #include "MPI_Wrappers.h" 6 #include "SparseMatrix.h" 7 #include "DiagonalMatrix.h" 8 #include <algorithm> 9 10 namespace ATC_matrix { 11 12 /** 13 * @class ParSparseMatrix 14 * @brief Parallelized version of SparseMatrix class. 15 * 16 * ParSparseMatrix<double>::MultMv is used in LinearSolver, which is then 17 * used in NonLinearSolver, PoissonSolver, and SchrodingerSolver. These 18 * parallelized solvers are used in the following locations: 19 * 20 * - LinearSolver 21 * - ExtrinsicModelDriftDiffusion.cpp (lines 511 and 522) 22 * - AtomicRegulator.cpp (line 926) 23 * - TransferLibrary.cpp (lines 72 and 260) 24 * - PoissonSolver 25 * - ExtrinsicModelDriftDiffusion.cpp (line 232) 26 * - SchrodingerSolver 27 * - ExtrinsicModelDriftDiffusion.cpp (line 251) 28 * - SliceSchrodingerSolver 29 * - ExtrinsicModelDriftDiffusion.cpp (line 246) 30 */ 31 32 template <typename T> 33 class ParSparseMatrix : public SparseMatrix<T> 34 { 35 public: 36 ParSparseMatrix(MPI_Comm comm, INDEX rows = 0, INDEX cols = 0) 37 : SparseMatrix<T>(rows, cols), _comm(comm){} 38 ParSparseMatrix(MPI_Comm comm,const SparseMatrix<T> & c)39 ParSparseMatrix(MPI_Comm comm, const SparseMatrix<T> &c) 40 : SparseMatrix<T>(c), _comm(comm){} 41 ParSparseMatrix(MPI_Comm comm,INDEX * rows,INDEX * cols,T * vals,INDEX size,INDEX nRows,INDEX nCols,INDEX nRowsCRS)42 ParSparseMatrix(MPI_Comm comm, INDEX* rows, INDEX* cols, T* vals, 43 INDEX size, INDEX nRows, INDEX nCols, INDEX nRowsCRS) 44 : SparseMatrix<T>(rows, cols, vals, size, nRows, nCols, nRowsCRS) 45 ,_comm(comm){} 46 ParSparseMatrix(MPI_Comm comm)47 ParSparseMatrix(MPI_Comm comm) 48 : SparseMatrix<T>(), _comm(comm){} 49 50 virtual void operator=(const SparseMatrix<T> &source) 51 { 52 copy(source); 53 } 54 55 template<typename U> 56 friend void ParMultAB(MPI_Comm comm, const SparseMatrix<U>& A, 57 const Matrix<U>& B, DenseMatrix<U>& C); 58 59 private: 60 MPI_Comm _comm; 61 }; 62 63 template <> 64 class ParSparseMatrix<double> : public SparseMatrix<double> 65 { 66 public: 67 // All the same constructors as for SparseMatrix 68 ParSparseMatrix(MPI_Comm comm, INDEX rows = 0, INDEX cols=0); 69 ParSparseMatrix(MPI_Comm comm, const SparseMatrix<double> &c); 70 ParSparseMatrix(MPI_Comm comm, INDEX* rows, INDEX* cols, double* vals, INDEX size, 71 INDEX nRows, INDEX nCols, INDEX nRowsCRS); 72 73 // Parallel sparse matrix multiplication functions 74 void MultMv(const Vector<double>& v, DenseVector<double>& c) const; 75 DenseVector<double> transMat(const Vector<double>& v) const; 76 void MultAB(const Matrix<double>& B, DenseMatrix<double>& C) const; 77 DenseMatrix<double> transMat(const DenseMatrix<double>& B) const; 78 DenseMatrix<double> transMat(const SparseMatrix<double>& B) const; 79 80 virtual void operator=(const SparseMatrix<double> &source); 81 82 template<typename U> 83 friend void ParMultAB(MPI_Comm comm, const SparseMatrix<U>& A, const Matrix<U>& B, DenseMatrix<U>& C); 84 85 private: 86 void partition(ParSparseMatrix<double>& A_local) const; 87 void finalize(); 88 MPI_Comm _comm; 89 }; 90 91 // The SparseMatrix versions of these functions will call the correct 92 // MultMv/MultAB: 93 // DenseVector<double> operator*(const ParSparseMatrix<double> &A, const Vector<double> &v); 94 // DenseVector<double> operator*(const Vector<double> &v, const ParSparseMatrix<double> &A); 95 // DenseMatrix<double> operator*(const ParSparseMatrix<double> &A, const Matrix<double> &B); 96 97 98 template<typename T> ParMultAB(MPI_Comm comm,const SparseMatrix<T> & A,const Matrix<T> & B,DenseMatrix<T> & C)99 void ParMultAB(MPI_Comm comm, const SparseMatrix<T>& A, const Matrix<T>& B, DenseMatrix<T>& C) 100 { 101 SparseMatrix<T>::compress(A); 102 103 INDEX M = A.nRows(), N = B.nCols(); 104 if (!C.is_size(M, N)) 105 { 106 C.resize(M, N); 107 C.zero(); 108 } 109 110 // Temporarily put fields into a ParSparseMatrix for distributed multiplication 111 ParSparseMatrix<T> Ap(comm); 112 Ap._nRows = A._nRows; 113 Ap._nCols = A._nCols; 114 Ap._size = A._size; 115 Ap._nRowsCRS = A._nRowsCRS; 116 Ap._val = A._val; 117 Ap._ja = A._ja; 118 Ap._ia = A._ia; 119 Ap.hasTemplate_ = A.hasTemplate_; 120 121 // MultAB calls compress(), but we hope that does nothing because we just 122 // compressed A. If it did something, it might mess up other members 123 // (e.g. _tri). 124 Ap.MultAB(B, C); 125 126 // We're not changing the matrix's values, so we can justify calling A const. 127 SparseMatrix<T> &Avar = const_cast<SparseMatrix<T> &>(A); 128 Avar._nRows = Ap._nRows; 129 Avar._nCols = Ap._nCols; 130 Avar._size = Ap._size; 131 Avar._nRowsCRS = Ap._nRowsCRS; 132 Avar._val = Ap._val; 133 Avar._ja = Ap._ja; 134 Avar._ia = Ap._ia; 135 Avar.hasTemplate_ = Ap.hasTemplate_; 136 137 // Avoid catastrophe 138 Ap._val = NULL; 139 Ap._ja = NULL; 140 Ap._ia = NULL; 141 } 142 143 144 /*SparseMatrix<double> operator*(const ParSparseMatrix<double> &A, const SparseMatrix<double> &B); 145 146 SparseMatrix<double> operator*(const ParSparseMatrix<double> &A, const DiagonalMatrix<double> &B); 147 */ 148 } // end namespace 149 #endif 150 151