1 //
2 // BAGEL - Brilliantly Advanced General Electronic Structure Library
3 // Filename: matview.cc
4 // Copyright (C) 2014 Toru Shiozaki
5 //
6 // Author: Toru Shiozaki <shiozaki@northwestern.edu>
7 // Maintainer: Shiozaki group
8 //
9 // This file is part of the BAGEL package.
10 //
11 // This program is free software: you can redistribute it and/or modify
12 // it under the terms of the GNU General Public License as published by
13 // the Free Software Foundation, either version 3 of the License, or
14 // (at your option) any later version.
15 //
16 // This program is distributed in the hope that it will be useful,
17 // but WITHOUT ANY WARRANTY; without even the implied warranty of
18 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU General Public License
22 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 //
24 
25 #include <src/util/math/matrix_base.h>
26 #include <src/util/math/matview.h>
27 
28 using namespace std;
29 using namespace bagel;
30 
31 
32 template<typename DataType>
init()33 void MatView_<DataType>::init() {
34 #ifdef HAVE_SCALAPACK
35   if (!localized_) {
36     desc_ = mpi__->descinit(ndim(), mdim());
37     localsize_ = mpi__->numroc(ndim(), mdim());
38   }
39 #endif
40 }
41 
42 
43 #ifdef HAVE_SCALAPACK
44 template<typename DataType>
setlocal_(const std::unique_ptr<DataType[]> & local)45 void MatView_<DataType>::setlocal_(const std::unique_ptr<DataType[]>& local) {
46   zero();
47 
48   const int localrow = std::get<0>(localsize_);
49   const int localcol = std::get<1>(localsize_);
50 
51   const int nblock = localrow/blocksize__;
52   const int mblock = localcol/blocksize__;
53   const size_t nstride = blocksize__*mpi__->nprow();
54   const size_t mstride = blocksize__*mpi__->npcol();
55   const int myprow = mpi__->myprow()*blocksize__;
56   const int mypcol = mpi__->mypcol()*blocksize__;
57 
58   for (int i = 0; i != mblock; ++i)
59     for (int j = 0; j != nblock; ++j)
60       for (int id = 0; id != blocksize__; ++id)
61         std::copy_n(&local[j*blocksize__+localrow*(i*blocksize__+id)], blocksize__, element_ptr(myprow+j*nstride, mypcol+i*mstride+id));
62 
63   for (int id = 0; id != localcol % blocksize__; ++id) {
64     for (int j = 0; j != nblock; ++j)
65       std::copy_n(&local[j*blocksize__+localrow*(mblock*blocksize__+id)], blocksize__, element_ptr(myprow+j*nstride, mypcol+mblock*mstride+id));
66     for (int jd = 0; jd != localrow % blocksize__; ++jd)
67       element(myprow+nblock*nstride+jd, mypcol+mblock*mstride+id) = local[nblock*blocksize__+jd+localrow*(mblock*blocksize__+id)];
68   }
69   for (int i = 0; i != mblock; ++i)
70     for (int id = 0; id != blocksize__; ++id)
71       for (int jd = 0; jd != localrow % blocksize__; ++jd)
72         element(myprow+nblock*nstride+jd, mypcol+i*mstride+id) = local[nblock*blocksize__+jd+localrow*(i*blocksize__+id)];
73 
74   // syncronize (this can be improved, but...)
75   allreduce();
76 }
77 
78 
79 template<typename DataType>
getlocal() const80 std::unique_ptr<DataType[]> MatView_<DataType>::getlocal() const {
81   const int localrow = std::get<0>(localsize_);
82   const int localcol = std::get<1>(localsize_);
83 
84   std::unique_ptr<DataType[]> local(new DataType[localrow*localcol]);
85 
86   const int nblock = localrow/blocksize__;
87   const int mblock = localcol/blocksize__;
88   const size_t nstride = blocksize__*mpi__->nprow();
89   const size_t mstride = blocksize__*mpi__->npcol();
90   const int myprow = mpi__->myprow()*blocksize__;
91   const int mypcol = mpi__->mypcol()*blocksize__;
92 
93   for (int i = 0; i != mblock; ++i)
94     for (int j = 0; j != nblock; ++j)
95       for (int id = 0; id != blocksize__; ++id)
96         std::copy_n(element_ptr(myprow+j*nstride, mypcol+i*mstride+id), blocksize__, &local[j*blocksize__+localrow*(i*blocksize__+id)]);
97 
98   for (int id = 0; id != localcol % blocksize__; ++id) {
99     for (int j = 0; j != nblock; ++j)
100       std::copy_n(element_ptr(myprow+j*nstride, mypcol+mblock*mstride+id), blocksize__, &local[j*blocksize__+localrow*(mblock*blocksize__+id)]);
101     for (int jd = 0; jd != localrow % blocksize__; ++jd)
102       local[nblock*blocksize__+jd+localrow*(mblock*blocksize__+id)] = element(myprow+nblock*nstride+jd, mypcol+mblock*mstride+id);
103   }
104   for (int i = 0; i != mblock; ++i)
105     for (int id = 0; id != blocksize__; ++id)
106       for (int jd = 0; jd != localrow % blocksize__; ++jd)
107         local[nblock*blocksize__+jd+localrow*(i*blocksize__+id)] = element(myprow+nblock*nstride+jd, mypcol+i*mstride+id);
108   return local;
109 }
110 #endif
111 
112 
113 template class bagel::MatView_<double>;
114 template class bagel::MatView_<complex<double>>;
115