1 #ifndef PARSPARSEMATRIX_H
2 #define PARSPARSEMATRIX_H
3 
4 #include "mpi.h"
5 #include "MPI_Wrappers.h"
6 #include "SparseMatrix.h"
7 #include "DiagonalMatrix.h"
8 #include <algorithm>
9 
10 namespace ATC_matrix {
11 
12   /**
13    *  @class  ParSparseMatrix
14    *  @brief  Parallelized version of SparseMatrix class.
15    *
16    *  ParSparseMatrix<double>::MultMv is used in LinearSolver, which is then
17    *  used in NonLinearSolver, PoissonSolver, and SchrodingerSolver. These
18    *  parallelized solvers are used in the following locations:
19    *
20    *  - LinearSolver
21    *    - ExtrinsicModelDriftDiffusion.cpp (lines 511 and 522)
22    *    - AtomicRegulator.cpp (line 926)
23    *    - TransferLibrary.cpp (lines 72 and 260)
24    *  - PoissonSolver
25    *    - ExtrinsicModelDriftDiffusion.cpp (line 232)
26    *  - SchrodingerSolver
27    *    - ExtrinsicModelDriftDiffusion.cpp (line 251)
28    *  - SliceSchrodingerSolver
29    *    - ExtrinsicModelDriftDiffusion.cpp (line 246)
30    */
31 
32   template <typename T>
33   class ParSparseMatrix : public SparseMatrix<T>
34   {
35     public:
36       ParSparseMatrix(MPI_Comm comm, INDEX rows = 0, INDEX cols = 0)
37         : SparseMatrix<T>(rows, cols), _comm(comm){}
38 
ParSparseMatrix(MPI_Comm comm,const SparseMatrix<T> & c)39       ParSparseMatrix(MPI_Comm comm, const SparseMatrix<T> &c)
40         : SparseMatrix<T>(c), _comm(comm){}
41 
ParSparseMatrix(MPI_Comm comm,INDEX * rows,INDEX * cols,T * vals,INDEX size,INDEX nRows,INDEX nCols,INDEX nRowsCRS)42       ParSparseMatrix(MPI_Comm comm, INDEX* rows, INDEX* cols, T* vals,
43                       INDEX size, INDEX nRows, INDEX nCols, INDEX nRowsCRS)
44         : SparseMatrix<T>(rows, cols, vals, size, nRows, nCols, nRowsCRS)
45          ,_comm(comm){}
46 
ParSparseMatrix(MPI_Comm comm)47       ParSparseMatrix(MPI_Comm comm)
48         : SparseMatrix<T>(), _comm(comm){}
49 
50     virtual void operator=(const SparseMatrix<T> &source)
51     {
52       copy(source);
53     }
54 
55     template<typename U>
56     friend void ParMultAB(MPI_Comm comm, const SparseMatrix<U>& A,
57                           const Matrix<U>& B, DenseMatrix<U>& C);
58 
59     private:
60         MPI_Comm _comm;
61   };
62 
63   template <>
64   class ParSparseMatrix<double> : public SparseMatrix<double>
65   {
66   public:
67     // All the same constructors as for SparseMatrix
68     ParSparseMatrix(MPI_Comm comm, INDEX rows = 0, INDEX cols=0);
69     ParSparseMatrix(MPI_Comm comm, const SparseMatrix<double> &c);
70     ParSparseMatrix(MPI_Comm comm, INDEX* rows, INDEX* cols, double* vals, INDEX size,
71         INDEX nRows, INDEX nCols, INDEX nRowsCRS);
72 
73     // Parallel sparse matrix multiplication functions
74     void MultMv(const Vector<double>& v, DenseVector<double>& c) const;
75     DenseVector<double> transMat(const Vector<double>& v) const;
76     void MultAB(const Matrix<double>& B, DenseMatrix<double>& C) const;
77     DenseMatrix<double> transMat(const DenseMatrix<double>& B) const;
78     DenseMatrix<double> transMat(const SparseMatrix<double>& B) const;
79 
80     virtual void operator=(const SparseMatrix<double> &source);
81 
82     template<typename U>
83     friend void ParMultAB(MPI_Comm comm, const SparseMatrix<U>& A, const Matrix<U>& B, DenseMatrix<U>& C);
84 
85   private:
86     void partition(ParSparseMatrix<double>& A_local) const;
87     void finalize();
88     MPI_Comm _comm;
89   };
90 
91   // The SparseMatrix versions of these functions will call the correct
92   //   MultMv/MultAB:
93   // DenseVector<double> operator*(const ParSparseMatrix<double> &A, const Vector<double> &v);
94   // DenseVector<double> operator*(const Vector<double> &v, const ParSparseMatrix<double> &A);
95   // DenseMatrix<double> operator*(const ParSparseMatrix<double> &A, const Matrix<double> &B);
96 
97 
98   template<typename T>
ParMultAB(MPI_Comm comm,const SparseMatrix<T> & A,const Matrix<T> & B,DenseMatrix<T> & C)99   void ParMultAB(MPI_Comm comm, const SparseMatrix<T>& A, const Matrix<T>& B, DenseMatrix<T>& C)
100   {
101     SparseMatrix<T>::compress(A);
102 
103     INDEX M = A.nRows(), N = B.nCols();
104     if (!C.is_size(M, N))
105     {
106       C.resize(M, N);
107       C.zero();
108     }
109 
110     // Temporarily put fields into a ParSparseMatrix for distributed multiplication
111     ParSparseMatrix<T> Ap(comm);
112     Ap._nRows       = A._nRows;
113     Ap._nCols       = A._nCols;
114     Ap._size        = A._size;
115     Ap._nRowsCRS    = A._nRowsCRS;
116     Ap._val         = A._val;
117     Ap._ja          = A._ja;
118     Ap._ia          = A._ia;
119     Ap.hasTemplate_ = A.hasTemplate_;
120 
121     // MultAB calls compress(), but we hope that does nothing because we just
122     // compressed A. If it did something, it might mess up other members
123     // (e.g. _tri).
124     Ap.MultAB(B, C);
125 
126     // We're not changing the matrix's values, so we can justify calling A const.
127     SparseMatrix<T> &Avar = const_cast<SparseMatrix<T> &>(A);
128     Avar._nRows       = Ap._nRows;
129     Avar._nCols       = Ap._nCols;
130     Avar._size        = Ap._size;
131     Avar._nRowsCRS    = Ap._nRowsCRS;
132     Avar._val         = Ap._val;
133     Avar._ja          = Ap._ja;
134     Avar._ia          = Ap._ia;
135     Avar.hasTemplate_ = Ap.hasTemplate_;
136 
137     // Avoid catastrophe
138     Ap._val = NULL;
139     Ap._ja  = NULL;
140     Ap._ia  = NULL;
141   }
142 
143 
144   /*SparseMatrix<double> operator*(const ParSparseMatrix<double> &A, const SparseMatrix<double> &B);
145 
146   SparseMatrix<double> operator*(const ParSparseMatrix<double> &A, const DiagonalMatrix<double> &B);
147   */
148 } // end namespace
149 #endif
150 
151