1 /*
2  *  Copyright (C) 2004-2021 Edward F. Valeev
3  *
4  *  This file is part of Libint.
5  *
6  *  Libint is free software: you can redistribute it and/or modify
7  *  it under the terms of the GNU General Public License as published by
8  *  the Free Software Foundation, either version 3 of the License, or
9  *  (at your option) any later version.
10  *
11  *  Libint is distributed in the hope that it will be useful,
12  *  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  *  GNU General Public License for more details.
15  *
16  *  You should have received a copy of the GNU General Public License
17  *  along with Libint.  If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 #define HAVE_STD_BINARY_COMPOSE 0
22 
23 #include <algorithm>
24 #include <functional>
25 #include <list>
26 #include <stdexcept>
27 #include <iostream>
28 #include <memory.h>
29 
30 using namespace std;
31 using namespace libint2;
32 
33 #if HAVE_STD_BINARY_COMPOSE
34 #include <ext/functional>
35 using namespace __gnu_cxx;
36 #endif
37 
MemoryManager(const Size & maxmem)38 MemoryManager::MemoryManager(const Size& maxmem) :
39   maxmem_(maxmem), blks_(), superblock_(new MemBlock(Address(0),maxmem,true,SafePtr<MemBlock>(),SafePtr<MemBlock>())),
40   max_memory_used_(0)
41 {
42 }
43 
~MemoryManager()44 MemoryManager::~MemoryManager()
45 {
46   reset();
47 }
48 
49 SafePtr<MemoryManager::MemBlock>
steal_from_block(const SafePtr<MemBlock> & blk,const Size & size)50 MemoryManager::steal_from_block(const SafePtr<MemBlock>& blk, const Size& size)
51 {
52   if (!blk->free())
53     throw std::runtime_error("MemoryManager::steal_from_block() -- block is not free");
54 
55   Size old_size = blk->size();
56   if (old_size < size)
57     throw std::runtime_error("MemoryManager::steal_from_block() -- block is too small");
58   if (old_size == size) {
59     blk->set_free(false);
60     return blk;
61   }
62 
63   Size new_size = old_size - size;
64   Address address = blk->address();
65   blk->set_size(new_size);
66   blk->set_address(address+size);
67   SafePtr<MemBlock> left = blk->left();
68   SafePtr<MemBlock> newblk(new MemBlock(address,size,false,left,blk));
69   if (left)
70     left->right(newblk);
71   blk->left(newblk);
72   blks_.push_back(newblk);
73 
74   update_max_memory();
75 
76   return newblk;
77 }
78 
79 SafePtr<MemoryManager::MemBlock>
find_block(const Address & address)80 MemoryManager::find_block(const Address& address)
81 {
82   typedef memblkset::iterator iter;
83   using std::placeholders::_1;
84   iter blk = find_if(blks_.begin(),blks_.end(),std::bind(MemBlock::address_eq,_1,address));
85   if (blk == blks_.end())
86     throw std::runtime_error("MemoryManager::find_block() -- didn't find a block at this address");
87   else
88     return *blk;
89 }
90 
91 void
free(const Address & address)92 MemoryManager::free(const Address& address)
93 {
94   SafePtr<MemBlock> blk = find_block(address);
95   if (!blk->free())
96     blk->set_free(true);
97   else
98     throw std::runtime_error("WorstFitMemoryManager::free() tried to free a free block");
99 
100   // Find blocks adjacent to this one and, if they are free, merge them
101   SafePtr<MemBlock> left = blk->left();
102   SafePtr<MemBlock> right = blk->right();
103   if (left && left->free())
104     blk = merge_blocks(left,blk);
105   if (right && right->free())
106     merge_blocks(blk,right);
107 }
108 
109 SafePtr<MemoryManager::MemBlock>
merge_blocks(const SafePtr<MemBlock> & left,const SafePtr<MemBlock> & right)110 MemoryManager::merge_blocks(const SafePtr<MemBlock>& left, const SafePtr<MemBlock>& right)
111 {
112   if (left->free() != right->free())
113     throw std::runtime_error("MemoryManager::merge_block() -- both blocks must be occupied or free");
114   bool free = left->free();
115   Address address = left->address();
116   if (right->address() <= address)
117     throw std::runtime_error("MemoryManager::merge_block() -- address of left block >= address of right block");
118   if (left->address() + static_cast<Address>(left->size()) != right->address())
119     throw std::runtime_error("MemoryManager::merge_block() -- address of left block + size of left block != address of right block");
120   Size size = left->size() + right->size();
121 
122   if (right == superblock())
123     return merge_to_superblock(left);
124   else {
125     SafePtr<MemBlock> lleft = left->left();
126     SafePtr<MemBlock> rright = right->right();
127 
128     typedef memblkset::iterator iter;
129     iter liter = find(blks_.begin(),blks_.end(),left);
130     if (liter != blks_.end())
131       blks_.erase(liter);
132     else
133       throw std::runtime_error("MemoryManager::merge_block() -- left block is not found");
134     iter riter = find(blks_.begin(),blks_.end(),right);
135     if (riter != blks_.end())
136       blks_.erase(riter);
137     else
138       throw std::runtime_error("MemoryManager::merge_block() -- right block is not found");
139 
140     SafePtr<MemBlock> newblk(new MemBlock(address,size,free,lleft,rright));
141     blks_.push_back(newblk);
142     if (lleft) {
143       lleft->right(newblk);
144     }
145     if (rright) {
146       rright->left(newblk);
147     }
148 
149     return newblk;
150   }
151 }
152 
153 SafePtr<MemoryManager::MemBlock>
merge_to_superblock(const SafePtr<MemBlock> & blk)154 MemoryManager::merge_to_superblock(const SafePtr<MemBlock>& blk)
155 {
156   SafePtr<MemBlock> sblk = superblock();
157   typedef memblkset::iterator iter;
158   iter biter = find(blks_.begin(),blks_.end(),blk);
159   if (biter != blks_.end())
160     blks_.erase(biter);
161   else
162     throw std::runtime_error("MemoryManager::merge_to_superblock(blk) --  blk is not found");
163   sblk->set_address(blk->address());
164   sblk->set_size(sblk->size() + blk->size());
165   SafePtr<MemBlock> left = blk->left();
166   if (left)
167     left->right(sblk);
168   sblk->left(left);
169   return sblk;
170 }
171 
172 void
update_max_memory()173 MemoryManager::update_max_memory()
174 {
175   Address saddr =  superblock()->address();
176   if (static_cast<Size>(saddr) > max_memory_used_)
177     max_memory_used_  = saddr;
178 }
179 
180 void
reset()181 MemoryManager::reset()
182 {
183   // for each block reset left and right pointers to break up cyclic dependencies that prevent automatic destruction of SafePtr-managed MemBlock objects
184   superblock_->left(SafePtr<MemBlock>());
185   superblock_->right(SafePtr<MemBlock>());
186   for(memblkset::iterator b=blks_.begin(); b!=blks_.end(); ++b) {
187     (*b)->left(SafePtr<MemBlock>());
188     (*b)->right(SafePtr<MemBlock>());
189   }
190   memblkset empty_blks;
191   swap(blks_,empty_blks);
192   superblock_ = SafePtr<MemBlock>(new MemBlock(Address(0),maxmem_,true,SafePtr<MemBlock>(),SafePtr<MemBlock>()));
193 }
194 
195 ///////////////
196 
WorstFitMemoryManager(bool search_exact,const Size & maxsize)197 WorstFitMemoryManager::WorstFitMemoryManager(bool search_exact, const Size& maxsize) :
198   MemoryManager(maxsize), search_exact_(search_exact)
199 {
200 }
201 
~WorstFitMemoryManager()202 WorstFitMemoryManager::~WorstFitMemoryManager()
203 {
204 }
205 
206 MemoryManager::Address
alloc(const Size & size)207 WorstFitMemoryManager::alloc(const Size& size)
208 {
209   if (size > maxmem())
210     throw std::runtime_error("WorstFitMemoryManager::alloc() -- requested more memory than available");
211   if (size == 0)
212     throw std::runtime_error("WorstFitMemoryManager::alloc(size) -- size is 0");
213 
214   typedef memblkset::iterator iter;
215   memblkset& blks = blocks();
216 
217   // try to find the exact match first
218   if (search_exact_) {
219 #if HAVE_STD_BINARY_COMPOSE
220     iter blk;
221     blk = find_if(blks.begin(),blks.end(),
222                   compose2(logical_and<bool>(),
223                            bind2nd(ptr_fun(MemBlock::size_eq),size),
224                            &MemBlock::is_free));
225     if (blk != blks.end()) {
226       (*blk)->set_free(false);
227       return (*blk)->address();
228     }
229 #else
230     iter begin = blks.begin();
231     iter end = blks.end();
232     for(iter b=begin; b!=end; b++) {
233       if((*b)->size() == size && (*b)->free()) {
234         (*b)->set_free(false);
235         return (*b)->address();
236       }
237     }
238 #endif
239   }
240 
241   // find all free_blocks
242   std::list< SafePtr<MemBlock> > free_blks;
243   for(iter b=blks.begin(); b!=blks.end(); b++) {
244     b = find_if(b,blks.end(),&MemBlock::is_free);
245     if (b != blks.end())
246       free_blks.push_back(*b);
247     else // No more blocks left
248       break;
249   }
250 
251   // if no exact match found -- find the largest free block and grab memory from it
252   std::list< SafePtr<MemBlock> >::iterator largest_free_block = max_element(free_blks.begin(),free_blks.end(),&MemBlock::size_less_than);
253   if (largest_free_block != free_blks.end() &&
254       (*largest_free_block)->size() > size) {
255     SafePtr<MemBlock> result = steal_from_block(*largest_free_block,size);
256     return result->address();
257   }
258 
259   // lastly, if all failed -- steal from the super block
260   SafePtr<MemBlock> result = steal_from_block(superblock(),size);
261   return result->address();
262 }
263 
264 ///////////////
265 
BestFitMemoryManager(bool search_exact,const Size & tight_fit,const Size & maxsize)266 BestFitMemoryManager::BestFitMemoryManager(bool search_exact, const Size& tight_fit, const Size& maxsize) :
267   MemoryManager(maxsize), search_exact_(search_exact), tight_fit_(tight_fit)
268 {
269 }
270 
~BestFitMemoryManager()271 BestFitMemoryManager::~BestFitMemoryManager()
272 {
273 }
274 
275 MemoryManager::Address
alloc(const Size & size)276 BestFitMemoryManager::alloc(const Size& size)
277 {
278   if (size > maxmem())
279     throw std::runtime_error("BestFitMemoryManager::alloc() -- requested more memory than available");
280   if (size == 0)
281     throw std::runtime_error("BestFitMemoryManager::alloc(size) -- size is 0");
282 
283   typedef memblkset::iterator iter;
284   memblkset& blks = blocks();
285 
286   // try to find the exact match first
287   if (search_exact_) {
288 #if HAVE_STD_BINARY_COMPOSE
289     iter blk;
290     blk = find_if(blks.begin(),blks.end(),
291                   compose2(logical_and<bool>(),
292                            bind2nd(ptr_fun(MemBlock::size_eq),size),
293                            &MemBlock::is_free));
294     if (blk != blks.end()) {
295       (*blk)->set_free(false);
296       return (*blk)->address();
297     }
298 #else
299     iter begin = blks.begin();
300     iter end = blks.end();
301     for(iter b=begin; b!=end; b++) {
302       if((*b)->size() == size && (*b)->free()) {
303         (*b)->set_free(false);
304         return (*b)->address();
305       }
306     }
307 #endif
308   }
309 
310   // find all free_blocks
311   std::list< SafePtr<MemBlock> > free_blks;
312   typedef std::list< SafePtr<MemBlock> >::iterator fiter;
313   for(iter b=blks.begin(); b!=blks.end(); b++) {
314     b = find_if(b,blks.end(),&MemBlock::is_free);
315     if (b != blks.end())
316       free_blks.push_back(*b);
317     else // No more blocks left
318       break;
319   }
320 
321   // if there are no free blocks left -- steal from the super block
322   if (free_blks.empty()) {
323     SafePtr<MemBlock> result = steal_from_block(superblock(),size);
324     return result->address();
325   }
326 
327   // else find the smallest free block and grab memory from it
328   fiter smallest_free_block = min_element(free_blks.begin(),free_blks.end(),&MemBlock::size_less_than);
329 
330   do {
331 
332     if ((*smallest_free_block)->size() > size + tight_fit_) {
333       SafePtr<MemBlock> result = steal_from_block(*smallest_free_block,size);
334       return result->address();
335     }
336     else {
337       free_blks.erase(smallest_free_block);
338     }
339 
340     smallest_free_block = min_element(free_blks.begin(),free_blks.end(),&MemBlock::size_less_than);
341 
342   } while (smallest_free_block != free_blks.end());
343 
344   // Steal from superblock as a last resort
345   SafePtr<MemBlock> result = steal_from_block(superblock(),size);
346   return result->address();
347 
348 }
349 
350 ///////////////
351 
FirstFitMemoryManager(bool search_exact,const Size & maxsize)352 FirstFitMemoryManager::FirstFitMemoryManager(bool search_exact, const Size& maxsize) :
353   MemoryManager(maxsize), search_exact_(search_exact)
354 {
355 }
356 
~FirstFitMemoryManager()357 FirstFitMemoryManager::~FirstFitMemoryManager()
358 {
359 }
360 
361 MemoryManager::Address
alloc(const Size & size)362 FirstFitMemoryManager::alloc(const Size& size)
363 {
364   if (size > maxmem())
365     throw std::runtime_error("FirstFitMemoryManager::alloc() -- requested more memory than available");
366   if (size == 0)
367     throw std::runtime_error("FirstFitMemoryManager::alloc(size) -- size is 0");
368 
369   typedef memblkset::iterator iter;
370   memblkset& blks = blocks();
371 
372   // try to find the exact match first
373   if (search_exact_) {
374 #if HAVE_STD_BINARY_COMPOSE
375     iter blk;
376     blk = find_if(blks.begin(),blks.end(),
377                   compose2(logical_and<bool>(),
378                            bind2nd(ptr_fun(MemBlock::size_eq),size),
379                            &MemBlock::is_free));
380     if (blk != blks.end()) {
381       (*blk)->set_free(false);
382       return (*blk)->address();
383     }
384 #else
385     iter begin = blks.begin();
386     iter end = blks.end();
387     for(iter b=begin; b!=end; b++) {
388       if((*b)->size() == size && (*b)->free()) {
389         (*b)->set_free(false);
390         return (*b)->address();
391       }
392     }
393 #endif
394   }
395 
396   // Find the first free block larger than size
397 #if HAVE_STD_BINARY_COMPOSE
398     iter blk;
399     blk = find_if(blks.begin(),blks.end(),
400                   compose2(logical_and<bool>(),
401                            bind2nd(ptr_fun(MemBlock::size_geq),size),
402                            &MemBlock::is_free));
403     if (blk != blks.end()) {
404       SafePtr<MemBlock> result = steal_from_block(*blk,size);
405       return result->address();
406     }
407 #else
408     iter begin = blks.begin();
409     iter end = blks.end();
410     for(iter b=begin; b!=end; b++) {
411       if((*b)->size() >= size && (*b)->free()) {
412         SafePtr<MemBlock> result = steal_from_block(*b,size);
413         return result->address();
414       }
415     }
416 #endif
417 
418   // Steal from superblock as a last resort
419   SafePtr<MemBlock> result = steal_from_block(superblock(),size);
420   return result->address();
421 
422 }
423 
424 ///////////////
425 
LastFitMemoryManager(bool search_exact,const Size & maxsize)426 LastFitMemoryManager::LastFitMemoryManager(bool search_exact, const Size& maxsize) :
427   MemoryManager(maxsize), search_exact_(search_exact)
428 {
429 }
430 
~LastFitMemoryManager()431 LastFitMemoryManager::~LastFitMemoryManager()
432 {
433 }
434 
435 MemoryManager::Address
alloc(const Size & size)436 LastFitMemoryManager::alloc(const Size& size)
437 {
438   if (size > maxmem())
439     throw std::runtime_error("LastFitMemoryManager::alloc() -- requested more memory than available");
440   if (size == 0)
441     throw std::runtime_error("LastFitMemoryManager::alloc(size) -- size is 0");
442 
443   typedef memblkset::iterator iter;
444   typedef memblkset::reverse_iterator riter;
445 
446   memblkset& blks = blocks();
447   riter rbegin = blks.rbegin();
448   riter rend = blks.rend();
449 
450   // try to find the exact match first
451   if (search_exact_) {
452 #if HAVE_STD_BINARY_COMPOSE
453     riter blk;
454     blk = find_if(rbegin,rend,
455                   compose2(logical_and<bool>(),
456                            bind2nd(ptr_fun(MemBlock::size_eq),size),
457                            &MemBlock::is_free));
458     if (blk != rend) {
459       (*blk)->set_free(false);
460       return (*blk)->address();
461     }
462 #else
463     for(riter b=rbegin; b!=rend; b++) {
464       if((*b)->size() == size && (*b)->free()) {
465         (*b)->set_free(false);
466         return (*b)->address();
467       }
468     }
469 #endif
470   }
471 
472   // Find the first free block larger than size
473 #if HAVE_STD_BINARY_COMPOSE
474     riter blk;
475     blk = find_if(rbegin,rend,
476                   compose2(logical_and<bool>(),
477                            bind2nd(ptr_fun(MemBlock::size_geq),size),
478                            &MemBlock::is_free));
479     if (blk != rend) {
480       SafePtr<MemBlock> result = steal_from_block(*blk,size);
481       return result->address();
482     }
483 #else
484     for(riter b=rbegin; b!=rend; b++) {
485       if((*b)->size() >= size && (*b)->free()) {
486         SafePtr<MemBlock> result = steal_from_block(*b,size);
487         return result->address();
488       }
489     }
490 #endif
491 
492   // Steal from superblock as a last resort
493   SafePtr<MemBlock> result = steal_from_block(superblock(),size);
494   return result->address();
495 
496 }
497 
498 //////////////
499 
500 SafePtr<MemoryManager>
memman(unsigned int type) const501 MemoryManagerFactory::memman(unsigned int type) const
502 {
503   switch (type) {
504   case 0:
505     {
506       SafePtr<MemoryManager> result(new WorstFitMemoryManager(true));
507       return result;
508     }
509   case 1:
510     {
511       SafePtr<MemoryManager> result(new WorstFitMemoryManager(false));
512       return result;
513     }
514   case 2:
515     {
516       SafePtr<MemoryManager> result(new BestFitMemoryManager(true));
517       return result;
518     }
519   case 3:
520     {
521       SafePtr<MemoryManager> result(new BestFitMemoryManager(false));
522       return result;
523     }
524   case 4:
525     {
526       SafePtr<MemoryManager> result(new FirstFitMemoryManager(true));
527       return result;
528     }
529   case 5:
530     {
531       SafePtr<MemoryManager> result(new FirstFitMemoryManager(false));
532       return result;
533     }
534   case 6:
535     {
536       SafePtr<MemoryManager> result(new LastFitMemoryManager(true));
537       return result;
538     }
539   case 7:
540     {
541       SafePtr<MemoryManager> result(new LastFitMemoryManager(false));
542       return result;
543     }
544   default:
545     throw std::runtime_error("MemoryManagerFactory::memman(type) -- invalid type");
546   }
547 }
548 
549 namespace MMTypes {
550   static const char labels_[MemoryManagerFactory::ntypes][80] = {
551     "WorstFitMemoryManager(true)",
552       "WorstFitMemoryManager(false)",
553       "BestFitMemoryManager(true)",
554       "BestFitMemoryManager(false)",
555       "FirstFitMemoryManager(true)",
556       "FirstFitMemoryManager(false)",
557       "LastFitMemoryManager(true)",
558       "LastFitMemoryManager(false)"
559       };
560 
561 };
562 
563 std::string
label(unsigned int type) const564 MemoryManagerFactory::label(unsigned int type) const
565 {
566   return MMTypes::labels_[type];
567 }
568 
569 ////
570 
571 namespace libint2 {
572 
size_lessthan(const MemoryManager::MemBlock & A,const MemoryManager::MemBlock & B)573   bool size_lessthan(const MemoryManager::MemBlock& A, const MemoryManager::MemBlock& B) {
574     return A.size() < B.size();
575   }
address_lessthan(const MemoryManager::MemBlock & A,const MemoryManager::MemBlock & B)576   bool address_lessthan(const MemoryManager::MemBlock& A, const MemoryManager::MemBlock& B) {
577     return A.address() < B.address();
578   }
579 
can_merge(const MemoryManager::MemBlock & A,const MemoryManager::MemBlock & B)580   bool can_merge(const MemoryManager::MemBlock& A, const MemoryManager::MemBlock& B) {
581     if (A.address() < B.address()) {
582       return (A.free() == B.free()) && (A.address() + static_cast<MemBlock::Address>(A.size()) == B.address());
583     }
584     else {
585       return (A.free() == B.free()) && (B.address() + static_cast<MemBlock::Address>(B.size()) == A.address());
586     }
587   }
588 
589   void
merge(MemBlockSet & blocks)590   merge(MemBlockSet& blocks) {
591     typedef MemBlockSet::const_iterator citer;
592     typedef MemBlockSet::iterator iter;
593 
594     if (blocks.size() <= 1) return;
595 
596     // Sort by increasing address
597     //sort(blocks.begin(),blocks.end(),libint2::address_lessthan);
598     blocks.sort(address_lessthan);
599 
600     // Iterate over pais of adjacent blocks and merge, if possible
601     const citer end = blocks.end();
602     iter b = blocks.begin();
603     iter bp1(b);  ++bp1;
604     while (bp1 != end) {
605       if (can_merge(*b,*bp1)) {
606 	b->merge(*bp1);
607 	bp1 = blocks.erase(bp1);
608       }
609       else {
610 	++b;
611 	++bp1;
612       }
613     }
614   }
615 
616 };
617