1 /*
2  *  FishersExact.cc
3  *  Apto
4  *
5  *  Created by David on 2/15/11.
6  *  Copyright 2011 David Michael Bryson. All rights reserved.
7  *  http://programerror.com/software/apto
8  *
9  *  Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
10  *  following conditions are met:
11  *
12  *  1.  Redistributions of source code must retain the above copyright notice, this list of conditions and the
13  *      following disclaimer.
14  *  2.  Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
15  *      following disclaimer in the documentation and/or other materials provided with the distribution.
16  *  3.  Neither the name of David Michael Bryson, nor the names of contributors may be used to endorse or promote
17  *      products derived from this software without specific prior written permission.
18  *
19  *  THIS SOFTWARE IS PROVIDED BY DAVID MICHAEL BRYSON AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
20  *  INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21  *  DISCLAIMED. IN NO EVENT SHALL DAVID MICHAEL BRYSON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22  *  SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23  *  SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24  *  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25  *  USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  *
27  *  Authors: David M. Bryson <david@programerror.com>
28  *
29  *  Fishers Exact Test for r x c Contingency Tables
30  *  Based on:
31  *    "A Network Algorithm for Performing Fisher's Exact Test in r x c Contingency Tables" by Mehta and Patel
32  *      Journal of the American Statistical Association, June 1983, Volume 87, Number 382, Pages 427-434
33  *    "Algorithm 643: FEXACT: A FORTRAN Subroutine for Fisher's Exact Test on Unordered r x c Contingency Tables"
34  *      by Mehta and Patel, ACM Transactions on Mathematical Software, June 1986, Volume 12, No. 2, Pages 154-161
35  *    "A Remark on Algorithm 643: FEXACT: A FORTRAN Subroutine for Fisher's Exact Test on Unordered r x c Contingency Tables"
36  *      by Clarkson, Fan, and Joe, ACM Transactions on Mathematical Software, Dec. 1993, Volume 19, No. 4, Pages 484-488
37  *
38  */
39 
40 #include "apto/platform.h"
41 
42 #include "apto/stat/ContingencyTable.h"
43 #include "apto/stat/Functions.h"
44 
45 #include "apto/core/Array.h"
46 #include "apto/core/ArrayUtils.h"
47 #include "apto/core/ConditionVariable.h"
48 #include "apto/core/Mutex.h"
49 #include "apto/core/Pair.h"
50 #include "apto/core/RefCount.h"
51 #include "apto/core/SmartPtr.h"
52 #include "apto/core/Thread.h"
53 
54 #include <cmath>
55 #include <limits>
56 
57 #include <iostream>
58 #include <stdio.h>
59 
60 using namespace Apto;
61 
62 
63 // Constant Declarations
64 // --------------------------------------------------------------------------------------------------------------
65 
66 static const double TOLERANCE = 3.4525e-7;  // Tolerance, as used in Algorithm 643
67 static const int THREADING_THRESHOLD = 25;
68 static const int DEFAULT_TABLE_SIZE = 300;
69 
70 
71 // Exported Function Declarations
72 // --------------------------------------------------------------------------------------------------------------
73 
74 namespace Apto {
75   namespace Stat {
76     double FishersExact(const ContingencyTable& table);
77   };
78 };
79 
80 
81 
82 // Internal Function Declarations
83 // --------------------------------------------------------------------------------------------------------------
84 
85 static double cummulativeGamma(double q, double alpha, bool& fault);
86 static double logGamma(double x, bool& fault);
87 
88 
89 
90 // Internal Class/Struct Definitions
91 // --------------------------------------------------------------------------------------------------------------
92 
93 template <class T> class ManualBuffer
94 {
95 private:
96   T* m_data;    // Data Array
97   int m_size;   // Array Size
98 
99 protected:
100   typedef T StoredType;
101 
ManualBuffer(int size=0)102   explicit ManualBuffer(int size = 0) : m_data(NULL), m_size(size) { ; }
ManualBuffer(const ManualBuffer & rhs)103   ManualBuffer(const ManualBuffer& rhs) : m_data(NULL), m_size(0) { this->operator=(rhs); }
ManualBuffer()104   ManualBuffer() { ; }
105 
GetSize() const106   int GetSize() const { return m_size; }
ResizeClear(const int in_size)107   void ResizeClear(const int in_size) { m_size = in_size; }
Resize(int new_size)108   void Resize(int new_size) { m_size = new_size; }
109 
operator =(const ManualBuffer & rhs)110   ManualBuffer& operator=(const ManualBuffer& rhs)
111   {
112     m_size = rhs.m_size;
113     for (int i = 0; i < rhs.m_size; i++) m_data[i] = rhs.m_data[i];
114     return *this;
115   }
116 
operator [](const int index)117   inline T& operator[](const int index) { return m_data[index]; }
operator [](const int index) const118   inline const T& operator[](const int index) const { return m_data[index]; }
119 
Swap(int idx1,int idx2)120   void Swap(int idx1, int idx2)
121   {
122     T v = m_data[idx1];
123     m_data[idx1] = m_data[idx2];
124     m_data[idx2] = v;
125   }
126 public:
SetBuffer(T * data)127   void SetBuffer(T* data) { m_data = data; }
128 };
129 
130 
131 template <class T> class EnhancedSmart : public Smart<T>
132 {
133 public:
EnhancedSmart(int size=0)134   EnhancedSmart(int size = 0) : Smart<T>(size) { ; }
EnhancedSmart(const EnhancedSmart & rhs)135   EnhancedSmart(const EnhancedSmart& rhs) : Smart<T>(rhs) { ; }
136 
137   class Slice
138   {
139     friend class EnhancedSmart;
140   private:
141     const T* m_data;
142     const int m_size;
143 
Slice(const T * data,const int size)144     Slice(const T* data, const int size) : m_data(data), m_size(size) { ; }
145 
146   public:
GetSize() const147     inline int GetSize() const { return m_size; }
148 
operator [](const int index) const149     inline const T& operator[](const int index) const
150     {
151       assert(index >= 0);       // Lower Bounds Error
152       assert(index < m_size); // Upper Bounds Error
153       return m_data[index];
154     }
155   };
156 
GetSlice(int start,int end) const157   inline Slice GetSlice(int start, int end) const
158   {
159     assert(start >= 0);
160     assert(end < Smart<T>::m_active);
161 
162     return Slice(Smart<T>::m_data + start, end - start + 1);
163   }
164 
operator =(const Slice & rhs)165   EnhancedSmart& operator=(const Slice& rhs)
166   {
167     if (Smart<T>::m_active != rhs.GetSize()) Smart<T>::ResizeClear(rhs.GetSize());
168     for (int i = 0; i < rhs.GetSize(); i++) Smart<T>::m_data[i] = rhs.m_data[i];
169     return *this;
170   }
171 };
172 
173 typedef Array<int, EnhancedSmart> MarginalArray;
174 
175 
176 class FExact
177 {
178 public:
179   // FExact Public Methods
180   // ------------------------------------------------------------------------------------------------------------
181 
182   FExact(const Stat::ContingencyTable& table, double tolerance);
183 
184   double Calculate();
185   double ThreadedCalculate();
186 
187 
188 private:
189   // FExact Internal Type Declarations
190   // ------------------------------------------------------------------------------------------------------------
191 
192   // Main Node
193   class FExactNode;
194   struct PastPathLength;
195   class NodeHashTable;
196   typedef SmartPtr<FExactNode, InternalRCObject> NodePtr;
197 
198 
199   // Path Extremes
200   struct PathExtremes;
201   class PathExtremesHashTable;
202   struct PendingPathExtremes;
203   class PendingPathExtremesTable;
204   struct PendingPathNode;
205   class PathExtremesCalc;
206 
207 
208 private:
209   // FExact Internal Methods
210   // ------------------------------------------------------------------------------------------------------------
211 
212   inline double logMultinomial(int numerator, const MarginalArray& denominator);
213   inline double logMultinomial(int numerator, const MarginalArray::Slice& denominator);
214 
215   // Core Algorithm
216   inline bool generateFirstDaughter(const MarginalArray& row_marginals, int n, MarginalArray& row_diff, int& kmax, int& kd);
217   bool generateNewDaughter(int kmax, const MarginalArray& row_marginals, MarginalArray& row_diff, int& idx_dec, int& idx_inc);
218   NodePtr handlePastPaths(NodePtr& cur_node, double obs2, double obs3, double ddf, double drn, int kval, NodeHashTable& nht);
219   void recordPath(double path_length, int path_freq, Array<PastPathLength, Smart>& past_entries);
220 
221   void handleNode(int k, NodePtr cur_node);
222   inline void unpackMarginals(int key, MarginalArray& row_marginals);
223 
224 
225   // Longest Path
226   double longestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, int marginal_total);
227   bool longestPathSpecial(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& val);
228 
229 
230   // Shortest Path
231   void shortestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& shortest_path);
232   inline void removeFromVector(const Array<int, ManualBuffer>& src, int idx_remove, Array<int, ManualBuffer>& dest);
233   inline void reduceZeroInVector(const Array<int, ManualBuffer>& src, int value, int idx_start, Array<int, ManualBuffer>& dest);
234 
235 
236 private:
237   // FExact Internal Type Definitions
238   // ------------------------------------------------------------------------------------------------------------
239 
240   struct PastPathLength
241   {
242     double value;
243     int observed;
244     int next_left;
245     int next_right;
246 
PastPathLengthFExact::PastPathLength247     PastPathLength(double in_value = 0.0) : value(in_value), observed(1), next_left(-1), next_right(-1) { ; }
PastPathLengthFExact::PastPathLength248     PastPathLength(double in_value, int in_freq) : value(in_value), observed(in_freq), next_left(-1), next_right(-1) { ; }
249   };
250 
251   class FExactNode : public RefCountObject
252   {
253   public:
254     int key;
255     Array<PastPathLength, Smart> past_entries;
256 
FExactNode(int in_key)257     FExactNode(int in_key) : key(in_key) { ; }
~FExactNode()258     virtual ~FExactNode() { ; }
259   };
260 
261   class NodeHashTable
262   {
263   public:
264     typedef SmartPtr<FExactNode, InternalRCObject> NodePtr;
265   private:
266     Array<NodePtr> m_table;
267     int m_last;
268 
269   public:
NodeHashTable(int size=DEFAULT_TABLE_SIZE)270     inline NodeHashTable(int size = DEFAULT_TABLE_SIZE) : m_table(size), m_last(-1) { ; }
271 
Find(int key,int & idx)272     bool Find(int key, int& idx)
273     {
274       int init = key % m_table.GetSize();
275       idx = init;
276       for (; idx < m_table.GetSize(); idx++) {
277         if (!m_table[idx]) {
278           m_table[idx] = NodePtr(new FExactNode(key));
279           return false;
280         } else if (m_table[idx]->key == key) {
281           return true;
282         }
283       }
284       for (idx = 0; idx < init; idx++) {
285         if (!m_table[idx]) {
286           m_table[idx] = NodePtr(new FExactNode(key));
287           return false;
288         } else if (m_table[idx]->key == key) {
289           return true;
290         }
291       }
292       Rehash(key, idx);
293       return false;
294     }
295 
operator [](int idx)296     inline FExactNode& operator[](int idx) { return *m_table[idx]; }
Get(int idx)297     inline NodePtr Get(int idx) { return m_table[idx]; }
298 
Pop()299     NodePtr Pop()
300     {
301       for (++m_last; m_last < m_table.GetSize(); m_last++) {
302         if (m_table[m_last]) {
303           NodePtr tmp = m_table[m_last];
304           m_table[m_last] = NodePtr(NULL);
305           return tmp;
306         }
307       }
308       m_last = -1;
309       return NodePtr(NULL);
310     }
311   private:
Rehash(int key,int & idx)312     void Rehash(int key, int& idx)
313     {
314       Array<NodePtr> old_table(m_table);
315       m_table.ResizeClear(old_table.GetSize() * 2);
316       for (int i = 0; i < old_table.GetSize(); i++) {
317         int t_idx;
318         Find(old_table[i]->key, t_idx);
319         m_table[t_idx] = old_table[i];
320       }
321       Find(key, idx);
322     }
323   };
324 
325 
326 
327   struct PathExtremes {
328     int key;
329     double longest_path;
330     double shortest_path;
PathExtremesFExact::PathExtremes331     PathExtremes() : key(-1) { ; }
332 
SetFExact::PathExtremes333     void Set(int in_key, double lp, double sp) { key = in_key; longest_path = lp; shortest_path = sp; }
334   };
335 
336   class PathExtremesHashTable
337   {
338   private:
339     Array<PathExtremes> m_table;
340 
341   public:
PathExtremesHashTable(int size=DEFAULT_TABLE_SIZE)342     inline PathExtremesHashTable(int size = DEFAULT_TABLE_SIZE) : m_table(size) { ClearTable(); }
343 
Find(int key,int & idx)344     bool Find(int key, int& idx)
345     {
346       int init = key % m_table.GetSize();
347       idx = init;
348       for (; idx < m_table.GetSize(); idx++) {
349         if (m_table[idx].key < 0) {
350           m_table[idx].key = key;
351           return false;
352         } else if (m_table[idx].key == key) {
353           return true;
354         }
355       }
356       for (idx = 0; idx < init; idx++) {
357         if (m_table[idx].key < 0) {
358           m_table[idx].key = key;
359           return false;
360         } else if (m_table[idx].key == key) {
361           return true;
362         }
363       }
364       Rehash(key, idx);
365       return false;
366     }
367 
operator [](int idx)368     inline PathExtremes& operator[](int idx) { return m_table[idx]; }
369 
ClearTable()370     inline void ClearTable()
371     {
372       for (int i = 0; i < m_table.GetSize(); i++) m_table[i].key = -1;
373     }
374 
375   private:
Rehash(int key,int & idx)376     void Rehash(int key, int& idx)
377     {
378       Array<PathExtremes> old_table(m_table);
379       m_table.ResizeClear(old_table.GetSize() * 2);
380       for (int i = 0; i < old_table.GetSize(); i++) {
381         int t_idx;
382         Find(old_table[i].key, t_idx);
383         m_table[t_idx].longest_path = old_table[i].longest_path;
384         m_table[t_idx].shortest_path = old_table[i].shortest_path;
385       }
386       Find(key, idx);
387     }
388   };
389 
390 
391   struct PendingPathExtremes {
392     int key;
393     bool inprogress;
394     MarginalArray rows;
395     MarginalArray cols;
396     double obs2;
397     double ddf;
398     int ntot;
PendingPathExtremesFExact::PendingPathExtremes399     PendingPathExtremes() : key(-1), inprogress(false) { ; }
400 
SetFExact::PendingPathExtremes401     inline void Set(const MarginalArray::Slice& in_rows, const MarginalArray::Slice& in_cols,
402                     double in_obs2, double in_ddf, int in_ntot)
403     {
404       rows.ResizeClear(in_rows.GetSize());
405       for (int i = 0; i < in_rows.GetSize(); i++) rows[i] = in_rows[i];
406       cols.ResizeClear(in_cols.GetSize());
407       for (int i = 0; i < in_cols.GetSize(); i++) cols[i] = in_cols[i];
408       ntot = in_ntot;
409       obs2 = in_obs2;
410       ddf = in_ddf;
411     }
412   };
413 
414   class PendingPathExtremesTable
415   {
416   private:
417     Array<PendingPathExtremes> m_table;
418     int m_size;
419 
420   public:
PendingPathExtremesTable(int size=DEFAULT_TABLE_SIZE)421     inline PendingPathExtremesTable(int size = DEFAULT_TABLE_SIZE) : m_table(size), m_size(0) { ; }
422 
GetSize()423     int GetSize() { return m_size; }
424 
Find(int key,int & idx)425     bool Find(int key, int& idx)
426     {
427       int init = key % m_table.GetSize();
428       idx = init;
429       for (; idx < m_table.GetSize(); idx++) {
430         if (m_table[idx].key < 0) {
431           m_size++;
432           m_table[idx].key = key;
433           return false;
434         } else if (m_table[idx].key == key) {
435           return true;
436         }
437       }
438       for (idx = 0; idx < init; idx++) {
439         if (m_table[idx].key < 0) {
440           m_size++;
441           m_table[idx].key = key;
442           return false;
443         } else if (m_table[idx].key == key) {
444           return true;
445         }
446       }
447       Rehash(key, idx);
448       return false;
449     }
450 
Pop()451     int Pop()
452     {
453       assert(m_size > 0);
454       for (int i = 0; i < m_table.GetSize(); i++) if (m_table[i].key >= 0 && !m_table[i].inprogress) {
455         m_table[i].inprogress = true;
456         return i;
457       }
458       return -1;
459     }
Remove(int idx)460     void Remove(int idx) { m_table[idx].key = -1; m_table[idx].inprogress = false; m_size--; }
461 
operator [](int idx)462     inline PendingPathExtremes& operator[](int idx) { return m_table[idx]; }
463 
464   private:
Rehash(int key,int & idx)465     void Rehash(int key, int& idx)
466     {
467       Array<PendingPathExtremes> old_table(m_table);
468       m_table.ResizeClear(old_table.GetSize() * 2);
469       for (int i = 0; i < old_table.GetSize(); i++) {
470         int t_idx;
471         Find(old_table[i].key, t_idx);
472         m_table[t_idx].rows = old_table[i].rows;
473         m_table[t_idx].cols = old_table[i].cols;
474       }
475       Find(key, idx);
476     }
477   };
478 
479 
480   struct PendingPathNode {
481     NodePtr node;
482     double obs2;
483     double drn;
484     double ddf;
485     int kval;
486     bool k1;
487     bool handled;
488 
PendingPathNodeFExact::PendingPathNode489     PendingPathNode() : node(NULL), handled(false) { ; }
SetFExact::PendingPathNode490     inline void Set(NodePtr in_node, double in_obs2, double in_drn, double in_ddf, int in_kval, bool in_k1 = false)
491     {
492       node = in_node; obs2 = in_obs2; drn = in_drn; ddf = in_ddf; kval = in_kval; k1 = in_k1;
493     }
494   };
495 
496 
497   class PathExtremesCalc : public Thread
498   {
499   private:
500     FExact* m_fexact;
501     bool m_run;
502 
503   public:
PathExtremesCalc(FExact * fexact)504     PathExtremesCalc(FExact* fexact) : m_fexact(fexact), m_run(true) { ; }
505 
SetFinished()506     void SetFinished() { m_run = false; }
507   protected:
508     void Run();
509   };
510 
511 
512 private:
513   // FExact Class Variable Declarations
514   // ------------------------------------------------------------------------------------------------------------
515 
516   // Constants
517   const double m_tolerance;
518 
519 
520   // Pre-calculated Values
521   MarginalArray m_row_marginals;
522   MarginalArray m_col_marginals;
523   Array<double> m_facts; // Log factorials
524   Array<int> m_key_multipliers;
525   double m_observed_path;
526   double m_den_observed_path;
527 
528 
529   // Main Calculated Value
530   double m_pvalue;
531 
532 
533   // Core Algorithm Support
534 
535 
536   // Threaded Core Algorithm Support
537   Array<NodeHashTable> m_nht;
538   Array<Array<PendingPathNode, Smart> >  m_pending_path_nodes;
539   Array<PathExtremesHashTable> m_path_extremes;
540 
541   // Path Extremes Threading Support
542   Mutex m_path_extremes_mutex;
543   ConditionVariable m_path_extremes_cond;
544   ConditionVariable m_path_extremes_cond_complete;
545   Array<PendingPathExtremesTable> m_pending_path_extremes;
546   Array<Array<PathExtremes, Smart> > m_completed_path_extremes;
547   Array<int, Smart> m_path_extremes_queue;
548   bool m_pending_paths_term;
549 };
550 
551 
552 
553 // Exported Function Definitions
554 // --------------------------------------------------------------------------------------------------------------
555 
FishersExact(const ContingencyTable & table)556 double Apto::Stat::FishersExact(const ContingencyTable& table)
557 {
558   if (table.MarginalTotal() == 0.0) return std::numeric_limits<double>::quiet_NaN();  // All elements are 0
559 
560   FExact fe(table, TOLERANCE);
561 
562   // Use threaded calculate for larger tables
563   if (table.NumRows() * table.NumCols() > THREADING_THRESHOLD) return fe.ThreadedCalculate();
564 
565   return fe.Calculate();
566 }
567 
568 
569 
570 // FExact Constructor and Support Methods
571 // --------------------------------------------------------------------------------------------------------------
572 
FExact(const Stat::ContingencyTable & table,double tolerance)573 FExact::FExact(const Stat::ContingencyTable& table, double tolerance)
574   : m_tolerance(tolerance)
575   , m_facts(table.MarginalTotal() + 1)
576   , m_pvalue(0.0)
577 {
578   // Store table marginals for use in calculation
579   if (table.NumRows() > table.NumCols()) {
580     m_row_marginals = table.ColMarginals();
581     m_col_marginals = table.RowMarginals();
582   } else {
583     m_row_marginals = table.RowMarginals();
584     m_col_marginals = table.ColMarginals();
585   }
586 
587 
588   // Sort row and column marginals
589   QSort(m_row_marginals);
590   QSort(m_col_marginals);
591 
592 
593   // Set up key multipliers
594   m_key_multipliers.Resize(m_row_marginals.GetSize());
595   m_key_multipliers[0] = 1;
596   for (int i = 1; i < m_row_marginals.GetSize(); i++)
597     m_key_multipliers[i] = m_key_multipliers[i - 1] * (m_row_marginals[i - 1] + 1);
598 
599 
600   // Ensure that the maximum key will fit within an integer
601   assert((m_row_marginals[m_row_marginals.GetSize() - 2] + 1) <=
602          (std::numeric_limits<int>::max() / m_key_multipliers[m_row_marginals.GetSize() - 2]));
603 
604 
605   // Pre-calculate log factorials
606   const int marginal_total = table.MarginalTotal();
607   m_facts[0] = 0.0;
608   m_facts[1] = 0.0;
609   if (marginal_total > 1) {
610     m_facts[2] = log(2.0);
611     for (int i = 3; i <= marginal_total; i++) {
612       m_facts[i] = m_facts[i - 1] + log((double)i);
613       if (++i <= marginal_total)  m_facts[i] = m_facts[i - 1] + m_facts[2] + m_facts[i / 2] - m_facts[i / 2 - 1];
614     }
615   }
616 
617 
618   // Calculate Observed Path Numerator
619   m_observed_path = m_tolerance;
620   for (int j = 0; j < m_col_marginals.GetSize(); j++) {
621     double dd = 0.0;
622     for (int i = 0; i < m_row_marginals.GetSize(); i++) {
623       if (m_row_marginals.GetSize() == table.NumRows()) {
624         dd += m_facts[table[i][j]];
625       } else {
626         dd += m_facts[table[j][i]];
627       }
628     }
629     m_observed_path += m_facts[m_col_marginals[j]] - dd;
630   }
631 
632   // Calculate Observed Path Denominator
633   m_den_observed_path = logMultinomial(marginal_total, m_row_marginals);
634 
635 
636 //
637 //  double prt = exp(m_observed_path - m_den_observed_path);
638 //  std::cout << "prt = " << prt << std::endl;
639 }
640 
641 
logMultinomial(int numerator,const MarginalArray & denominator)642 inline double FExact::logMultinomial(int numerator, const MarginalArray& denominator)
643 {
644   double ret_val = m_facts[numerator];
645   for (int i = 0; i < denominator.GetSize(); i++) ret_val -= m_facts[denominator[i]];
646   return ret_val;
647 }
648 
649 
logMultinomial(int numerator,const MarginalArray::Slice & denominator)650 inline double FExact::logMultinomial(int numerator, const MarginalArray::Slice& denominator)
651 {
652   double ret_val = m_facts[numerator];
653   for (int i = 0; i < denominator.GetSize(); i++) ret_val -= m_facts[denominator[i]];
654   return ret_val;
655 }
656 
657 
658 
659 // FExact Core Algorithm Calculate Method
660 // --------------------------------------------------------------------------------------------------------------
661 
Calculate()662 double FExact::Calculate()
663 {
664   NodeHashTable nht[2];
665   PathExtremesHashTable path_extremes;
666 
667   int k = m_col_marginals.GetSize();
668   NodePtr cur_node(new FExactNode(0));
669   cur_node->past_entries.Push(PastPathLength(0));
670 
671   MarginalArray row_diff(m_row_marginals.GetSize());
672   MarginalArray irn(m_row_marginals.GetSize());
673 
674 //  printf("k = %d\n", k);
675 
676   while (true) {
677     int kb = m_col_marginals.GetSize() - k;
678     int ks = -1;
679     int kmax;
680     int kd;
681 
682     if (generateFirstDaughter(m_row_marginals, m_col_marginals[kb], row_diff, kmax, kd)) {
683       int ntot = 0;
684       for (int i = kb + 1; i < m_col_marginals.GetSize(); i++) ntot += m_col_marginals[i];
685 
686       do {
687         for (int i = 0; i < m_row_marginals.GetSize(); i++) irn[i] = m_row_marginals[i] - row_diff[i];
688 
689         int nrb;
690         if (k > 2) {
691           if (irn.GetSize() == 2) {
692             if (irn[0] > irn[1]) irn.Swap(0, 1);
693           } else {
694             QSort(irn);
695           }
696 
697           // Adjust for zero start
698           int i = 0;
699           for (; i < irn.GetSize(); i++) if (irn[i] != 0) break;
700           nrb = i;
701         } else {
702           nrb = 0;
703         }
704 
705         // Build adjusted row array
706         MarginalArray::Slice sub_rows = irn.GetSlice(nrb, irn.GetSize() - 1);
707         MarginalArray::Slice sub_cols = m_col_marginals.GetSlice(kb + 1, m_col_marginals.GetSize() - 1);
708 
709         double ddf = logMultinomial(m_col_marginals[kb], row_diff);
710         double drn = logMultinomial(ntot, sub_rows) - m_den_observed_path + ddf;
711 
712         int kval = 0;
713         int path_idx = -1;
714         double obs2, obs3;
715 
716         if (k > 2) {
717           // compute hash table key for current table
718           kval = irn[0] + irn[1] * m_key_multipliers[1];
719           for (int i = 2; i < irn.GetSize(); i++) kval += irn[i] * m_key_multipliers[i];
720 
721           if (!path_extremes.Find(kval, path_idx)) {
722             path_extremes[path_idx].longest_path = 1.0;
723           }
724 
725           obs2 = m_observed_path - m_facts[m_col_marginals[kb + 1]] - m_facts[m_col_marginals[kb + 2]] - ddf;
726           for (int i = 3; i <= (k - 1); i++) obs2 -= m_facts[m_col_marginals[kb + i]];
727 
728           if (path_extremes[path_idx].longest_path > 0.0) {
729 
730             path_extremes[path_idx].longest_path = longestPath(sub_rows, sub_cols, ntot);
731             if (path_extremes[path_idx].longest_path > 0.0) path_extremes[path_idx].longest_path = 0.0;
732 
733             double dspt = m_observed_path - obs2 - ddf;
734             path_extremes[path_idx].shortest_path = dspt;
735             shortestPath(sub_rows, sub_cols, path_extremes[path_idx].shortest_path);
736             path_extremes[path_idx].shortest_path -= dspt;
737             if (path_extremes[path_idx].shortest_path > 0.0) path_extremes[path_idx].shortest_path = 0.0;
738           }
739           obs3 = obs2 - path_extremes[path_idx].longest_path;
740           obs2 = obs2 - path_extremes[path_idx].shortest_path;
741         } else {
742           obs2 = m_observed_path - drn - m_den_observed_path;
743           obs3 = obs2;
744         }
745 
746         handlePastPaths(cur_node, obs2, obs3, ddf, drn, kval, nht[k & 0x1]);
747       } while (generateNewDaughter(kmax, m_row_marginals, row_diff, kd, ks));
748     }
749 
750     do {
751       cur_node = nht[(k + 1) & 0x1].Pop();
752       if (!cur_node) {
753         k--;
754 //        printf("k = %d\n", k);
755         path_extremes.ClearTable();
756         if (k < 2) return m_pvalue;
757       }
758     } while (!cur_node);
759 
760     // Unpack node row marginals from key
761     int kval = cur_node->key;
762     for (int i = m_row_marginals.GetSize() - 1; i > 0; i--) {
763       m_row_marginals[i] = kval / m_key_multipliers[i];
764       kval -= m_row_marginals[i] * m_key_multipliers[i];
765     }
766     m_row_marginals[0] = kval;
767   }
768 
769   return m_pvalue;
770 }
771 
772 
773 
774 // FExact Threaded Core Algorithm Calculate Method
775 // --------------------------------------------------------------------------------------------------------------
776 
ThreadedCalculate()777 double FExact::ThreadedCalculate()
778 {
779   int k = m_col_marginals.GetSize();
780 
781   // Setup Multi-Stage Data Structures
782   m_nht.Resize(k + 1);
783   m_pending_path_nodes.Resize(k + 1);
784   m_path_extremes.Resize(k + 1);
785   m_pending_path_extremes.Resize(k + 1);
786   m_completed_path_extremes.Resize(k + 1);
787 
788 
789   // Setup Path Calculation Worker Thread
790   Array<PathExtremesCalc*> path_calcs(Platform::AvailableCPUs());
791   for (int i = 0; i < path_calcs.GetSize(); i++) {
792     path_calcs[i] = new PathExtremesCalc(this);
793     path_calcs[i]->Start();
794   }
795 
796 
797   // Single Stage Support Variables
798   int kval = m_row_marginals[0] + m_row_marginals[1] * m_key_multipliers[1];
799   for (int i = 2; i < m_row_marginals.GetSize(); i++) kval += m_row_marginals[i] * m_key_multipliers[i];
800   NodePtr cur_node(new FExactNode(kval));
801   cur_node->past_entries.Push(PastPathLength(0));
802 
803 
804   handleNode(k, cur_node);
805 
806   Array<PathExtremes, Smart> temp_completed;
807   for (; k >= 1; k--) {
808 //    printf("k = %d\n", k);
809     if (m_pending_path_nodes[k].GetSize()) {
810       m_path_extremes_mutex.Lock();
811       while (m_pending_path_extremes[k].GetSize()) {
812         m_path_extremes_cond_complete.Wait(m_path_extremes_mutex);
813         if (m_completed_path_extremes[k].GetSize()) {
814           temp_completed = m_completed_path_extremes[k];
815           m_completed_path_extremes[k].Resize(0);
816           m_path_extremes_mutex.Unlock();
817           for (int i = 0; i < temp_completed.GetSize(); i++) {
818             PathExtremes& path = temp_completed[i];
819             int path_idx;
820             m_path_extremes[k].Find(path.key, path_idx);
821             m_path_extremes[k][path_idx].longest_path = path.longest_path;
822             m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
823 
824             for (int i = 0; i < m_pending_path_nodes[k].GetSize(); i++) {
825               PendingPathNode& p = m_pending_path_nodes[k][i];
826               if (p.kval != path.key) continue;
827               if (p.handled) break;
828               double obs2, obs3;
829               if (p.k1) {
830                 obs2 = p.obs2;
831                 obs3 = p.obs2;
832               } else {
833                 int path_idx;
834                 m_path_extremes[k].Find(p.kval, path_idx);
835                 obs3 = p.obs2 - m_path_extremes[k][path_idx].longest_path;
836                 obs2 = p.obs2 - m_path_extremes[k][path_idx].shortest_path;
837               }
838               cur_node = handlePastPaths(p.node, obs2, obs3, p.ddf, p.drn, p.kval, m_nht[k  - 1]);
839               p.node = NodePtr(NULL);
840               if (cur_node) handleNode(k - 1, cur_node);
841               p.handled = true;
842             }
843           }
844           m_path_extremes_mutex.Lock();
845         }
846       }
847       for (int i = 0; i < m_completed_path_extremes[k].GetSize(); i++) {
848         PathExtremes& path = m_completed_path_extremes[k][i];
849         int path_idx;
850         m_path_extremes[k].Find(path.key, path_idx);
851         m_path_extremes[k][path_idx].longest_path = path.longest_path;
852         m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
853       }
854       m_completed_path_extremes[k].Resize(0);
855       m_path_extremes_mutex.Unlock();
856 
857       for (int i = 0; i < m_pending_path_nodes[k].GetSize(); i++) {
858         PendingPathNode& p = m_pending_path_nodes[k][i];
859         if (p.handled) continue;
860         double obs2, obs3;
861         if (p.k1) {
862           obs2 = p.obs2;
863           obs3 = p.obs2;
864         } else {
865           int path_idx;
866           m_path_extremes[k].Find(p.kval, path_idx);
867           obs3 = p.obs2 - m_path_extremes[k][path_idx].longest_path;
868           obs2 = p.obs2 - m_path_extremes[k][path_idx].shortest_path;
869         }
870         cur_node = handlePastPaths(p.node, obs2, obs3, p.ddf, p.drn, p.kval, m_nht[k  - 1]);
871         p.node = NodePtr(NULL);
872         if (cur_node) handleNode(k - 1, cur_node);
873       }
874       m_pending_path_nodes[k].Resize(0);
875     }
876   }
877 
878   for (int i = 0; i < path_calcs.GetSize(); i++) path_calcs[i]->SetFinished();
879   m_path_extremes_cond.Broadcast();
880   for (int i = 0; i < path_calcs.GetSize(); i++) path_calcs[i]->Join();
881 
882   return m_pvalue;
883 }
884 
885 
handleNode(int k,NodePtr cur_node)886 void FExact::handleNode(int k, NodePtr cur_node)
887 {
888   MarginalArray row_marginals(m_row_marginals.GetSize());
889   MarginalArray row_diff(row_marginals.GetSize());
890   MarginalArray irn(row_marginals.GetSize());
891 
892   unpackMarginals(cur_node->key, row_marginals);
893 
894   int kb = m_col_marginals.GetSize() - k;
895   int ks = -1;
896   int kmax;
897   int kd;
898 
899   if (!generateFirstDaughter(row_marginals, m_col_marginals[kb], row_diff, kmax, kd)) return;
900 
901   int ntot = 0;
902   for (int i = kb + 1; i < m_col_marginals.GetSize(); i++) ntot += m_col_marginals[i];
903 
904   do {
905     for (int i = 0; i < row_marginals.GetSize(); i++) irn[i] = row_marginals[i] - row_diff[i];
906 
907     int nrb = 0;
908     if (k > 2) {
909       QSort(irn);
910 
911       // Adjust for zero start
912       for (; nrb < irn.GetSize(); nrb++) if (irn[nrb] != 0) break;
913     }
914 
915     // Build adjusted row array
916     MarginalArray::Slice sub_rows = irn.GetSlice(nrb, irn.GetSize() - 1);
917     MarginalArray::Slice sub_cols = m_col_marginals.GetSlice(kb + 1, m_col_marginals.GetSize() - 1);
918 
919     double ddf = logMultinomial(m_col_marginals[kb], row_diff);
920     double drn = logMultinomial(ntot, sub_rows) - m_den_observed_path + ddf;
921 
922     int kval = 0;
923     int path_idx = -1;
924     double obs2;
925 
926     if (k > 2) {
927       // compute hash table key for current table
928       kval = irn[0] + irn[1] * m_key_multipliers[1];
929       for (int i = 2; i < irn.GetSize(); i++) kval += irn[i] * m_key_multipliers[i];
930 
931       obs2 = m_observed_path - m_facts[m_col_marginals[kb + 1]] - m_facts[m_col_marginals[kb + 2]] - ddf;
932       for (int i = 3; i <= (k - 1); i++) obs2 -= m_facts[m_col_marginals[kb + i]];
933 
934       if (!m_path_extremes[k].Find(kval, path_idx)) {
935         m_path_extremes_mutex.Lock();
936         {
937           // Move completed paths to the local table
938           bool found = false;
939           for (int i = 0; i < m_completed_path_extremes[k].GetSize(); i++) {
940             PathExtremes& path = m_completed_path_extremes[k][i];
941             m_path_extremes[k].Find(path.key, path_idx);
942             m_path_extremes[k][path_idx].longest_path = path.longest_path;
943             m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
944             if (path.key == kval) found = true;
945           }
946           m_completed_path_extremes[0].Resize(0);
947 
948           // Add new pending path if not found and not already in pending table
949           if (!found && !m_pending_path_extremes[k].Find(kval, path_idx)) {
950             m_pending_path_extremes[k][path_idx].Set(sub_rows, sub_cols, obs2, ddf, ntot);
951             m_path_extremes_queue.Push(k);
952           }
953         }
954         m_path_extremes_mutex.Unlock();
955         m_path_extremes_cond.Signal();
956       }
957 
958       // Push node onto pending stack
959       int pending_idx = m_pending_path_nodes[k].GetSize();
960       m_pending_path_nodes[k].Resize(pending_idx + 1);
961       m_pending_path_nodes[k][pending_idx].Set(cur_node, obs2, drn, ddf, kval);
962     } else {
963       obs2 = m_observed_path - drn - m_den_observed_path;
964       int pending_idx = m_pending_path_nodes[k].GetSize();
965       m_pending_path_nodes[k].Resize(pending_idx + 1);
966       m_pending_path_nodes[k][pending_idx].Set(cur_node, obs2, drn, ddf, kval, true);
967     }
968 
969   } while (generateNewDaughter(kmax, row_marginals, row_diff, kd, ks));
970 
971 }
972 
973 // Path Extremes Threading Methods
974 // --------------------------------------------------------------------------------------------------------------
975 
Run()976 void FExact::PathExtremesCalc::Run()
977 {
978   while (true) {
979     int k;
980     int path_idx;
981 
982     m_fexact->m_path_extremes_mutex.Lock();
983     {
984       while (m_fexact->m_path_extremes_queue.GetSize() == 0 && m_run) {
985         m_fexact->m_path_extremes_cond.Wait(m_fexact->m_path_extremes_mutex);
986       }
987       if (!m_run) {
988         m_fexact->m_path_extremes_mutex.Unlock();
989         return;
990       }
991       k = m_fexact->m_path_extremes_queue.Pop();
992       path_idx = m_fexact->m_pending_path_extremes[k].Pop();
993     }
994     m_fexact->m_path_extremes_mutex.Unlock();
995 
996     if (path_idx < 0) continue;
997 
998     PendingPathExtremes& p = m_fexact->m_pending_path_extremes[k][path_idx];
999 
1000     double longest_path = m_fexact->longestPath(p.rows.GetSlice(0, p.rows.GetSize() - 1), p.cols.GetSlice(0, p.cols.GetSize() - 1), p.ntot);
1001     if (longest_path > 0.0) longest_path = 0.0;
1002 
1003     double dspt = m_fexact->m_observed_path - p.obs2 - p.ddf;
1004     double shortest_path = dspt;
1005     m_fexact->shortestPath(p.rows.GetSlice(0, p.rows.GetSize() - 1), p.cols.GetSlice(0, p.cols.GetSize() - 1), shortest_path);
1006     shortest_path -= dspt;
1007     if (shortest_path > 0.0) shortest_path = 0.0;
1008 
1009     m_fexact->m_path_extremes_mutex.Lock();
1010     {
1011       // Record completed calculation
1012       int completed_idx = m_fexact->m_completed_path_extremes[k].GetSize();
1013       m_fexact->m_completed_path_extremes[k].Resize(completed_idx + 1);
1014       m_fexact->m_completed_path_extremes[k][completed_idx].Set(p.key, longest_path, shortest_path);
1015 
1016       // Clear out the pending record
1017       m_fexact->m_pending_path_extremes[k].Remove(path_idx);
1018     }
1019     m_fexact->m_path_extremes_mutex.Unlock();
1020     m_fexact->m_path_extremes_cond_complete.Signal();
1021   }
1022 }
1023 
1024 
1025 
1026 // FExact Core Algorithm Support Methods
1027 // --------------------------------------------------------------------------------------------------------------
1028 
generateFirstDaughter(const MarginalArray & row_marginals,int n,MarginalArray & row_diff,int & kmax,int & kd)1029 inline bool FExact::generateFirstDaughter(const MarginalArray& row_marginals, int n, MarginalArray& row_diff, int& kmax, int& kd)
1030 {
1031   row_diff.SetAll(0);
1032 
1033   kmax = row_marginals.GetSize() - 1;
1034   kd = row_marginals.GetSize();
1035   do {
1036     kd--;
1037     int ntot = (n < row_marginals[kd]) ? n : row_marginals[kd];
1038     row_diff[kd] = ntot;
1039     if (row_diff[kmax] == 0) kmax--;
1040     n -= ntot;
1041   } while (n > 0 && kd > 0);
1042 
1043   if (n != 0) return false;
1044 
1045   return true;
1046 }
1047 
1048 
generateNewDaughter(int kmax,const MarginalArray & row_marginals,MarginalArray & row_diff,int & idx_dec,int & idx_inc)1049 bool FExact::generateNewDaughter(int kmax, const MarginalArray& row_marginals, MarginalArray& row_diff, int& idx_dec, int& idx_inc)
1050 {
1051   if (idx_inc == -1) {
1052     while (row_diff[++idx_inc] == row_marginals[idx_inc]) ;
1053   }
1054 
1055   // Find node to decrement
1056   if (row_diff[idx_dec] > 0 && idx_dec > idx_inc) {
1057     row_diff[idx_dec]--;
1058     while (row_marginals[--idx_dec] == 0) ;
1059     int m = idx_dec;
1060 
1061     // Find node to increment
1062     while (row_diff[m] >= row_marginals[m]) m--;
1063     row_diff[m]++;
1064 
1065     if (m == idx_inc && row_diff[m] == row_marginals[m]) idx_inc = idx_dec;
1066   } else {
1067     int idx = 0;
1068     do {
1069       // Check for finish
1070       idx = idx_dec + 1;
1071       bool found = false;
1072       for (; idx < row_diff.GetSize(); idx++) {
1073         if (row_diff[idx] > 0) {
1074           found = true;
1075           break;
1076         }
1077       }
1078       if (!found) return false;
1079 
1080       int marginal_total = 1;
1081       for (int i = 0; i <= idx_dec; i++) {
1082         marginal_total += row_diff[i];
1083         row_diff[i] = 0;
1084       }
1085       idx_dec = idx;
1086       do {
1087         idx_dec--;
1088         int m = (marginal_total < row_marginals[idx_dec]) ? marginal_total : row_marginals[idx_dec];
1089         row_diff[idx_dec] = m;
1090         marginal_total -= m;
1091       } while (marginal_total > 0 && idx_dec != 0);
1092 
1093       if (marginal_total > 0) {
1094         if (idx != (kmax)) {
1095           idx_dec = idx;
1096           continue;
1097         }
1098         return false;
1099       } else {
1100         break;
1101       }
1102     } while (true);
1103     row_diff[idx]--;
1104     for (idx_inc = 0; row_diff[idx_inc] >= row_marginals[idx_inc]; idx_inc++) if (idx_inc > idx_dec) break;
1105   }
1106 
1107   return true;
1108 }
1109 
1110 
handlePastPaths(NodePtr & cur_node,double obs2,double obs3,double ddf,double drn,int kval,NodeHashTable & nht)1111 FExact::NodePtr FExact::handlePastPaths(NodePtr& cur_node, double obs2, double obs3, double ddf, double drn, int kval,
1112                              NodeHashTable& nht)
1113 {
1114   NodePtr new_node(NULL);
1115   for (int i = 0; i < cur_node->past_entries.GetSize(); i++) {
1116     double past_path = cur_node->past_entries[i].value;
1117     int path_freq = cur_node->past_entries[i].observed;
1118     if (past_path <= obs3) {
1119       // Path shorter than longest path, add to the pvalue and continue
1120       m_pvalue += (double)(path_freq) * exp(past_path + drn);
1121     } else if (past_path < obs2) {
1122       int nht_idx;
1123       double new_path = past_path + ddf;
1124       if (nht.Find(kval, nht_idx)) {
1125         // Existing Node was found
1126         recordPath(new_path, path_freq, nht[nht_idx].past_entries);
1127       } else {
1128         // New Node added, insert this observed path
1129         new_node = nht.Get(nht_idx);
1130         new_node->past_entries.Push(PastPathLength(new_path, path_freq));
1131       }
1132     }
1133   }
1134   return new_node;
1135 }
1136 
1137 
recordPath(double path_length,int path_freq,Array<PastPathLength,Smart> & past_entries)1138 void FExact::recordPath(double path_length, int path_freq, Array<PastPathLength, Smart>& past_entries)
1139 {
1140   // Search for past path within m_tolerance and add observed frequency to it
1141   double test1 = path_length - m_tolerance;
1142   double test2 = path_length + m_tolerance;
1143 
1144   int j = 0;
1145   int old_j = 0;
1146   while (true) {
1147     double test_path = past_entries[j].value;
1148     if (test_path < test1) {
1149       old_j = j;
1150       j = past_entries[j].next_left;
1151       if (j >= 0) continue;
1152     } else if (test_path > test2) {
1153       old_j = j;
1154       j = past_entries[j].next_right;
1155       if (j >= 0) continue;
1156     } else {
1157       past_entries[j].observed += path_freq;
1158       return;
1159     }
1160     break;
1161   }
1162 
1163   // If no path within m_tolerance is found, add new past path length to the node
1164   int new_idx = past_entries.GetSize();
1165   past_entries.Push(PastPathLength(path_length, path_freq));
1166 
1167   double test_path = past_entries[old_j].value;
1168   if (test_path < test1) {
1169     past_entries[old_j].next_left = new_idx;
1170   } else if (test_path > test2) {
1171     past_entries[old_j].next_right = new_idx;
1172   } else {
1173     assert(false);
1174   }
1175 }
1176 
1177 
unpackMarginals(int key,MarginalArray & row_marginals)1178 void FExact::unpackMarginals(int key, MarginalArray& row_marginals)
1179 {
1180   for (int i = row_marginals.GetSize() - 1; i > 0; i--) {
1181     row_marginals[i] = key / m_key_multipliers[i];
1182     key -= row_marginals[i] * m_key_multipliers[i];
1183   }
1184   row_marginals[0] = key;
1185 }
1186 
1187 
1188 
1189 // FExact Longest Path Methods
1190 // --------------------------------------------------------------------------------------------------------------
1191 
longestPath(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,int marginal_total)1192 double FExact::longestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, int marginal_total)
1193 {
1194   class ValueHashTable
1195   {
1196   private:
1197     Array<Pair<int, double> >* m_table;
1198     Array<int> m_stack;
1199     int m_entry_count;
1200 
1201   public:
1202     inline ValueHashTable(int size = 200) : m_table(new Array<Pair<int, double> >(size)), m_stack(size) { ClearTable(); }
1203     inline ~ValueHashTable() { delete m_table; }
1204 
1205     int GetEntryCount() const { return m_entry_count; }
1206 
1207     bool Find(int key, int& idx)
1208     {
1209       int init = key % m_table->GetSize();
1210       idx = init;
1211       for (; idx < m_table->GetSize(); idx++) {
1212         if ((*m_table)[idx].Value1() < 0) {
1213           m_stack[m_entry_count] = idx;
1214           (*m_table)[idx].Value1() = key;
1215           m_entry_count++;
1216           return false;
1217         } else if ((*m_table)[idx].Value1() == key) {
1218           return true;
1219         }
1220       }
1221       for (idx = 0; idx < init; idx++) {
1222         if ((*m_table)[idx].Value1() < 0) {
1223           m_stack[m_entry_count] = idx;
1224           (*m_table)[idx].Value1() = key;
1225           m_entry_count++;
1226           return false;
1227         } else if ((*m_table)[idx].Value1() == key) {
1228           return true;
1229         }
1230       }
1231       Rehash(key, idx);
1232       return false;
1233     }
1234 
1235     inline double& operator[](int idx) { return (*m_table)[idx].Value2(); }
1236 
1237     Pair<int, double> Pop()
1238     {
1239       Pair<int, double> tmp = (*m_table)[m_stack[--m_entry_count]];
1240       (*m_table)[m_stack[m_entry_count]].Value1() = -1;
1241       return tmp;
1242     }
1243 
1244     inline void ClearTable()
1245     {
1246       m_entry_count = 0;
1247       for (int i = 0; i < m_table->GetSize(); i++) (*m_table)[i].Value1() = -1;
1248     }
1249 
1250   private:
1251     void Rehash(int key, int& idx)
1252     {
1253       Array<Pair<int, double> >* old_table = m_table;
1254       m_table = new Array<Pair<int, double> >(old_table->GetSize() * 2);
1255       for (int i = 0; i < m_table->GetSize(); i++) (*m_table)[i].Value1() = -1;
1256       m_stack.Resize(old_table->GetSize() * 2);
1257       for (int i = 0; i < old_table->GetSize(); i++) {
1258         int t_idx;
1259         Find((*old_table)[i].Value1(), t_idx);
1260         (*m_table)[t_idx].Value2() = (*old_table)[i].Value2();
1261       }
1262       Find(key, idx);
1263       delete old_table;
1264     }
1265   };
1266 
1267   // 1 x c
1268   if (row_marginals.GetSize() <= 1) {
1269     double longest_path = 0.0;
1270     for (int i = 0; i < col_marginals.GetSize(); i++) longest_path -= m_facts[col_marginals[i]];
1271     return longest_path;
1272   }
1273 
1274   // r x 1
1275   if (col_marginals.GetSize() <= 1) {
1276     double longest_path = 0.0;
1277     for (int i = 0; i < row_marginals.GetSize(); i++) longest_path -= m_facts[row_marginals[i]];
1278     return longest_path;
1279   }
1280 
1281   // 2 x 2
1282   if (row_marginals.GetSize() == 2 && col_marginals.GetSize() == 2) {
1283     int n11 = (row_marginals[0] + 1) * (col_marginals[0] + 1) / (marginal_total + 2);
1284     int n12 = row_marginals[0] - n11;
1285     return -m_facts[n11] - m_facts[n12] - m_facts[col_marginals[0] - n11] - m_facts[col_marginals[1] - n12];
1286   }
1287 
1288   double val = 0.0;
1289   bool min = false;
1290   if (row_marginals[row_marginals.GetSize() - 1] <= row_marginals[0] + col_marginals.GetSize()) {
1291     min = longestPathSpecial(row_marginals, col_marginals, val);
1292   }
1293   if (!min && col_marginals[col_marginals.GetSize() - 1] <= col_marginals[0] + row_marginals.GetSize()) {
1294     min = longestPathSpecial(col_marginals, row_marginals, val);
1295   }
1296 
1297   if (min) {
1298     return -val;
1299   }
1300 
1301 
1302   int ntot = marginal_total;
1303   MarginalArray lrow;
1304   MarginalArray lcol;
1305 
1306   if (row_marginals.GetSize() >= col_marginals.GetSize()) {
1307     lrow.EnhancedSmart<int>::operator=(row_marginals);
1308     lcol.EnhancedSmart<int>::operator=(col_marginals);
1309   } else {
1310     lrow.EnhancedSmart<int>::operator=(col_marginals);
1311     lcol.EnhancedSmart<int>::operator=(row_marginals);
1312   }
1313 
1314   Array<int> nt(lcol.GetSize());
1315   nt[0] = ntot - lcol[0];
1316   for (int i = 1; i < lcol.GetSize(); i++) nt[i] = nt[i - 1] - lcol[i];
1317 
1318 
1319   Array<double> alen(col_marginals.GetSize() + 1);
1320   alen.SetAll(0.0);
1321 
1322   ValueHashTable vht[2];
1323   int active_vht = 0;
1324 
1325   double vmn = 1.0e10;
1326   int nc1s = lcol.GetSize() - 2;
1327   int kyy = lcol[lcol.GetSize() - 1] + 1;
1328 
1329   Array<int> lb(lrow.GetSize());
1330   Array<int> nu(lrow.GetSize());
1331   Array<int> nr(lrow.GetSize());
1332 
1333 
1334   while (true) {
1335     bool continue_main = false;
1336 
1337     // Setup to generate new node
1338     int lev = 0;
1339     int nr1 = lrow.GetSize() - 1;
1340     int nrt = lrow[0];
1341     int nct = lcol[0];
1342     lb[0] = (int)((((double)nrt + 1.0) * (nct + 1)) / (double)(ntot + nr1 * (nc1s + 1) + 1) - m_tolerance) - 1;
1343     nu[0] = (int)((((double)nrt + nc1s + 1.0) * (nct + nr1)) / (double)(ntot + nr1 + nc1s + 1)) - lb[0] + 1;
1344     nr[0] = nrt - lb[0];
1345 
1346     while (true) {
1347       do {
1348         nu[lev]--;
1349         if (nu[lev] == 0) {
1350           if (lev == 0) {
1351             do {
1352               if (vht[(active_vht) ? 0 : 1].GetEntryCount()) {
1353                 Pair<int, double> entry = vht[(active_vht) ? 0 : 1].Pop();
1354                 val = entry.Value2();
1355                 int key = entry.Value1();
1356 
1357                 // Compute Marginals
1358                 for (int i = lcol.GetSize() - 1; i > 0; i--) {
1359                   lcol[i] = key % kyy;
1360                   key = key / kyy;
1361                 }
1362                 lcol[0] = key;
1363 
1364                 // Set up nt array
1365                 nt[0] = ntot - lcol[0];
1366                 for (int i = 1; i < lcol.GetSize(); i++) nt[i] = nt[i - 1] - lcol[i];
1367 
1368                 min = false;
1369                 if (lrow[lrow.GetSize() - 1] <= lrow[0] + lcol.GetSize()) {
1370                   min = longestPathSpecial(lrow.GetSlice(0, lrow.GetSize() - 1), lcol.GetSlice(0, lcol.GetSize() - 1), val);
1371                 }
1372                 if (!min && lcol[lcol.GetSize() - 1] <= lcol[0] + lrow.GetSize()) {
1373                   min = longestPathSpecial(lrow.GetSlice(0, lrow.GetSize() - 1), lcol.GetSlice(0, lcol.GetSize() - 1), val);
1374                 }
1375 
1376                 if (min) {
1377                   if (val < vmn)
1378                     vmn = val;
1379                   continue;
1380                 }
1381                 continue_main = true;
1382               } else if (lrow.GetSize() > 2 && vht[active_vht].GetEntryCount()) {
1383                 // Go to next level
1384                 ntot -= lrow[0];
1385                 Array<int> tmp(lrow);
1386                 lrow.ResizeClear(lrow.GetSize() - 1);
1387                 for (int i = 0; i < lrow.GetSize(); i++) lrow[i] = tmp[i + 1];
1388                 active_vht = (active_vht) ? 0 : 1;
1389                 continue;
1390               }
1391               break;
1392             } while (true);
1393             if (!continue_main)
1394               return -vmn;
1395           }
1396           if (continue_main) break;
1397           lev--;
1398           continue;
1399         }
1400         break;
1401       } while (true);
1402       if (continue_main) break;
1403 
1404       lb[lev]++;
1405       nr[lev]--;
1406 
1407       for (alen[lev + 1] = alen[lev] + m_facts[lb[lev]]; lev < nc1s; alen[lev + 1] = alen[lev] + m_facts[lb[lev]]) {
1408         int nn1 = nt[lev];
1409         int nrt = nr[lev];
1410         lev++;
1411         int nc1 = lcol.GetSize() - lev - 1;
1412         int nct = lcol[lev];
1413         lb[lev] = (int)((double)((nrt + 1) * (nct + 1)) / (double)(nn1 + nr1 * nc1 + 1) - m_tolerance);
1414         nu[lev] = (int)((double)((nrt + nc1) * (nct + nr1)) / (double)(nn1 + nr1 + nc1) - lb[lev] + 1);
1415         nr[lev] = nrt - lb[lev];
1416       }
1417       alen[lcol.GetSize()] = alen[lev + 1] + m_facts[nr[lev]];
1418       lb[lcol.GetSize() - 1] = nr[lev];
1419 
1420       double v = val + alen[lcol.GetSize()];
1421       if (lrow.GetSize() == 2) {
1422         for (int i = 0; i < lcol.GetSize(); i++) v += m_facts[lcol[i] - lb[i]];
1423         if (v < vmn)
1424           vmn = v;
1425       } else if (lrow.GetSize() == 3 && lcol.GetSize() == 2) {
1426         int nn1 = ntot - lrow[0] + 2;
1427         int ic1 = lcol[0] - lb[0];
1428         int ic2 = lcol[1] - lb[1];
1429         int n11 = (lrow[1] + 1) * (ic1 + 1) / nn1;
1430         int n12 = lrow[1] - n11;
1431         v += m_facts[n11] + m_facts[n12] + m_facts[ic1 - n11] + m_facts[ic2 - n12];
1432         if (v < vmn)
1433           vmn = v;
1434       } else {
1435         Array<int> it(lcol.GetSize());
1436         for (int i = 0; i < lcol.GetSize(); i++) it[i] = lcol[i] - lb[i];
1437 
1438         if (lcol.GetSize() == 2) {
1439           if (it[0] > it[1]) it.Swap(0, 1);
1440         } else {
1441           QSort(it);
1442         }
1443 
1444         // Compute hash value
1445         int key = it[0] * kyy + it[1];
1446         for (int i = 2; i < lcol.GetSize(); i++) key = it[i] + key * kyy;
1447 
1448         // Put onto stack (or update stack entry as necessary)
1449         int t_idx;
1450         if (vht[active_vht].Find(key, t_idx)) {
1451           if (v < vht[active_vht][t_idx]) vht[active_vht][t_idx] = v;
1452         } else {
1453           vht[active_vht][t_idx] = v;
1454         }
1455       }
1456     }
1457   }
1458 
1459 
1460 
1461   return 0.0;
1462 }
1463 
1464 
longestPathSpecial(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,double & val)1465 bool FExact::longestPathSpecial(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& val)
1466 {
1467   Array<int> nd(row_marginals.GetSize() - 1);
1468   Array<int> ne(col_marginals.GetSize());
1469   Array<int> m(col_marginals.GetSize());
1470 
1471   nd.SetAll(0);
1472   int is = col_marginals[0] / row_marginals.GetSize();
1473   ne[0] = is;
1474   int ix = col_marginals[0] - row_marginals.GetSize() * is;
1475   m[0] = ix;
1476   if (ix != 0) nd[ix - 1] = 1;
1477 
1478   for (int i = 1; i < col_marginals.GetSize(); i++) {
1479     ix = col_marginals[i] / row_marginals.GetSize();
1480     ne[i] = ix;
1481     is += ix;
1482     ix = col_marginals[i] - row_marginals.GetSize() * ix;
1483     m[i] = ix;
1484     if (ix != 0) nd[ix - 1]++;
1485   }
1486 
1487   for (int i = nd.GetSize() - 2; i >= 0; i--) nd[i] += nd[i + 1];
1488 
1489   ix = 0;
1490   int nrow1 = row_marginals.GetSize() - 1;
1491   for (int i = (row_marginals.GetSize() - 1); i > 0; i--) {
1492     ix += is + nd[nrow1 - i] - row_marginals[i];
1493     if (ix < 0) return false;
1494   }
1495 
1496   val = 0.0;
1497   for (int i = 0; i < col_marginals.GetSize(); i++) {
1498     ix = ne[i];
1499     is = m[i];
1500     val += is * m_facts[ix + 1] + (row_marginals.GetSize() - is) * m_facts[ix];
1501   }
1502 
1503   return true;
1504 }
1505 
1506 
1507 
1508 
1509 // FExact Shortest Path Methods
1510 // --------------------------------------------------------------------------------------------------------------
1511 
removeFromVector(const Array<int,ManualBuffer> & src,int idx_remove,Array<int,ManualBuffer> & dest)1512 inline void FExact::removeFromVector(const Array<int, ManualBuffer>& src, int idx_remove, Array<int, ManualBuffer>& dest)
1513 {
1514   dest.Resize(src.GetSize() - 1);
1515   for (int i = 0; i < idx_remove; i++) dest[i] = src[i];
1516   for (int i = idx_remove + 1; i < src.GetSize(); i++) dest[i - 1] = src[i];
1517 }
1518 
1519 
reduceZeroInVector(const Array<int,ManualBuffer> & src,int value,int idx_start,Array<int,ManualBuffer> & dest)1520 inline void FExact::reduceZeroInVector(const Array<int, ManualBuffer>& src, int value, int idx_start, Array<int, ManualBuffer>& dest)
1521 {
1522   dest.Resize(src.GetSize());
1523 
1524   int i = 0;
1525   for (; i < idx_start; i++) dest[i] = src[i];
1526 
1527   for (; i < (src.GetSize() - 1); i++) {
1528     if (value >= src[i + 1]) {
1529       break;
1530     }
1531     dest[i] = src[i + 1];
1532   }
1533   dest[i] = value;
1534 
1535   for (++i; i < src.GetSize(); i++) dest[i] = src[i];
1536 }
1537 
1538 
shortestPath(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,double & shortest_path)1539 void FExact::shortestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& shortest_path)
1540 {
1541   // Take care of easy cases first
1542 
1543   // 1 x c
1544   if (row_marginals.GetSize() == 1) {
1545     for (int i = 0; i < col_marginals.GetSize(); i++) shortest_path -= m_facts[col_marginals[i]];
1546     return;
1547   }
1548 
1549   // r x 1
1550   if (col_marginals.GetSize() == 1) {
1551     for (int i = 0; i < row_marginals.GetSize(); i++) shortest_path -= m_facts[row_marginals[i]];
1552     return;
1553   }
1554 
1555   // 2 x 2
1556   if (row_marginals.GetSize() == 2 && col_marginals.GetSize() == 2) {
1557     if (row_marginals[1] <= col_marginals[1]) {
1558       shortest_path += -m_facts[row_marginals[1]] - m_facts[col_marginals[0]] - m_facts[col_marginals[1] - row_marginals[1]];
1559     } else {
1560       shortest_path += -m_facts[col_marginals[1]] - m_facts[row_marginals[0]] - m_facts[row_marginals[1] - col_marginals[1]];
1561     }
1562     return;
1563   }
1564 
1565   // General Case
1566 
1567 
1568   const int ROW_BUFFER_SIZE = (row_marginals.GetSize() + col_marginals.GetSize() + 1) * row_marginals.GetSize();
1569   const int COL_BUFFER_SIZE = (row_marginals.GetSize() + col_marginals.GetSize() + 1) * col_marginals.GetSize();
1570   SmartPtr<int, NoCopy, ArrayStorage> row_data(new int[ROW_BUFFER_SIZE]);
1571   SmartPtr<int, NoCopy, ArrayStorage> col_data(new int[COL_BUFFER_SIZE]);
1572   Array<Array<int, ManualBuffer> > row_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1573   Array<Array<int, ManualBuffer> > col_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1574   for (int i = 0; i < row_stack.GetSize(); i++) {
1575     row_stack[i].SetBuffer(GetInternalPtr(row_data) + (i * row_marginals.GetSize()));
1576     col_stack[i].SetBuffer(GetInternalPtr(col_data) + (i * col_marginals.GetSize()));
1577   }
1578 
1579   row_stack[0].Resize(row_marginals.GetSize());
1580   for (int i = 0; i < row_marginals.GetSize(); i++) row_stack[0][i] = row_marginals[row_marginals.GetSize() - i - 1];
1581   col_stack[0].Resize(col_marginals.GetSize());
1582   for (int i = 0; i < col_marginals.GetSize(); i++) col_stack[0][i] = col_marginals[col_marginals.GetSize() - i - 1];
1583 
1584   int istk = 0;
1585 
1586   Array<double> y_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1587   Array<int> l_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1588   Array<int> m_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1589   Array<int> n_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1590   y_stack[0] = 0.0;
1591   double y = 0.0;
1592 
1593   int l = 0;
1594   double amx = 0.0;
1595 
1596   int m, n, jrow, jcol;
1597 
1598   do {
1599     int row1 = row_stack[istk][0];
1600     int col1 = col_stack[istk][0];
1601     if (row1 > col1) {
1602       if (row_stack[istk].GetSize() >= col_stack[istk].GetSize()) {
1603         m = col_stack[istk].GetSize() - 1;
1604         n = 2;
1605       } else {
1606         m = row_stack[istk].GetSize();
1607         n = 1;
1608       }
1609     } else if (row1 < col1) {
1610       if (row_stack[istk].GetSize() <= col_stack[istk].GetSize()) {
1611         m = row_stack[istk].GetSize() - 1;
1612         n = 1;
1613       } else {
1614         m = col_stack[istk].GetSize();
1615         n = 2;
1616       }
1617     } else {
1618       if (row_stack[istk].GetSize() <= col_stack[istk].GetSize()) {
1619         m = row_stack[istk].GetSize() - 1;
1620         n = 1;
1621       } else {
1622         m = col_stack[istk].GetSize() - 1;
1623         n = 2;
1624       }
1625     }
1626 
1627     do {
1628       if (n == 1) {
1629         jrow = l;
1630         jcol = 0;
1631       } else {
1632         jrow = 0;
1633         jcol = l;
1634       }
1635 
1636       int rowt = row_stack[istk][jrow];
1637       int colt = col_stack[istk][jcol];
1638       int mn = (rowt > colt) ? colt : rowt;
1639       y += m_facts[mn];
1640       if (rowt == colt) {
1641         removeFromVector(row_stack[istk], jrow, row_stack[istk + 1]);
1642         removeFromVector(col_stack[istk], jcol, col_stack[istk + 1]);
1643       } else if (rowt > colt) {
1644         removeFromVector(col_stack[istk], jcol, col_stack[istk + 1]);
1645         reduceZeroInVector(row_stack[istk], rowt - colt, jrow, row_stack[istk + 1]);
1646       } else {
1647         removeFromVector(row_stack[istk], jrow, row_stack[istk + 1]);
1648         reduceZeroInVector(col_stack[istk], colt - rowt, jcol, col_stack[istk + 1]);
1649       }
1650 
1651       if (row_stack[istk + 1].GetSize() == 1 || col_stack[istk + 1].GetSize() == 1) {
1652         if (row_stack[istk + 1].GetSize() == 1) {
1653           for (int i = 0; i < col_stack[istk + 1].GetSize(); i++) y += m_facts[col_stack[istk + 1][i]];
1654         }
1655         if (col_stack[istk + 1].GetSize() == 1) {
1656           for (int i = 0; i < row_stack[istk + 1].GetSize(); i++) y += m_facts[row_stack[istk + 1][i]];
1657         }
1658 
1659         if (y > amx) {
1660           amx = y;
1661           if (shortest_path - amx <= m_tolerance) {
1662             shortest_path = 0.0;
1663             return;
1664           }
1665         }
1666 
1667         bool continue_outer = false;
1668         for (--istk; istk >= 0; istk--) {
1669           l = l_stack[istk] + 1;
1670           for (; l < m_stack[istk]; l++) {
1671             n = n_stack[istk];
1672             y = y_stack[istk];
1673             if (n == 1) {
1674               if (row_stack[istk][l] < row_stack[istk][l - 1]) {
1675                 continue_outer = true;
1676                 break;
1677               }
1678             } else if (n == 2) {
1679               if (col_stack[istk][l] < col_stack[istk][l - 1]) {
1680                 continue_outer = true;
1681                 break;
1682               }
1683             }
1684           }
1685           if (continue_outer) break;
1686         }
1687         if (continue_outer) continue;
1688 
1689         shortest_path -= amx;
1690         if (shortest_path - amx <= m_tolerance) shortest_path = 0.0;
1691         return;
1692       } else {
1693         break;
1694       }
1695     } while (true);
1696 
1697     l_stack[istk] = l;
1698     m_stack[istk] = m;
1699     n_stack[istk] = n;
1700     istk++;
1701     y_stack[istk] = y;
1702     l = 0;
1703   } while (true);
1704 
1705 }
1706 
1707 
1708 
1709 // Internal Function Definitions
1710 // --------------------------------------------------------------------------------------------------------------
1711 
cummulativeGamma(double q,double alpha,bool & fault)1712 double cummulativeGamma(double q, double alpha, bool& fault)
1713 {
1714   if (q <= 0.0 || alpha <= 0.0) {
1715     fault = true;
1716     return 0.0;
1717   }
1718 
1719   double f = exp(alpha * log(q) - logGamma(alpha + 1.0, fault) - q); // no need to test logGamma fail as an error is impossible
1720   if (f == 0.0) {
1721     fault = true;
1722     return 0.0;
1723   }
1724 
1725   fault = false;
1726 
1727   double c = 1.0;
1728   double ret_val = 1.0;
1729   double a = alpha;
1730 
1731   do {
1732     a += 1.0;
1733     c = c * q / a;
1734     ret_val += c;
1735   } while (c / ret_val > (1e-6));
1736   ret_val *= f;
1737 
1738   return ret_val;
1739 }
1740 
1741 
logGamma(double x,bool & fault)1742 double logGamma(double x, bool& fault)
1743 {
1744   const double a1 = .918938533204673;
1745   const double a2 = 5.95238095238e-4;
1746   const double a3 = 7.93650793651e-4;
1747   const double a4 = .002777777777778;
1748   const double a5 = .083333333333333;
1749 
1750   if (x < 0.0) {
1751     fault = true;
1752     return 0.0;
1753   }
1754 
1755   fault = false;
1756 
1757   double f = 0.0;
1758 
1759   if (x < 7.0) {
1760     f = x;
1761 
1762     x += 1.0;
1763     while (x < 7.0) {
1764       f *= x;
1765       x += 1.0;
1766     }
1767 
1768     f = -log(f);
1769   }
1770 
1771   double z = 1 / (x * x);
1772   return f + (x - .5) * log(x) - x + a1 + (((-a2 * z + a3) * z - a4) * z + a5) / x;
1773 }
1774