1 // 2 // BAGEL - Brilliantly Advanced General Electronic Structure Library 3 // Filename: cistring.h 4 // Copyright (C) 2013 Toru Shiozaki 5 // 6 // Author: Shane Parker <shane.parker@u.northwestern.edu> 7 // Maintainer: Shiozaki group 8 // 9 // This file is part of the BAGEL package. 10 // 11 // This program is free software: you can redistribute it and/or modify 12 // it under the terms of the GNU General Public License as published by 13 // the Free Software Foundation, either version 3 of the License, or 14 // (at your option) any later version. 15 // 16 // This program is distributed in the hope that it will be useful, 17 // but WITHOUT ANY WARRANTY; without even the implied warranty of 18 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 19 // GNU General Public License for more details. 20 // 21 // You should have received a copy of the GNU General Public License 22 // along with this program. If not, see <http://www.gnu.org/licenses/>. 23 // 24 25 26 #ifndef BAGEL_CIUTIL_STRINGSPACE_H 27 #define BAGEL_CIUTIL_STRINGSPACE_H 28 29 #include <bitset> 30 #include <algorithm> 31 #include <src/util/constants.h> 32 #include <src/util/parallel/staticdist.h> 33 #include <src/util/parallel/mpi_interface.h> 34 #include <src/util/serialization.h> 35 #include <src/ci/ciutil/cistringmap.h> 36 37 namespace bagel { 38 39 // Contains all the strings and information for lexical ordering for one particular graph (set of strings) 40 // comprised of three subgraphs (one each for RASI, RASII, RASIII) 41 class CIGraph { 42 protected: 43 size_t nele_; 44 size_t norb_; 45 size_t size_; 46 std::vector<size_t> weights_; 47 48 private: 49 friend class boost::serialization::access; 50 template <class Archive> serialize(Archive & ar,const unsigned int)51 void serialize(Archive& ar, const unsigned int) { 52 ar & nele_ & norb_ & size_& weights_; 53 } 54 55 public: CIGraph()56 CIGraph() { } 57 CIGraph(const size_t nele, const size_t norb); 58 weight(const size_t i,const size_t j)59 size_t& weight(const size_t i, const size_t j) { assert(nele_*norb_ > 0); return weights_[i + j*norb_]; } weight(const size_t i,const size_t j)60 const size_t& weight(const size_t i, const size_t j) const { assert(nele_*norb_ > 0); return weights_[i + j*norb_]; } 61 size()62 size_t size() const { return size_; } 63 lexical(const int & start,const int & fence,const std::bitset<nbit__> & abit)64 size_t lexical(const int& start, const int& fence, const std::bitset<nbit__>& abit) const { 65 size_t out = 0; 66 67 int k = 0; 68 for (int i = start; i < fence; ++i) 69 if (abit[i]) out += weight(i-start,k++); 70 return out; 71 } 72 }; 73 74 75 class CIString_base { 76 private: 77 friend class boost::serialization::access; 78 template <class Archive> serialize(Archive & ar,const unsigned int)79 void serialize(Archive& ar, const unsigned int) { } 80 public: CIString_base()81 CIString_base() { } ~CIString_base()82 virtual ~CIString_base() { } 83 }; 84 85 86 template<int N, class Derived> 87 class CIString_base_impl : public CIString_base { 88 protected: 89 int norb_; 90 int nele_; 91 size_t offset_; 92 std::vector<std::bitset<nbit__>> strings_; 93 94 std::array<std::pair<int, int>, N> subspace_; 95 std::array<std::shared_ptr<CIGraph>, N> graphs_; 96 std::shared_ptr<const StaticDist> dist_; 97 init()98 void init() { 99 compute_strings(); 100 } compute_strings()101 void compute_strings() { static_cast<Derived*>(this)->compute_strings_impl(); } 102 103 private: 104 friend class boost::serialization::access; 105 template <class Archive> serialize(Archive & ar,const unsigned int version)106 void serialize(Archive& ar, const unsigned int version) { 107 boost::serialization::split_member(ar, *this, version); 108 } 109 template <class Archive> save(Archive & ar,const unsigned int)110 void save(Archive& ar, const unsigned int) const { 111 ar << boost::serialization::base_object<CIString_base>(*this) << norb_ << nele_ << offset_ << strings_ << subspace_ << graphs_; 112 } 113 template <class Archive> load(Archive & ar,const unsigned int)114 void load(Archive& ar, const unsigned int) { 115 ar >> boost::serialization::base_object<CIString_base>(*this) >> norb_ >> nele_ >> offset_ >> strings_ >> subspace_ >> graphs_; 116 const size_t size = std::accumulate(graphs_.begin(), graphs_.end(), 1, [](size_t n, const std::shared_ptr<CIGraph>& i) { return n*i->size(); }); 117 dist_ = std::make_shared<StaticDist>(size, mpi__->size()); 118 } 119 120 public: CIString_base_impl()121 CIString_base_impl() : norb_(0), nele_(0), offset_(0) { } CIString_base_impl(std::initializer_list<size_t> args)122 CIString_base_impl(std::initializer_list<size_t> args) { 123 assert(args.size() == 2*N+1); 124 auto iter = args.begin(); 125 for (int i = 0; i != N; ++i) { 126 const int a = *iter++; 127 const int b = *iter++; 128 subspace_[i] = {a, b}; 129 graphs_[i] = std::make_shared<CIGraph>(a, b); 130 } 131 // setting to CIString_base 132 offset_ = *iter++; 133 norb_ = std::accumulate(subspace_.begin(), subspace_.end(), 0, [](int n, const std::pair<int, int>& i) { return n+i.second; }); 134 nele_ = std::accumulate(subspace_.begin(), subspace_.end(), 0, [](int n, const std::pair<int, int>& i) { return n+i.first; }); 135 assert(iter == args.end()); 136 137 const size_t size = std::accumulate(graphs_.begin(), graphs_.end(), 1, [](size_t n, const std::shared_ptr<CIGraph>& i) { return n*i->size(); }); 138 dist_ = std::make_shared<StaticDist>(size, mpi__->size()); 139 } 140 141 // copy construct with an offset update CIString_base_impl(const CIString_base_impl<N,Derived> & o,const size_t offset)142 CIString_base_impl(const CIString_base_impl<N,Derived>& o, const size_t offset) 143 : norb_(o.norb_), nele_(o.nele_), offset_(offset), strings_(o.strings_), subspace_(o.subspace_), 144 graphs_(o.graphs_), dist_(o.dist_) { } 145 ~CIString_base_impl()146 virtual ~CIString_base_impl() { } 147 nele()148 int nele() const { return nele_; } norb()149 int norb() const { return norb_; } size()150 size_t size() const { return strings_.size(); } offset()151 size_t offset() const { return offset_; } 152 empty()153 bool empty() const { return size() == 0; } 154 strings()155 const std::vector<std::bitset<nbit__>>& strings() const { return strings_; } strings(const size_t i)156 const std::bitset<nbit__>& strings(const size_t i) const { return strings_[i]; } 157 begin()158 std::vector<std::bitset<nbit__>>::iterator begin() { return strings_.begin(); } end()159 std::vector<std::bitset<nbit__>>::iterator end() { return strings_.end(); } begin()160 std::vector<std::bitset<nbit__>>::const_iterator begin() const { return strings_.cbegin(); } end()161 std::vector<std::bitset<nbit__>>::const_iterator end() const { return strings_.cend(); } 162 size(const int & i)163 size_t size(const int& i) const { return graphs_[i]->size(); } 164 dist()165 std::shared_ptr<const StaticDist> dist() const { return dist_; } 166 lexical_zero(const std::bitset<nbit__> & bit)167 size_t lexical_zero(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->lexical_zero_impl(bit); } lexical_offset(const std::bitset<nbit__> & bit)168 size_t lexical_offset(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->lexical_offset_impl(bit); } 169 contains(const std::bitset<nbit__> & bit)170 bool contains(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->contains_impl(bit); } matches(const int i,const int j)171 bool matches(const int i, const int j) const { return static_cast<const Derived*>(this)->matches_impl(i,j); } 172 template <typename U> matches(std::shared_ptr<const U> o)173 bool matches(std::shared_ptr<const U> o) const { return static_cast<const Derived*>(this)->matches_impl(o); } 174 }; 175 176 177 class RASString : public CIString_base_impl<3,RASString> { 178 friend class CIString_base_impl<3,RASString>; 179 protected: 180 void compute_strings_impl(); 181 contains_impl(const std::bitset<nbit__> & bit)182 bool contains_impl(const std::bitset<nbit__>& bit) const { 183 assert(bit.count() == nele_); 184 return nholes(bit) == nholes() && nparticles(bit) == nparticles(); 185 } 186 matches_impl(const int nh,const int np)187 bool matches_impl(const int nh, const int np) const { 188 return nh == nholes() && np == nparticles(); 189 } matches_impl(const std::shared_ptr<const RASString> o)190 bool matches_impl(const std::shared_ptr<const RASString> o) const { 191 return matches_impl(o->nholes(), o->nparticles()); 192 } 193 lexical_offset_impl(const std::bitset<nbit__> & bit)194 size_t lexical_offset_impl(const std::bitset<nbit__>& bit) const { 195 return lexical_zero(bit) + offset_; 196 } 197 lexical_zero_impl(const std::bitset<nbit__> & bit)198 size_t lexical_zero_impl(const std::bitset<nbit__>& bit) const { 199 const size_t r1 = subspace_[0].second; 200 const size_t r2 = subspace_[1].second; 201 const size_t r3 = subspace_[2].second; 202 203 const size_t n2 = graphs_[1]->size(); 204 const size_t n1 = graphs_[0]->size(); 205 206 size_t out = 0; 207 out += graphs_[1]->lexical(r1, r1+r2, bit); 208 out += n2 * graphs_[0]->lexical(0, r1, bit); 209 out += n2 * n1 * graphs_[2]->lexical(r1+r2, r1+r2+r3, bit); 210 211 return out; 212 } 213 214 // helper functions nholes(const std::bitset<nbit__> & bit)215 int nholes(const std::bitset<nbit__>& bit) const { 216 return subspace_[0].second - (bit & (~std::bitset<nbit__>(0ull) >> (nbit__ - subspace_[0].second))).count(); 217 } nparticles(const std::bitset<nbit__> & bit)218 int nparticles(const std::bitset<nbit__>& bit) const { 219 return (bit & (~(~std::bitset<nbit__>(0ull) << subspace_[2].second) << subspace_[0].second + subspace_[1].second)).count(); 220 } 221 222 private: 223 friend class boost::serialization::access; 224 template <class Archive> serialize(Archive & ar,const unsigned int)225 void serialize(Archive& ar, const unsigned int) { 226 ar & boost::serialization::base_object<CIString_base_impl<3,RASString>>(*this); 227 } 228 229 public: RASString()230 RASString() { } 231 RASString(const size_t nele1, const size_t norb1, const size_t nele2, const size_t norb2, const size_t nele3, const size_t norb3, const size_t offset = 0); 232 RASString(const RASString& o, const size_t offset = 0) : CIString_base_impl<3,RASString>(o, offset) { } 233 nholes()234 int nholes() const { return subspace_[0].second - subspace_[0].first; } nele2()235 int nele2() const { return nele_ - subspace_[0].first - subspace_[2].first; } nparticles()236 int nparticles() const { return subspace_[2].first; } 237 tag()238 size_t tag() const { return nholes() + (nparticles() << 8); } 239 ras()240 template <int S> const std::pair<const int, const int> ras() const { 241 static_assert(S == 0 || S == 1 || S == 2, "illegal call of RAString::ras"); 242 return std::get<S>(subspace_); 243 } 244 245 }; 246 247 248 class FCIString : public CIString_base_impl<1,FCIString> { 249 friend class CIString_base_impl<1,FCIString>; 250 protected: 251 void compute_strings_impl(); 252 contains_impl(const std::bitset<nbit__> & bit)253 bool contains_impl(const std::bitset<nbit__>& bit) const { assert(bit.count() == nele_); return true; } matches_impl(const int n,const int m)254 bool matches_impl(const int n, const int m) const { return true; } matches_impl(const std::shared_ptr<const FCIString> o)255 bool matches_impl(const std::shared_ptr<const FCIString> o) const { return true; } 256 lexical_offset_impl(const std::bitset<nbit__> & bit)257 size_t lexical_offset_impl(const std::bitset<nbit__>& bit) const { return lexical(bit)+offset_; } lexical_zero_impl(const std::bitset<nbit__> & bit)258 size_t lexical_zero_impl(const std::bitset<nbit__>& bit) const { return lexical(bit); } 259 260 private: 261 friend class boost::serialization::access; 262 template <class Archive> serialize(Archive & ar,const unsigned int)263 void serialize(Archive& ar, const unsigned int) { 264 ar & boost::serialization::base_object<CIString_base_impl<1,FCIString>>(*this); 265 } 266 267 public: FCIString()268 FCIString() { } 269 FCIString(const size_t nele1, const size_t norb1, const size_t offset = 0); 270 FCIString(const FCIString& o, const size_t offset = 0) : CIString_base_impl<1,FCIString>(o, offset) { } 271 lexical(const std::bitset<nbit__> & bit)272 size_t lexical(const std::bitset<nbit__>& bit) const { 273 assert(contains(bit)); 274 return graphs_[0]->lexical(0, norb_, bit); 275 } 276 }; 277 278 using FCIString_base = CIString_base_impl<1,FCIString>; 279 using RASString_base = CIString_base_impl<3,RASString>; 280 281 } 282 283 extern template class bagel::CIString_base_impl<1,bagel::FCIString>; 284 extern template class bagel::CIString_base_impl<3,bagel::RASString>; 285 286 #include <src/util/archive.h> 287 BOOST_CLASS_EXPORT_KEY(bagel::CIGraph) 288 BOOST_CLASS_EXPORT_KEY(bagel::CIString_base) 289 BOOST_CLASS_EXPORT_KEY(bagel::RASString_base) 290 BOOST_CLASS_EXPORT_KEY(bagel::FCIString_base) 291 BOOST_CLASS_EXPORT_KEY(bagel::RASString) 292 BOOST_CLASS_EXPORT_KEY(bagel::FCIString) 293 294 #endif 295