1 //
2 // BAGEL - Brilliantly Advanced General Electronic Structure Library
3 // Filename: cistring.h
4 // Copyright (C) 2013 Toru Shiozaki
5 //
6 // Author: Shane Parker <shane.parker@u.northwestern.edu>
7 // Maintainer: Shiozaki group
8 //
9 // This file is part of the BAGEL package.
10 //
11 // This program is free software: you can redistribute it and/or modify
12 // it under the terms of the GNU General Public License as published by
13 // the Free Software Foundation, either version 3 of the License, or
14 // (at your option) any later version.
15 //
16 // This program is distributed in the hope that it will be useful,
17 // but WITHOUT ANY WARRANTY; without even the implied warranty of
18 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
19 // GNU General Public License for more details.
20 //
21 // You should have received a copy of the GNU General Public License
22 // along with this program.  If not, see <http://www.gnu.org/licenses/>.
23 //
24 
25 
26 #ifndef BAGEL_CIUTIL_STRINGSPACE_H
27 #define BAGEL_CIUTIL_STRINGSPACE_H
28 
29 #include <bitset>
30 #include <algorithm>
31 #include <src/util/constants.h>
32 #include <src/util/parallel/staticdist.h>
33 #include <src/util/parallel/mpi_interface.h>
34 #include <src/util/serialization.h>
35 #include <src/ci/ciutil/cistringmap.h>
36 
37 namespace bagel {
38 
39 // Contains all the strings and information for lexical ordering for one particular graph (set of strings)
40 //   comprised of three subgraphs (one each for RASI, RASII, RASIII)
41 class CIGraph {
42   protected:
43     size_t nele_;
44     size_t norb_;
45     size_t size_;
46     std::vector<size_t> weights_;
47 
48   private:
49     friend class boost::serialization::access;
50     template <class Archive>
serialize(Archive & ar,const unsigned int)51     void serialize(Archive& ar, const unsigned int) {
52       ar & nele_ & norb_ & size_& weights_;
53     }
54 
55   public:
CIGraph()56     CIGraph() { }
57     CIGraph(const size_t nele, const size_t norb);
58 
weight(const size_t i,const size_t j)59     size_t& weight(const size_t i, const size_t j) { assert(nele_*norb_ > 0); return weights_[i + j*norb_]; }
weight(const size_t i,const size_t j)60     const size_t& weight(const size_t i, const size_t j) const { assert(nele_*norb_ > 0); return weights_[i + j*norb_]; }
61 
size()62     size_t size() const { return size_; }
63 
lexical(const int & start,const int & fence,const std::bitset<nbit__> & abit)64     size_t lexical(const int& start, const int& fence, const std::bitset<nbit__>& abit) const {
65       size_t out = 0;
66 
67       int k = 0;
68       for (int i = start; i < fence; ++i)
69         if (abit[i]) out += weight(i-start,k++);
70       return out;
71     }
72 };
73 
74 
75 class CIString_base {
76   private:
77     friend class boost::serialization::access;
78     template <class Archive>
serialize(Archive & ar,const unsigned int)79     void serialize(Archive& ar, const unsigned int) { }
80   public:
CIString_base()81     CIString_base() { }
~CIString_base()82     virtual ~CIString_base() { }
83 };
84 
85 
86 template<int N, class Derived>
87 class CIString_base_impl : public CIString_base {
88   protected:
89     int norb_;
90     int nele_;
91     size_t offset_;
92     std::vector<std::bitset<nbit__>> strings_;
93 
94     std::array<std::pair<int, int>, N> subspace_;
95     std::array<std::shared_ptr<CIGraph>, N> graphs_;
96     std::shared_ptr<const StaticDist> dist_;
97 
init()98     void init() {
99       compute_strings();
100     }
compute_strings()101     void compute_strings() { static_cast<Derived*>(this)->compute_strings_impl(); }
102 
103   private:
104     friend class boost::serialization::access;
105     template <class Archive>
serialize(Archive & ar,const unsigned int version)106     void serialize(Archive& ar, const unsigned int version) {
107       boost::serialization::split_member(ar, *this, version);
108     }
109     template <class Archive>
save(Archive & ar,const unsigned int)110     void save(Archive& ar, const unsigned int) const {
111       ar << boost::serialization::base_object<CIString_base>(*this) << norb_ << nele_ << offset_ << strings_ << subspace_ << graphs_;
112     }
113     template <class Archive>
load(Archive & ar,const unsigned int)114     void load(Archive& ar, const unsigned int) {
115       ar >> boost::serialization::base_object<CIString_base>(*this) >> norb_ >> nele_ >> offset_ >> strings_ >> subspace_ >> graphs_;
116       const size_t size = std::accumulate(graphs_.begin(), graphs_.end(), 1, [](size_t n, const std::shared_ptr<CIGraph>& i) { return n*i->size(); });
117       dist_ = std::make_shared<StaticDist>(size, mpi__->size());
118     }
119 
120   public:
CIString_base_impl()121     CIString_base_impl() : norb_(0), nele_(0), offset_(0) { }
CIString_base_impl(std::initializer_list<size_t> args)122     CIString_base_impl(std::initializer_list<size_t> args) {
123       assert(args.size() == 2*N+1);
124       auto iter = args.begin();
125       for (int i = 0; i != N; ++i) {
126         const int a = *iter++;
127         const int b = *iter++;
128         subspace_[i] = {a, b};
129         graphs_[i] = std::make_shared<CIGraph>(a, b);
130       }
131       // setting to CIString_base
132       offset_ = *iter++;
133       norb_ = std::accumulate(subspace_.begin(), subspace_.end(), 0, [](int n, const std::pair<int, int>& i) { return n+i.second; });
134       nele_ = std::accumulate(subspace_.begin(), subspace_.end(), 0, [](int n, const std::pair<int, int>& i) { return n+i.first; });
135       assert(iter == args.end());
136 
137       const size_t size = std::accumulate(graphs_.begin(), graphs_.end(), 1, [](size_t n, const std::shared_ptr<CIGraph>& i) { return n*i->size(); });
138       dist_ = std::make_shared<StaticDist>(size, mpi__->size());
139     }
140 
141     // copy construct with an offset update
CIString_base_impl(const CIString_base_impl<N,Derived> & o,const size_t offset)142     CIString_base_impl(const CIString_base_impl<N,Derived>& o, const size_t offset)
143       : norb_(o.norb_), nele_(o.nele_), offset_(offset), strings_(o.strings_), subspace_(o.subspace_),
144         graphs_(o.graphs_), dist_(o.dist_) { }
145 
~CIString_base_impl()146     virtual ~CIString_base_impl() { }
147 
nele()148     int nele() const { return nele_; }
norb()149     int norb() const { return norb_; }
size()150     size_t size() const { return strings_.size(); }
offset()151     size_t offset() const { return offset_; }
152 
empty()153     bool empty() const { return size() == 0; }
154 
strings()155     const std::vector<std::bitset<nbit__>>& strings() const { return strings_; }
strings(const size_t i)156     const std::bitset<nbit__>& strings(const size_t i) const { return strings_[i]; }
157 
begin()158     std::vector<std::bitset<nbit__>>::iterator begin() { return strings_.begin(); }
end()159     std::vector<std::bitset<nbit__>>::iterator end() { return strings_.end(); }
begin()160     std::vector<std::bitset<nbit__>>::const_iterator begin() const { return strings_.cbegin(); }
end()161     std::vector<std::bitset<nbit__>>::const_iterator end() const { return strings_.cend(); }
162 
size(const int & i)163     size_t size(const int& i) const { return graphs_[i]->size(); }
164 
dist()165     std::shared_ptr<const StaticDist> dist() const { return dist_; }
166 
lexical_zero(const std::bitset<nbit__> & bit)167     size_t lexical_zero(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->lexical_zero_impl(bit); }
lexical_offset(const std::bitset<nbit__> & bit)168     size_t lexical_offset(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->lexical_offset_impl(bit); }
169 
contains(const std::bitset<nbit__> & bit)170     bool contains(const std::bitset<nbit__>& bit) const { return static_cast<const Derived*>(this)->contains_impl(bit); }
matches(const int i,const int j)171     bool matches(const int i, const int j) const { return static_cast<const Derived*>(this)->matches_impl(i,j); }
172     template <typename U>
matches(std::shared_ptr<const U> o)173     bool matches(std::shared_ptr<const U> o) const { return static_cast<const Derived*>(this)->matches_impl(o); }
174 };
175 
176 
177 class RASString : public CIString_base_impl<3,RASString> {
178   friend class CIString_base_impl<3,RASString>;
179   protected:
180     void compute_strings_impl();
181 
contains_impl(const std::bitset<nbit__> & bit)182     bool contains_impl(const std::bitset<nbit__>& bit) const {
183       assert(bit.count() == nele_);
184       return nholes(bit) == nholes() && nparticles(bit) == nparticles();
185     }
186 
matches_impl(const int nh,const int np)187     bool matches_impl(const int nh, const int np) const {
188       return nh == nholes() && np == nparticles();
189     }
matches_impl(const std::shared_ptr<const RASString> o)190     bool matches_impl(const std::shared_ptr<const RASString> o) const {
191       return matches_impl(o->nholes(), o->nparticles());
192     }
193 
lexical_offset_impl(const std::bitset<nbit__> & bit)194     size_t lexical_offset_impl(const std::bitset<nbit__>& bit) const {
195       return lexical_zero(bit) + offset_;
196     }
197 
lexical_zero_impl(const std::bitset<nbit__> & bit)198     size_t lexical_zero_impl(const std::bitset<nbit__>& bit) const {
199       const size_t r1 = subspace_[0].second;
200       const size_t r2 = subspace_[1].second;
201       const size_t r3 = subspace_[2].second;
202 
203       const size_t n2 = graphs_[1]->size();
204       const size_t n1 = graphs_[0]->size();
205 
206       size_t out = 0;
207       out += graphs_[1]->lexical(r1, r1+r2, bit);
208       out += n2 * graphs_[0]->lexical(0, r1, bit);
209       out += n2 * n1 * graphs_[2]->lexical(r1+r2, r1+r2+r3, bit);
210 
211       return out;
212     }
213 
214     // helper functions
nholes(const std::bitset<nbit__> & bit)215     int nholes(const std::bitset<nbit__>& bit) const {
216       return subspace_[0].second - (bit & (~std::bitset<nbit__>(0ull) >> (nbit__ - subspace_[0].second))).count();
217     }
nparticles(const std::bitset<nbit__> & bit)218     int nparticles(const std::bitset<nbit__>& bit) const {
219       return (bit & (~(~std::bitset<nbit__>(0ull) << subspace_[2].second) << subspace_[0].second + subspace_[1].second)).count();
220     }
221 
222   private:
223     friend class boost::serialization::access;
224     template <class Archive>
serialize(Archive & ar,const unsigned int)225     void serialize(Archive& ar, const unsigned int) {
226       ar & boost::serialization::base_object<CIString_base_impl<3,RASString>>(*this);
227     }
228 
229   public:
RASString()230     RASString() { }
231     RASString(const size_t nele1, const size_t norb1, const size_t nele2, const size_t norb2, const size_t nele3, const size_t norb3, const size_t offset = 0);
232     RASString(const RASString& o, const size_t offset = 0) : CIString_base_impl<3,RASString>(o, offset) { }
233 
nholes()234     int nholes() const { return subspace_[0].second - subspace_[0].first; }
nele2()235     int nele2() const { return nele_ - subspace_[0].first - subspace_[2].first; }
nparticles()236     int nparticles() const { return subspace_[2].first; }
237 
tag()238     size_t tag() const { return nholes() + (nparticles() << 8); }
239 
ras()240     template <int S> const std::pair<const int, const int> ras() const {
241       static_assert(S == 0 || S == 1 || S == 2, "illegal call of RAString::ras");
242       return std::get<S>(subspace_);
243     }
244 
245 };
246 
247 
248 class FCIString : public CIString_base_impl<1,FCIString> {
249   friend class CIString_base_impl<1,FCIString>;
250   protected:
251     void compute_strings_impl();
252 
contains_impl(const std::bitset<nbit__> & bit)253     bool contains_impl(const std::bitset<nbit__>& bit) const { assert(bit.count() == nele_); return true; }
matches_impl(const int n,const int m)254     bool matches_impl(const int n, const int m) const { return true; }
matches_impl(const std::shared_ptr<const FCIString> o)255     bool matches_impl(const std::shared_ptr<const FCIString> o) const { return true; }
256 
lexical_offset_impl(const std::bitset<nbit__> & bit)257     size_t lexical_offset_impl(const std::bitset<nbit__>& bit) const { return lexical(bit)+offset_; }
lexical_zero_impl(const std::bitset<nbit__> & bit)258     size_t lexical_zero_impl(const std::bitset<nbit__>& bit) const { return lexical(bit); }
259 
260   private:
261     friend class boost::serialization::access;
262     template <class Archive>
serialize(Archive & ar,const unsigned int)263     void serialize(Archive& ar, const unsigned int) {
264       ar & boost::serialization::base_object<CIString_base_impl<1,FCIString>>(*this);
265     }
266 
267   public:
FCIString()268     FCIString() { }
269     FCIString(const size_t nele1, const size_t norb1, const size_t offset = 0);
270     FCIString(const FCIString& o, const size_t offset = 0) : CIString_base_impl<1,FCIString>(o, offset) { }
271 
lexical(const std::bitset<nbit__> & bit)272     size_t lexical(const std::bitset<nbit__>& bit) const {
273       assert(contains(bit));
274       return graphs_[0]->lexical(0, norb_, bit);
275     }
276 };
277 
278 using FCIString_base = CIString_base_impl<1,FCIString>;
279 using RASString_base = CIString_base_impl<3,RASString>;
280 
281 }
282 
283 extern template class bagel::CIString_base_impl<1,bagel::FCIString>;
284 extern template class bagel::CIString_base_impl<3,bagel::RASString>;
285 
286 #include <src/util/archive.h>
287 BOOST_CLASS_EXPORT_KEY(bagel::CIGraph)
288 BOOST_CLASS_EXPORT_KEY(bagel::CIString_base)
289 BOOST_CLASS_EXPORT_KEY(bagel::RASString_base)
290 BOOST_CLASS_EXPORT_KEY(bagel::FCIString_base)
291 BOOST_CLASS_EXPORT_KEY(bagel::RASString)
292 BOOST_CLASS_EXPORT_KEY(bagel::FCIString)
293 
294 #endif
295