1 /*
2 * FishersExact.cc
3 * Apto
4 *
5 * Created by David on 2/15/11.
6 * Copyright 2011 David Michael Bryson. All rights reserved.
7 * http://programerror.com/software/apto
8 *
9 * Redistribution and use in source and binary forms, with or without modification, are permitted provided that the
10 * following conditions are met:
11 *
12 * 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the
13 * following disclaimer.
14 * 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the
15 * following disclaimer in the documentation and/or other materials provided with the distribution.
16 * 3. Neither the name of David Michael Bryson, nor the names of contributors may be used to endorse or promote
17 * products derived from this software without specific prior written permission.
18 *
19 * THIS SOFTWARE IS PROVIDED BY DAVID MICHAEL BRYSON AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES,
20 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
21 * DISCLAIMED. IN NO EVENT SHALL DAVID MICHAEL BRYSON OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
22 * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
23 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
24 * WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *
27 * Authors: David M. Bryson <david@programerror.com>
28 *
29 * Fishers Exact Test for r x c Contingency Tables
30 * Based on:
31 * "A Network Algorithm for Performing Fisher's Exact Test in r x c Contingency Tables" by Mehta and Patel
32 * Journal of the American Statistical Association, June 1983, Volume 87, Number 382, Pages 427-434
33 * "Algorithm 643: FEXACT: A FORTRAN Subroutine for Fisher's Exact Test on Unordered r x c Contingency Tables"
34 * by Mehta and Patel, ACM Transactions on Mathematical Software, June 1986, Volume 12, No. 2, Pages 154-161
35 * "A Remark on Algorithm 643: FEXACT: A FORTRAN Subroutine for Fisher's Exact Test on Unordered r x c Contingency Tables"
36 * by Clarkson, Fan, and Joe, ACM Transactions on Mathematical Software, Dec. 1993, Volume 19, No. 4, Pages 484-488
37 *
38 */
39
40 #include "apto/platform.h"
41
42 #include "apto/stat/ContingencyTable.h"
43 #include "apto/stat/Functions.h"
44
45 #include "apto/core/Array.h"
46 #include "apto/core/ArrayUtils.h"
47 #include "apto/core/ConditionVariable.h"
48 #include "apto/core/Mutex.h"
49 #include "apto/core/Pair.h"
50 #include "apto/core/RefCount.h"
51 #include "apto/core/SmartPtr.h"
52 #include "apto/core/Thread.h"
53
54 #include <cmath>
55 #include <limits>
56
57 #include <iostream>
58 #include <stdio.h>
59
60 using namespace Apto;
61
62
63 // Constant Declarations
64 // --------------------------------------------------------------------------------------------------------------
65
66 static const double TOLERANCE = 3.4525e-7; // Tolerance, as used in Algorithm 643
67 static const int THREADING_THRESHOLD = 25;
68 static const int DEFAULT_TABLE_SIZE = 300;
69
70
71 // Exported Function Declarations
72 // --------------------------------------------------------------------------------------------------------------
73
74 namespace Apto {
75 namespace Stat {
76 double FishersExact(const ContingencyTable& table);
77 };
78 };
79
80
81
82 // Internal Function Declarations
83 // --------------------------------------------------------------------------------------------------------------
84
85 static double cummulativeGamma(double q, double alpha, bool& fault);
86 static double logGamma(double x, bool& fault);
87
88
89
90 // Internal Class/Struct Definitions
91 // --------------------------------------------------------------------------------------------------------------
92
93 template <class T> class ManualBuffer
94 {
95 private:
96 T* m_data; // Data Array
97 int m_size; // Array Size
98
99 protected:
100 typedef T StoredType;
101
ManualBuffer(int size=0)102 explicit ManualBuffer(int size = 0) : m_data(NULL), m_size(size) { ; }
ManualBuffer(const ManualBuffer & rhs)103 ManualBuffer(const ManualBuffer& rhs) : m_data(NULL), m_size(0) { this->operator=(rhs); }
ManualBuffer()104 ManualBuffer() { ; }
105
GetSize() const106 int GetSize() const { return m_size; }
ResizeClear(const int in_size)107 void ResizeClear(const int in_size) { m_size = in_size; }
Resize(int new_size)108 void Resize(int new_size) { m_size = new_size; }
109
operator =(const ManualBuffer & rhs)110 ManualBuffer& operator=(const ManualBuffer& rhs)
111 {
112 m_size = rhs.m_size;
113 for (int i = 0; i < rhs.m_size; i++) m_data[i] = rhs.m_data[i];
114 return *this;
115 }
116
operator [](const int index)117 inline T& operator[](const int index) { return m_data[index]; }
operator [](const int index) const118 inline const T& operator[](const int index) const { return m_data[index]; }
119
Swap(int idx1,int idx2)120 void Swap(int idx1, int idx2)
121 {
122 T v = m_data[idx1];
123 m_data[idx1] = m_data[idx2];
124 m_data[idx2] = v;
125 }
126 public:
SetBuffer(T * data)127 void SetBuffer(T* data) { m_data = data; }
128 };
129
130
131 template <class T> class EnhancedSmart : public Smart<T>
132 {
133 public:
EnhancedSmart(int size=0)134 EnhancedSmart(int size = 0) : Smart<T>(size) { ; }
EnhancedSmart(const EnhancedSmart & rhs)135 EnhancedSmart(const EnhancedSmart& rhs) : Smart<T>(rhs) { ; }
136
137 class Slice
138 {
139 friend class EnhancedSmart;
140 private:
141 const T* m_data;
142 const int m_size;
143
Slice(const T * data,const int size)144 Slice(const T* data, const int size) : m_data(data), m_size(size) { ; }
145
146 public:
GetSize() const147 inline int GetSize() const { return m_size; }
148
operator [](const int index) const149 inline const T& operator[](const int index) const
150 {
151 assert(index >= 0); // Lower Bounds Error
152 assert(index < m_size); // Upper Bounds Error
153 return m_data[index];
154 }
155 };
156
GetSlice(int start,int end) const157 inline Slice GetSlice(int start, int end) const
158 {
159 assert(start >= 0);
160 assert(end < Smart<T>::m_active);
161
162 return Slice(Smart<T>::m_data + start, end - start + 1);
163 }
164
operator =(const Slice & rhs)165 EnhancedSmart& operator=(const Slice& rhs)
166 {
167 if (Smart<T>::m_active != rhs.GetSize()) Smart<T>::ResizeClear(rhs.GetSize());
168 for (int i = 0; i < rhs.GetSize(); i++) Smart<T>::m_data[i] = rhs.m_data[i];
169 return *this;
170 }
171 };
172
173 typedef Array<int, EnhancedSmart> MarginalArray;
174
175
176 class FExact
177 {
178 public:
179 // FExact Public Methods
180 // ------------------------------------------------------------------------------------------------------------
181
182 FExact(const Stat::ContingencyTable& table, double tolerance);
183
184 double Calculate();
185 double ThreadedCalculate();
186
187
188 private:
189 // FExact Internal Type Declarations
190 // ------------------------------------------------------------------------------------------------------------
191
192 // Main Node
193 class FExactNode;
194 struct PastPathLength;
195 class NodeHashTable;
196 typedef SmartPtr<FExactNode, InternalRCObject> NodePtr;
197
198
199 // Path Extremes
200 struct PathExtremes;
201 class PathExtremesHashTable;
202 struct PendingPathExtremes;
203 class PendingPathExtremesTable;
204 struct PendingPathNode;
205 class PathExtremesCalc;
206
207
208 private:
209 // FExact Internal Methods
210 // ------------------------------------------------------------------------------------------------------------
211
212 inline double logMultinomial(int numerator, const MarginalArray& denominator);
213 inline double logMultinomial(int numerator, const MarginalArray::Slice& denominator);
214
215 // Core Algorithm
216 inline bool generateFirstDaughter(const MarginalArray& row_marginals, int n, MarginalArray& row_diff, int& kmax, int& kd);
217 bool generateNewDaughter(int kmax, const MarginalArray& row_marginals, MarginalArray& row_diff, int& idx_dec, int& idx_inc);
218 NodePtr handlePastPaths(NodePtr& cur_node, double obs2, double obs3, double ddf, double drn, int kval, NodeHashTable& nht);
219 void recordPath(double path_length, int path_freq, Array<PastPathLength, Smart>& past_entries);
220
221 void handleNode(int k, NodePtr cur_node);
222 inline void unpackMarginals(int key, MarginalArray& row_marginals);
223
224
225 // Longest Path
226 double longestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, int marginal_total);
227 bool longestPathSpecial(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& val);
228
229
230 // Shortest Path
231 void shortestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& shortest_path);
232 inline void removeFromVector(const Array<int, ManualBuffer>& src, int idx_remove, Array<int, ManualBuffer>& dest);
233 inline void reduceZeroInVector(const Array<int, ManualBuffer>& src, int value, int idx_start, Array<int, ManualBuffer>& dest);
234
235
236 private:
237 // FExact Internal Type Definitions
238 // ------------------------------------------------------------------------------------------------------------
239
240 struct PastPathLength
241 {
242 double value;
243 int observed;
244 int next_left;
245 int next_right;
246
PastPathLengthFExact::PastPathLength247 PastPathLength(double in_value = 0.0) : value(in_value), observed(1), next_left(-1), next_right(-1) { ; }
PastPathLengthFExact::PastPathLength248 PastPathLength(double in_value, int in_freq) : value(in_value), observed(in_freq), next_left(-1), next_right(-1) { ; }
249 };
250
251 class FExactNode : public RefCountObject
252 {
253 public:
254 int key;
255 Array<PastPathLength, Smart> past_entries;
256
FExactNode(int in_key)257 FExactNode(int in_key) : key(in_key) { ; }
~FExactNode()258 virtual ~FExactNode() { ; }
259 };
260
261 class NodeHashTable
262 {
263 public:
264 typedef SmartPtr<FExactNode, InternalRCObject> NodePtr;
265 private:
266 Array<NodePtr> m_table;
267 int m_last;
268
269 public:
NodeHashTable(int size=DEFAULT_TABLE_SIZE)270 inline NodeHashTable(int size = DEFAULT_TABLE_SIZE) : m_table(size), m_last(-1) { ; }
271
Find(int key,int & idx)272 bool Find(int key, int& idx)
273 {
274 int init = key % m_table.GetSize();
275 idx = init;
276 for (; idx < m_table.GetSize(); idx++) {
277 if (!m_table[idx]) {
278 m_table[idx] = NodePtr(new FExactNode(key));
279 return false;
280 } else if (m_table[idx]->key == key) {
281 return true;
282 }
283 }
284 for (idx = 0; idx < init; idx++) {
285 if (!m_table[idx]) {
286 m_table[idx] = NodePtr(new FExactNode(key));
287 return false;
288 } else if (m_table[idx]->key == key) {
289 return true;
290 }
291 }
292 Rehash(key, idx);
293 return false;
294 }
295
operator [](int idx)296 inline FExactNode& operator[](int idx) { return *m_table[idx]; }
Get(int idx)297 inline NodePtr Get(int idx) { return m_table[idx]; }
298
Pop()299 NodePtr Pop()
300 {
301 for (++m_last; m_last < m_table.GetSize(); m_last++) {
302 if (m_table[m_last]) {
303 NodePtr tmp = m_table[m_last];
304 m_table[m_last] = NodePtr(NULL);
305 return tmp;
306 }
307 }
308 m_last = -1;
309 return NodePtr(NULL);
310 }
311 private:
Rehash(int key,int & idx)312 void Rehash(int key, int& idx)
313 {
314 Array<NodePtr> old_table(m_table);
315 m_table.ResizeClear(old_table.GetSize() * 2);
316 for (int i = 0; i < old_table.GetSize(); i++) {
317 int t_idx;
318 Find(old_table[i]->key, t_idx);
319 m_table[t_idx] = old_table[i];
320 }
321 Find(key, idx);
322 }
323 };
324
325
326
327 struct PathExtremes {
328 int key;
329 double longest_path;
330 double shortest_path;
PathExtremesFExact::PathExtremes331 PathExtremes() : key(-1) { ; }
332
SetFExact::PathExtremes333 void Set(int in_key, double lp, double sp) { key = in_key; longest_path = lp; shortest_path = sp; }
334 };
335
336 class PathExtremesHashTable
337 {
338 private:
339 Array<PathExtremes> m_table;
340
341 public:
PathExtremesHashTable(int size=DEFAULT_TABLE_SIZE)342 inline PathExtremesHashTable(int size = DEFAULT_TABLE_SIZE) : m_table(size) { ClearTable(); }
343
Find(int key,int & idx)344 bool Find(int key, int& idx)
345 {
346 int init = key % m_table.GetSize();
347 idx = init;
348 for (; idx < m_table.GetSize(); idx++) {
349 if (m_table[idx].key < 0) {
350 m_table[idx].key = key;
351 return false;
352 } else if (m_table[idx].key == key) {
353 return true;
354 }
355 }
356 for (idx = 0; idx < init; idx++) {
357 if (m_table[idx].key < 0) {
358 m_table[idx].key = key;
359 return false;
360 } else if (m_table[idx].key == key) {
361 return true;
362 }
363 }
364 Rehash(key, idx);
365 return false;
366 }
367
operator [](int idx)368 inline PathExtremes& operator[](int idx) { return m_table[idx]; }
369
ClearTable()370 inline void ClearTable()
371 {
372 for (int i = 0; i < m_table.GetSize(); i++) m_table[i].key = -1;
373 }
374
375 private:
Rehash(int key,int & idx)376 void Rehash(int key, int& idx)
377 {
378 Array<PathExtremes> old_table(m_table);
379 m_table.ResizeClear(old_table.GetSize() * 2);
380 for (int i = 0; i < old_table.GetSize(); i++) {
381 int t_idx;
382 Find(old_table[i].key, t_idx);
383 m_table[t_idx].longest_path = old_table[i].longest_path;
384 m_table[t_idx].shortest_path = old_table[i].shortest_path;
385 }
386 Find(key, idx);
387 }
388 };
389
390
391 struct PendingPathExtremes {
392 int key;
393 bool inprogress;
394 MarginalArray rows;
395 MarginalArray cols;
396 double obs2;
397 double ddf;
398 int ntot;
PendingPathExtremesFExact::PendingPathExtremes399 PendingPathExtremes() : key(-1), inprogress(false) { ; }
400
SetFExact::PendingPathExtremes401 inline void Set(const MarginalArray::Slice& in_rows, const MarginalArray::Slice& in_cols,
402 double in_obs2, double in_ddf, int in_ntot)
403 {
404 rows.ResizeClear(in_rows.GetSize());
405 for (int i = 0; i < in_rows.GetSize(); i++) rows[i] = in_rows[i];
406 cols.ResizeClear(in_cols.GetSize());
407 for (int i = 0; i < in_cols.GetSize(); i++) cols[i] = in_cols[i];
408 ntot = in_ntot;
409 obs2 = in_obs2;
410 ddf = in_ddf;
411 }
412 };
413
414 class PendingPathExtremesTable
415 {
416 private:
417 Array<PendingPathExtremes> m_table;
418 int m_size;
419
420 public:
PendingPathExtremesTable(int size=DEFAULT_TABLE_SIZE)421 inline PendingPathExtremesTable(int size = DEFAULT_TABLE_SIZE) : m_table(size), m_size(0) { ; }
422
GetSize()423 int GetSize() { return m_size; }
424
Find(int key,int & idx)425 bool Find(int key, int& idx)
426 {
427 int init = key % m_table.GetSize();
428 idx = init;
429 for (; idx < m_table.GetSize(); idx++) {
430 if (m_table[idx].key < 0) {
431 m_size++;
432 m_table[idx].key = key;
433 return false;
434 } else if (m_table[idx].key == key) {
435 return true;
436 }
437 }
438 for (idx = 0; idx < init; idx++) {
439 if (m_table[idx].key < 0) {
440 m_size++;
441 m_table[idx].key = key;
442 return false;
443 } else if (m_table[idx].key == key) {
444 return true;
445 }
446 }
447 Rehash(key, idx);
448 return false;
449 }
450
Pop()451 int Pop()
452 {
453 assert(m_size > 0);
454 for (int i = 0; i < m_table.GetSize(); i++) if (m_table[i].key >= 0 && !m_table[i].inprogress) {
455 m_table[i].inprogress = true;
456 return i;
457 }
458 return -1;
459 }
Remove(int idx)460 void Remove(int idx) { m_table[idx].key = -1; m_table[idx].inprogress = false; m_size--; }
461
operator [](int idx)462 inline PendingPathExtremes& operator[](int idx) { return m_table[idx]; }
463
464 private:
Rehash(int key,int & idx)465 void Rehash(int key, int& idx)
466 {
467 Array<PendingPathExtremes> old_table(m_table);
468 m_table.ResizeClear(old_table.GetSize() * 2);
469 for (int i = 0; i < old_table.GetSize(); i++) {
470 int t_idx;
471 Find(old_table[i].key, t_idx);
472 m_table[t_idx].rows = old_table[i].rows;
473 m_table[t_idx].cols = old_table[i].cols;
474 }
475 Find(key, idx);
476 }
477 };
478
479
480 struct PendingPathNode {
481 NodePtr node;
482 double obs2;
483 double drn;
484 double ddf;
485 int kval;
486 bool k1;
487 bool handled;
488
PendingPathNodeFExact::PendingPathNode489 PendingPathNode() : node(NULL), handled(false) { ; }
SetFExact::PendingPathNode490 inline void Set(NodePtr in_node, double in_obs2, double in_drn, double in_ddf, int in_kval, bool in_k1 = false)
491 {
492 node = in_node; obs2 = in_obs2; drn = in_drn; ddf = in_ddf; kval = in_kval; k1 = in_k1;
493 }
494 };
495
496
497 class PathExtremesCalc : public Thread
498 {
499 private:
500 FExact* m_fexact;
501 bool m_run;
502
503 public:
PathExtremesCalc(FExact * fexact)504 PathExtremesCalc(FExact* fexact) : m_fexact(fexact), m_run(true) { ; }
505
SetFinished()506 void SetFinished() { m_run = false; }
507 protected:
508 void Run();
509 };
510
511
512 private:
513 // FExact Class Variable Declarations
514 // ------------------------------------------------------------------------------------------------------------
515
516 // Constants
517 const double m_tolerance;
518
519
520 // Pre-calculated Values
521 MarginalArray m_row_marginals;
522 MarginalArray m_col_marginals;
523 Array<double> m_facts; // Log factorials
524 Array<int> m_key_multipliers;
525 double m_observed_path;
526 double m_den_observed_path;
527
528
529 // Main Calculated Value
530 double m_pvalue;
531
532
533 // Core Algorithm Support
534
535
536 // Threaded Core Algorithm Support
537 Array<NodeHashTable> m_nht;
538 Array<Array<PendingPathNode, Smart> > m_pending_path_nodes;
539 Array<PathExtremesHashTable> m_path_extremes;
540
541 // Path Extremes Threading Support
542 Mutex m_path_extremes_mutex;
543 ConditionVariable m_path_extremes_cond;
544 ConditionVariable m_path_extremes_cond_complete;
545 Array<PendingPathExtremesTable> m_pending_path_extremes;
546 Array<Array<PathExtremes, Smart> > m_completed_path_extremes;
547 Array<int, Smart> m_path_extremes_queue;
548 bool m_pending_paths_term;
549 };
550
551
552
553 // Exported Function Definitions
554 // --------------------------------------------------------------------------------------------------------------
555
FishersExact(const ContingencyTable & table)556 double Apto::Stat::FishersExact(const ContingencyTable& table)
557 {
558 if (table.MarginalTotal() == 0.0) return std::numeric_limits<double>::quiet_NaN(); // All elements are 0
559
560 FExact fe(table, TOLERANCE);
561
562 // Use threaded calculate for larger tables
563 if (table.NumRows() * table.NumCols() > THREADING_THRESHOLD) return fe.ThreadedCalculate();
564
565 return fe.Calculate();
566 }
567
568
569
570 // FExact Constructor and Support Methods
571 // --------------------------------------------------------------------------------------------------------------
572
FExact(const Stat::ContingencyTable & table,double tolerance)573 FExact::FExact(const Stat::ContingencyTable& table, double tolerance)
574 : m_tolerance(tolerance)
575 , m_facts(table.MarginalTotal() + 1)
576 , m_pvalue(0.0)
577 {
578 // Store table marginals for use in calculation
579 if (table.NumRows() > table.NumCols()) {
580 m_row_marginals = table.ColMarginals();
581 m_col_marginals = table.RowMarginals();
582 } else {
583 m_row_marginals = table.RowMarginals();
584 m_col_marginals = table.ColMarginals();
585 }
586
587
588 // Sort row and column marginals
589 QSort(m_row_marginals);
590 QSort(m_col_marginals);
591
592
593 // Set up key multipliers
594 m_key_multipliers.Resize(m_row_marginals.GetSize());
595 m_key_multipliers[0] = 1;
596 for (int i = 1; i < m_row_marginals.GetSize(); i++)
597 m_key_multipliers[i] = m_key_multipliers[i - 1] * (m_row_marginals[i - 1] + 1);
598
599
600 // Ensure that the maximum key will fit within an integer
601 assert((m_row_marginals[m_row_marginals.GetSize() - 2] + 1) <=
602 (std::numeric_limits<int>::max() / m_key_multipliers[m_row_marginals.GetSize() - 2]));
603
604
605 // Pre-calculate log factorials
606 const int marginal_total = table.MarginalTotal();
607 m_facts[0] = 0.0;
608 m_facts[1] = 0.0;
609 if (marginal_total > 1) {
610 m_facts[2] = log(2.0);
611 for (int i = 3; i <= marginal_total; i++) {
612 m_facts[i] = m_facts[i - 1] + log((double)i);
613 if (++i <= marginal_total) m_facts[i] = m_facts[i - 1] + m_facts[2] + m_facts[i / 2] - m_facts[i / 2 - 1];
614 }
615 }
616
617
618 // Calculate Observed Path Numerator
619 m_observed_path = m_tolerance;
620 for (int j = 0; j < m_col_marginals.GetSize(); j++) {
621 double dd = 0.0;
622 for (int i = 0; i < m_row_marginals.GetSize(); i++) {
623 if (m_row_marginals.GetSize() == table.NumRows()) {
624 dd += m_facts[table[i][j]];
625 } else {
626 dd += m_facts[table[j][i]];
627 }
628 }
629 m_observed_path += m_facts[m_col_marginals[j]] - dd;
630 }
631
632 // Calculate Observed Path Denominator
633 m_den_observed_path = logMultinomial(marginal_total, m_row_marginals);
634
635
636 //
637 // double prt = exp(m_observed_path - m_den_observed_path);
638 // std::cout << "prt = " << prt << std::endl;
639 }
640
641
logMultinomial(int numerator,const MarginalArray & denominator)642 inline double FExact::logMultinomial(int numerator, const MarginalArray& denominator)
643 {
644 double ret_val = m_facts[numerator];
645 for (int i = 0; i < denominator.GetSize(); i++) ret_val -= m_facts[denominator[i]];
646 return ret_val;
647 }
648
649
logMultinomial(int numerator,const MarginalArray::Slice & denominator)650 inline double FExact::logMultinomial(int numerator, const MarginalArray::Slice& denominator)
651 {
652 double ret_val = m_facts[numerator];
653 for (int i = 0; i < denominator.GetSize(); i++) ret_val -= m_facts[denominator[i]];
654 return ret_val;
655 }
656
657
658
659 // FExact Core Algorithm Calculate Method
660 // --------------------------------------------------------------------------------------------------------------
661
Calculate()662 double FExact::Calculate()
663 {
664 NodeHashTable nht[2];
665 PathExtremesHashTable path_extremes;
666
667 int k = m_col_marginals.GetSize();
668 NodePtr cur_node(new FExactNode(0));
669 cur_node->past_entries.Push(PastPathLength(0));
670
671 MarginalArray row_diff(m_row_marginals.GetSize());
672 MarginalArray irn(m_row_marginals.GetSize());
673
674 // printf("k = %d\n", k);
675
676 while (true) {
677 int kb = m_col_marginals.GetSize() - k;
678 int ks = -1;
679 int kmax;
680 int kd;
681
682 if (generateFirstDaughter(m_row_marginals, m_col_marginals[kb], row_diff, kmax, kd)) {
683 int ntot = 0;
684 for (int i = kb + 1; i < m_col_marginals.GetSize(); i++) ntot += m_col_marginals[i];
685
686 do {
687 for (int i = 0; i < m_row_marginals.GetSize(); i++) irn[i] = m_row_marginals[i] - row_diff[i];
688
689 int nrb;
690 if (k > 2) {
691 if (irn.GetSize() == 2) {
692 if (irn[0] > irn[1]) irn.Swap(0, 1);
693 } else {
694 QSort(irn);
695 }
696
697 // Adjust for zero start
698 int i = 0;
699 for (; i < irn.GetSize(); i++) if (irn[i] != 0) break;
700 nrb = i;
701 } else {
702 nrb = 0;
703 }
704
705 // Build adjusted row array
706 MarginalArray::Slice sub_rows = irn.GetSlice(nrb, irn.GetSize() - 1);
707 MarginalArray::Slice sub_cols = m_col_marginals.GetSlice(kb + 1, m_col_marginals.GetSize() - 1);
708
709 double ddf = logMultinomial(m_col_marginals[kb], row_diff);
710 double drn = logMultinomial(ntot, sub_rows) - m_den_observed_path + ddf;
711
712 int kval = 0;
713 int path_idx = -1;
714 double obs2, obs3;
715
716 if (k > 2) {
717 // compute hash table key for current table
718 kval = irn[0] + irn[1] * m_key_multipliers[1];
719 for (int i = 2; i < irn.GetSize(); i++) kval += irn[i] * m_key_multipliers[i];
720
721 if (!path_extremes.Find(kval, path_idx)) {
722 path_extremes[path_idx].longest_path = 1.0;
723 }
724
725 obs2 = m_observed_path - m_facts[m_col_marginals[kb + 1]] - m_facts[m_col_marginals[kb + 2]] - ddf;
726 for (int i = 3; i <= (k - 1); i++) obs2 -= m_facts[m_col_marginals[kb + i]];
727
728 if (path_extremes[path_idx].longest_path > 0.0) {
729
730 path_extremes[path_idx].longest_path = longestPath(sub_rows, sub_cols, ntot);
731 if (path_extremes[path_idx].longest_path > 0.0) path_extremes[path_idx].longest_path = 0.0;
732
733 double dspt = m_observed_path - obs2 - ddf;
734 path_extremes[path_idx].shortest_path = dspt;
735 shortestPath(sub_rows, sub_cols, path_extremes[path_idx].shortest_path);
736 path_extremes[path_idx].shortest_path -= dspt;
737 if (path_extremes[path_idx].shortest_path > 0.0) path_extremes[path_idx].shortest_path = 0.0;
738 }
739 obs3 = obs2 - path_extremes[path_idx].longest_path;
740 obs2 = obs2 - path_extremes[path_idx].shortest_path;
741 } else {
742 obs2 = m_observed_path - drn - m_den_observed_path;
743 obs3 = obs2;
744 }
745
746 handlePastPaths(cur_node, obs2, obs3, ddf, drn, kval, nht[k & 0x1]);
747 } while (generateNewDaughter(kmax, m_row_marginals, row_diff, kd, ks));
748 }
749
750 do {
751 cur_node = nht[(k + 1) & 0x1].Pop();
752 if (!cur_node) {
753 k--;
754 // printf("k = %d\n", k);
755 path_extremes.ClearTable();
756 if (k < 2) return m_pvalue;
757 }
758 } while (!cur_node);
759
760 // Unpack node row marginals from key
761 int kval = cur_node->key;
762 for (int i = m_row_marginals.GetSize() - 1; i > 0; i--) {
763 m_row_marginals[i] = kval / m_key_multipliers[i];
764 kval -= m_row_marginals[i] * m_key_multipliers[i];
765 }
766 m_row_marginals[0] = kval;
767 }
768
769 return m_pvalue;
770 }
771
772
773
774 // FExact Threaded Core Algorithm Calculate Method
775 // --------------------------------------------------------------------------------------------------------------
776
ThreadedCalculate()777 double FExact::ThreadedCalculate()
778 {
779 int k = m_col_marginals.GetSize();
780
781 // Setup Multi-Stage Data Structures
782 m_nht.Resize(k + 1);
783 m_pending_path_nodes.Resize(k + 1);
784 m_path_extremes.Resize(k + 1);
785 m_pending_path_extremes.Resize(k + 1);
786 m_completed_path_extremes.Resize(k + 1);
787
788
789 // Setup Path Calculation Worker Thread
790 Array<PathExtremesCalc*> path_calcs(Platform::AvailableCPUs());
791 for (int i = 0; i < path_calcs.GetSize(); i++) {
792 path_calcs[i] = new PathExtremesCalc(this);
793 path_calcs[i]->Start();
794 }
795
796
797 // Single Stage Support Variables
798 int kval = m_row_marginals[0] + m_row_marginals[1] * m_key_multipliers[1];
799 for (int i = 2; i < m_row_marginals.GetSize(); i++) kval += m_row_marginals[i] * m_key_multipliers[i];
800 NodePtr cur_node(new FExactNode(kval));
801 cur_node->past_entries.Push(PastPathLength(0));
802
803
804 handleNode(k, cur_node);
805
806 Array<PathExtremes, Smart> temp_completed;
807 for (; k >= 1; k--) {
808 // printf("k = %d\n", k);
809 if (m_pending_path_nodes[k].GetSize()) {
810 m_path_extremes_mutex.Lock();
811 while (m_pending_path_extremes[k].GetSize()) {
812 m_path_extremes_cond_complete.Wait(m_path_extremes_mutex);
813 if (m_completed_path_extremes[k].GetSize()) {
814 temp_completed = m_completed_path_extremes[k];
815 m_completed_path_extremes[k].Resize(0);
816 m_path_extremes_mutex.Unlock();
817 for (int i = 0; i < temp_completed.GetSize(); i++) {
818 PathExtremes& path = temp_completed[i];
819 int path_idx;
820 m_path_extremes[k].Find(path.key, path_idx);
821 m_path_extremes[k][path_idx].longest_path = path.longest_path;
822 m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
823
824 for (int i = 0; i < m_pending_path_nodes[k].GetSize(); i++) {
825 PendingPathNode& p = m_pending_path_nodes[k][i];
826 if (p.kval != path.key) continue;
827 if (p.handled) break;
828 double obs2, obs3;
829 if (p.k1) {
830 obs2 = p.obs2;
831 obs3 = p.obs2;
832 } else {
833 int path_idx;
834 m_path_extremes[k].Find(p.kval, path_idx);
835 obs3 = p.obs2 - m_path_extremes[k][path_idx].longest_path;
836 obs2 = p.obs2 - m_path_extremes[k][path_idx].shortest_path;
837 }
838 cur_node = handlePastPaths(p.node, obs2, obs3, p.ddf, p.drn, p.kval, m_nht[k - 1]);
839 p.node = NodePtr(NULL);
840 if (cur_node) handleNode(k - 1, cur_node);
841 p.handled = true;
842 }
843 }
844 m_path_extremes_mutex.Lock();
845 }
846 }
847 for (int i = 0; i < m_completed_path_extremes[k].GetSize(); i++) {
848 PathExtremes& path = m_completed_path_extremes[k][i];
849 int path_idx;
850 m_path_extremes[k].Find(path.key, path_idx);
851 m_path_extremes[k][path_idx].longest_path = path.longest_path;
852 m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
853 }
854 m_completed_path_extremes[k].Resize(0);
855 m_path_extremes_mutex.Unlock();
856
857 for (int i = 0; i < m_pending_path_nodes[k].GetSize(); i++) {
858 PendingPathNode& p = m_pending_path_nodes[k][i];
859 if (p.handled) continue;
860 double obs2, obs3;
861 if (p.k1) {
862 obs2 = p.obs2;
863 obs3 = p.obs2;
864 } else {
865 int path_idx;
866 m_path_extremes[k].Find(p.kval, path_idx);
867 obs3 = p.obs2 - m_path_extremes[k][path_idx].longest_path;
868 obs2 = p.obs2 - m_path_extremes[k][path_idx].shortest_path;
869 }
870 cur_node = handlePastPaths(p.node, obs2, obs3, p.ddf, p.drn, p.kval, m_nht[k - 1]);
871 p.node = NodePtr(NULL);
872 if (cur_node) handleNode(k - 1, cur_node);
873 }
874 m_pending_path_nodes[k].Resize(0);
875 }
876 }
877
878 for (int i = 0; i < path_calcs.GetSize(); i++) path_calcs[i]->SetFinished();
879 m_path_extremes_cond.Broadcast();
880 for (int i = 0; i < path_calcs.GetSize(); i++) path_calcs[i]->Join();
881
882 return m_pvalue;
883 }
884
885
handleNode(int k,NodePtr cur_node)886 void FExact::handleNode(int k, NodePtr cur_node)
887 {
888 MarginalArray row_marginals(m_row_marginals.GetSize());
889 MarginalArray row_diff(row_marginals.GetSize());
890 MarginalArray irn(row_marginals.GetSize());
891
892 unpackMarginals(cur_node->key, row_marginals);
893
894 int kb = m_col_marginals.GetSize() - k;
895 int ks = -1;
896 int kmax;
897 int kd;
898
899 if (!generateFirstDaughter(row_marginals, m_col_marginals[kb], row_diff, kmax, kd)) return;
900
901 int ntot = 0;
902 for (int i = kb + 1; i < m_col_marginals.GetSize(); i++) ntot += m_col_marginals[i];
903
904 do {
905 for (int i = 0; i < row_marginals.GetSize(); i++) irn[i] = row_marginals[i] - row_diff[i];
906
907 int nrb = 0;
908 if (k > 2) {
909 QSort(irn);
910
911 // Adjust for zero start
912 for (; nrb < irn.GetSize(); nrb++) if (irn[nrb] != 0) break;
913 }
914
915 // Build adjusted row array
916 MarginalArray::Slice sub_rows = irn.GetSlice(nrb, irn.GetSize() - 1);
917 MarginalArray::Slice sub_cols = m_col_marginals.GetSlice(kb + 1, m_col_marginals.GetSize() - 1);
918
919 double ddf = logMultinomial(m_col_marginals[kb], row_diff);
920 double drn = logMultinomial(ntot, sub_rows) - m_den_observed_path + ddf;
921
922 int kval = 0;
923 int path_idx = -1;
924 double obs2;
925
926 if (k > 2) {
927 // compute hash table key for current table
928 kval = irn[0] + irn[1] * m_key_multipliers[1];
929 for (int i = 2; i < irn.GetSize(); i++) kval += irn[i] * m_key_multipliers[i];
930
931 obs2 = m_observed_path - m_facts[m_col_marginals[kb + 1]] - m_facts[m_col_marginals[kb + 2]] - ddf;
932 for (int i = 3; i <= (k - 1); i++) obs2 -= m_facts[m_col_marginals[kb + i]];
933
934 if (!m_path_extremes[k].Find(kval, path_idx)) {
935 m_path_extremes_mutex.Lock();
936 {
937 // Move completed paths to the local table
938 bool found = false;
939 for (int i = 0; i < m_completed_path_extremes[k].GetSize(); i++) {
940 PathExtremes& path = m_completed_path_extremes[k][i];
941 m_path_extremes[k].Find(path.key, path_idx);
942 m_path_extremes[k][path_idx].longest_path = path.longest_path;
943 m_path_extremes[k][path_idx].shortest_path = path.shortest_path;
944 if (path.key == kval) found = true;
945 }
946 m_completed_path_extremes[0].Resize(0);
947
948 // Add new pending path if not found and not already in pending table
949 if (!found && !m_pending_path_extremes[k].Find(kval, path_idx)) {
950 m_pending_path_extremes[k][path_idx].Set(sub_rows, sub_cols, obs2, ddf, ntot);
951 m_path_extremes_queue.Push(k);
952 }
953 }
954 m_path_extremes_mutex.Unlock();
955 m_path_extremes_cond.Signal();
956 }
957
958 // Push node onto pending stack
959 int pending_idx = m_pending_path_nodes[k].GetSize();
960 m_pending_path_nodes[k].Resize(pending_idx + 1);
961 m_pending_path_nodes[k][pending_idx].Set(cur_node, obs2, drn, ddf, kval);
962 } else {
963 obs2 = m_observed_path - drn - m_den_observed_path;
964 int pending_idx = m_pending_path_nodes[k].GetSize();
965 m_pending_path_nodes[k].Resize(pending_idx + 1);
966 m_pending_path_nodes[k][pending_idx].Set(cur_node, obs2, drn, ddf, kval, true);
967 }
968
969 } while (generateNewDaughter(kmax, row_marginals, row_diff, kd, ks));
970
971 }
972
973 // Path Extremes Threading Methods
974 // --------------------------------------------------------------------------------------------------------------
975
Run()976 void FExact::PathExtremesCalc::Run()
977 {
978 while (true) {
979 int k;
980 int path_idx;
981
982 m_fexact->m_path_extremes_mutex.Lock();
983 {
984 while (m_fexact->m_path_extremes_queue.GetSize() == 0 && m_run) {
985 m_fexact->m_path_extremes_cond.Wait(m_fexact->m_path_extremes_mutex);
986 }
987 if (!m_run) {
988 m_fexact->m_path_extremes_mutex.Unlock();
989 return;
990 }
991 k = m_fexact->m_path_extremes_queue.Pop();
992 path_idx = m_fexact->m_pending_path_extremes[k].Pop();
993 }
994 m_fexact->m_path_extremes_mutex.Unlock();
995
996 if (path_idx < 0) continue;
997
998 PendingPathExtremes& p = m_fexact->m_pending_path_extremes[k][path_idx];
999
1000 double longest_path = m_fexact->longestPath(p.rows.GetSlice(0, p.rows.GetSize() - 1), p.cols.GetSlice(0, p.cols.GetSize() - 1), p.ntot);
1001 if (longest_path > 0.0) longest_path = 0.0;
1002
1003 double dspt = m_fexact->m_observed_path - p.obs2 - p.ddf;
1004 double shortest_path = dspt;
1005 m_fexact->shortestPath(p.rows.GetSlice(0, p.rows.GetSize() - 1), p.cols.GetSlice(0, p.cols.GetSize() - 1), shortest_path);
1006 shortest_path -= dspt;
1007 if (shortest_path > 0.0) shortest_path = 0.0;
1008
1009 m_fexact->m_path_extremes_mutex.Lock();
1010 {
1011 // Record completed calculation
1012 int completed_idx = m_fexact->m_completed_path_extremes[k].GetSize();
1013 m_fexact->m_completed_path_extremes[k].Resize(completed_idx + 1);
1014 m_fexact->m_completed_path_extremes[k][completed_idx].Set(p.key, longest_path, shortest_path);
1015
1016 // Clear out the pending record
1017 m_fexact->m_pending_path_extremes[k].Remove(path_idx);
1018 }
1019 m_fexact->m_path_extremes_mutex.Unlock();
1020 m_fexact->m_path_extremes_cond_complete.Signal();
1021 }
1022 }
1023
1024
1025
1026 // FExact Core Algorithm Support Methods
1027 // --------------------------------------------------------------------------------------------------------------
1028
generateFirstDaughter(const MarginalArray & row_marginals,int n,MarginalArray & row_diff,int & kmax,int & kd)1029 inline bool FExact::generateFirstDaughter(const MarginalArray& row_marginals, int n, MarginalArray& row_diff, int& kmax, int& kd)
1030 {
1031 row_diff.SetAll(0);
1032
1033 kmax = row_marginals.GetSize() - 1;
1034 kd = row_marginals.GetSize();
1035 do {
1036 kd--;
1037 int ntot = (n < row_marginals[kd]) ? n : row_marginals[kd];
1038 row_diff[kd] = ntot;
1039 if (row_diff[kmax] == 0) kmax--;
1040 n -= ntot;
1041 } while (n > 0 && kd > 0);
1042
1043 if (n != 0) return false;
1044
1045 return true;
1046 }
1047
1048
generateNewDaughter(int kmax,const MarginalArray & row_marginals,MarginalArray & row_diff,int & idx_dec,int & idx_inc)1049 bool FExact::generateNewDaughter(int kmax, const MarginalArray& row_marginals, MarginalArray& row_diff, int& idx_dec, int& idx_inc)
1050 {
1051 if (idx_inc == -1) {
1052 while (row_diff[++idx_inc] == row_marginals[idx_inc]) ;
1053 }
1054
1055 // Find node to decrement
1056 if (row_diff[idx_dec] > 0 && idx_dec > idx_inc) {
1057 row_diff[idx_dec]--;
1058 while (row_marginals[--idx_dec] == 0) ;
1059 int m = idx_dec;
1060
1061 // Find node to increment
1062 while (row_diff[m] >= row_marginals[m]) m--;
1063 row_diff[m]++;
1064
1065 if (m == idx_inc && row_diff[m] == row_marginals[m]) idx_inc = idx_dec;
1066 } else {
1067 int idx = 0;
1068 do {
1069 // Check for finish
1070 idx = idx_dec + 1;
1071 bool found = false;
1072 for (; idx < row_diff.GetSize(); idx++) {
1073 if (row_diff[idx] > 0) {
1074 found = true;
1075 break;
1076 }
1077 }
1078 if (!found) return false;
1079
1080 int marginal_total = 1;
1081 for (int i = 0; i <= idx_dec; i++) {
1082 marginal_total += row_diff[i];
1083 row_diff[i] = 0;
1084 }
1085 idx_dec = idx;
1086 do {
1087 idx_dec--;
1088 int m = (marginal_total < row_marginals[idx_dec]) ? marginal_total : row_marginals[idx_dec];
1089 row_diff[idx_dec] = m;
1090 marginal_total -= m;
1091 } while (marginal_total > 0 && idx_dec != 0);
1092
1093 if (marginal_total > 0) {
1094 if (idx != (kmax)) {
1095 idx_dec = idx;
1096 continue;
1097 }
1098 return false;
1099 } else {
1100 break;
1101 }
1102 } while (true);
1103 row_diff[idx]--;
1104 for (idx_inc = 0; row_diff[idx_inc] >= row_marginals[idx_inc]; idx_inc++) if (idx_inc > idx_dec) break;
1105 }
1106
1107 return true;
1108 }
1109
1110
handlePastPaths(NodePtr & cur_node,double obs2,double obs3,double ddf,double drn,int kval,NodeHashTable & nht)1111 FExact::NodePtr FExact::handlePastPaths(NodePtr& cur_node, double obs2, double obs3, double ddf, double drn, int kval,
1112 NodeHashTable& nht)
1113 {
1114 NodePtr new_node(NULL);
1115 for (int i = 0; i < cur_node->past_entries.GetSize(); i++) {
1116 double past_path = cur_node->past_entries[i].value;
1117 int path_freq = cur_node->past_entries[i].observed;
1118 if (past_path <= obs3) {
1119 // Path shorter than longest path, add to the pvalue and continue
1120 m_pvalue += (double)(path_freq) * exp(past_path + drn);
1121 } else if (past_path < obs2) {
1122 int nht_idx;
1123 double new_path = past_path + ddf;
1124 if (nht.Find(kval, nht_idx)) {
1125 // Existing Node was found
1126 recordPath(new_path, path_freq, nht[nht_idx].past_entries);
1127 } else {
1128 // New Node added, insert this observed path
1129 new_node = nht.Get(nht_idx);
1130 new_node->past_entries.Push(PastPathLength(new_path, path_freq));
1131 }
1132 }
1133 }
1134 return new_node;
1135 }
1136
1137
recordPath(double path_length,int path_freq,Array<PastPathLength,Smart> & past_entries)1138 void FExact::recordPath(double path_length, int path_freq, Array<PastPathLength, Smart>& past_entries)
1139 {
1140 // Search for past path within m_tolerance and add observed frequency to it
1141 double test1 = path_length - m_tolerance;
1142 double test2 = path_length + m_tolerance;
1143
1144 int j = 0;
1145 int old_j = 0;
1146 while (true) {
1147 double test_path = past_entries[j].value;
1148 if (test_path < test1) {
1149 old_j = j;
1150 j = past_entries[j].next_left;
1151 if (j >= 0) continue;
1152 } else if (test_path > test2) {
1153 old_j = j;
1154 j = past_entries[j].next_right;
1155 if (j >= 0) continue;
1156 } else {
1157 past_entries[j].observed += path_freq;
1158 return;
1159 }
1160 break;
1161 }
1162
1163 // If no path within m_tolerance is found, add new past path length to the node
1164 int new_idx = past_entries.GetSize();
1165 past_entries.Push(PastPathLength(path_length, path_freq));
1166
1167 double test_path = past_entries[old_j].value;
1168 if (test_path < test1) {
1169 past_entries[old_j].next_left = new_idx;
1170 } else if (test_path > test2) {
1171 past_entries[old_j].next_right = new_idx;
1172 } else {
1173 assert(false);
1174 }
1175 }
1176
1177
unpackMarginals(int key,MarginalArray & row_marginals)1178 void FExact::unpackMarginals(int key, MarginalArray& row_marginals)
1179 {
1180 for (int i = row_marginals.GetSize() - 1; i > 0; i--) {
1181 row_marginals[i] = key / m_key_multipliers[i];
1182 key -= row_marginals[i] * m_key_multipliers[i];
1183 }
1184 row_marginals[0] = key;
1185 }
1186
1187
1188
1189 // FExact Longest Path Methods
1190 // --------------------------------------------------------------------------------------------------------------
1191
longestPath(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,int marginal_total)1192 double FExact::longestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, int marginal_total)
1193 {
1194 class ValueHashTable
1195 {
1196 private:
1197 Array<Pair<int, double> >* m_table;
1198 Array<int> m_stack;
1199 int m_entry_count;
1200
1201 public:
1202 inline ValueHashTable(int size = 200) : m_table(new Array<Pair<int, double> >(size)), m_stack(size) { ClearTable(); }
1203 inline ~ValueHashTable() { delete m_table; }
1204
1205 int GetEntryCount() const { return m_entry_count; }
1206
1207 bool Find(int key, int& idx)
1208 {
1209 int init = key % m_table->GetSize();
1210 idx = init;
1211 for (; idx < m_table->GetSize(); idx++) {
1212 if ((*m_table)[idx].Value1() < 0) {
1213 m_stack[m_entry_count] = idx;
1214 (*m_table)[idx].Value1() = key;
1215 m_entry_count++;
1216 return false;
1217 } else if ((*m_table)[idx].Value1() == key) {
1218 return true;
1219 }
1220 }
1221 for (idx = 0; idx < init; idx++) {
1222 if ((*m_table)[idx].Value1() < 0) {
1223 m_stack[m_entry_count] = idx;
1224 (*m_table)[idx].Value1() = key;
1225 m_entry_count++;
1226 return false;
1227 } else if ((*m_table)[idx].Value1() == key) {
1228 return true;
1229 }
1230 }
1231 Rehash(key, idx);
1232 return false;
1233 }
1234
1235 inline double& operator[](int idx) { return (*m_table)[idx].Value2(); }
1236
1237 Pair<int, double> Pop()
1238 {
1239 Pair<int, double> tmp = (*m_table)[m_stack[--m_entry_count]];
1240 (*m_table)[m_stack[m_entry_count]].Value1() = -1;
1241 return tmp;
1242 }
1243
1244 inline void ClearTable()
1245 {
1246 m_entry_count = 0;
1247 for (int i = 0; i < m_table->GetSize(); i++) (*m_table)[i].Value1() = -1;
1248 }
1249
1250 private:
1251 void Rehash(int key, int& idx)
1252 {
1253 Array<Pair<int, double> >* old_table = m_table;
1254 m_table = new Array<Pair<int, double> >(old_table->GetSize() * 2);
1255 for (int i = 0; i < m_table->GetSize(); i++) (*m_table)[i].Value1() = -1;
1256 m_stack.Resize(old_table->GetSize() * 2);
1257 for (int i = 0; i < old_table->GetSize(); i++) {
1258 int t_idx;
1259 Find((*old_table)[i].Value1(), t_idx);
1260 (*m_table)[t_idx].Value2() = (*old_table)[i].Value2();
1261 }
1262 Find(key, idx);
1263 delete old_table;
1264 }
1265 };
1266
1267 // 1 x c
1268 if (row_marginals.GetSize() <= 1) {
1269 double longest_path = 0.0;
1270 for (int i = 0; i < col_marginals.GetSize(); i++) longest_path -= m_facts[col_marginals[i]];
1271 return longest_path;
1272 }
1273
1274 // r x 1
1275 if (col_marginals.GetSize() <= 1) {
1276 double longest_path = 0.0;
1277 for (int i = 0; i < row_marginals.GetSize(); i++) longest_path -= m_facts[row_marginals[i]];
1278 return longest_path;
1279 }
1280
1281 // 2 x 2
1282 if (row_marginals.GetSize() == 2 && col_marginals.GetSize() == 2) {
1283 int n11 = (row_marginals[0] + 1) * (col_marginals[0] + 1) / (marginal_total + 2);
1284 int n12 = row_marginals[0] - n11;
1285 return -m_facts[n11] - m_facts[n12] - m_facts[col_marginals[0] - n11] - m_facts[col_marginals[1] - n12];
1286 }
1287
1288 double val = 0.0;
1289 bool min = false;
1290 if (row_marginals[row_marginals.GetSize() - 1] <= row_marginals[0] + col_marginals.GetSize()) {
1291 min = longestPathSpecial(row_marginals, col_marginals, val);
1292 }
1293 if (!min && col_marginals[col_marginals.GetSize() - 1] <= col_marginals[0] + row_marginals.GetSize()) {
1294 min = longestPathSpecial(col_marginals, row_marginals, val);
1295 }
1296
1297 if (min) {
1298 return -val;
1299 }
1300
1301
1302 int ntot = marginal_total;
1303 MarginalArray lrow;
1304 MarginalArray lcol;
1305
1306 if (row_marginals.GetSize() >= col_marginals.GetSize()) {
1307 lrow.EnhancedSmart<int>::operator=(row_marginals);
1308 lcol.EnhancedSmart<int>::operator=(col_marginals);
1309 } else {
1310 lrow.EnhancedSmart<int>::operator=(col_marginals);
1311 lcol.EnhancedSmart<int>::operator=(row_marginals);
1312 }
1313
1314 Array<int> nt(lcol.GetSize());
1315 nt[0] = ntot - lcol[0];
1316 for (int i = 1; i < lcol.GetSize(); i++) nt[i] = nt[i - 1] - lcol[i];
1317
1318
1319 Array<double> alen(col_marginals.GetSize() + 1);
1320 alen.SetAll(0.0);
1321
1322 ValueHashTable vht[2];
1323 int active_vht = 0;
1324
1325 double vmn = 1.0e10;
1326 int nc1s = lcol.GetSize() - 2;
1327 int kyy = lcol[lcol.GetSize() - 1] + 1;
1328
1329 Array<int> lb(lrow.GetSize());
1330 Array<int> nu(lrow.GetSize());
1331 Array<int> nr(lrow.GetSize());
1332
1333
1334 while (true) {
1335 bool continue_main = false;
1336
1337 // Setup to generate new node
1338 int lev = 0;
1339 int nr1 = lrow.GetSize() - 1;
1340 int nrt = lrow[0];
1341 int nct = lcol[0];
1342 lb[0] = (int)((((double)nrt + 1.0) * (nct + 1)) / (double)(ntot + nr1 * (nc1s + 1) + 1) - m_tolerance) - 1;
1343 nu[0] = (int)((((double)nrt + nc1s + 1.0) * (nct + nr1)) / (double)(ntot + nr1 + nc1s + 1)) - lb[0] + 1;
1344 nr[0] = nrt - lb[0];
1345
1346 while (true) {
1347 do {
1348 nu[lev]--;
1349 if (nu[lev] == 0) {
1350 if (lev == 0) {
1351 do {
1352 if (vht[(active_vht) ? 0 : 1].GetEntryCount()) {
1353 Pair<int, double> entry = vht[(active_vht) ? 0 : 1].Pop();
1354 val = entry.Value2();
1355 int key = entry.Value1();
1356
1357 // Compute Marginals
1358 for (int i = lcol.GetSize() - 1; i > 0; i--) {
1359 lcol[i] = key % kyy;
1360 key = key / kyy;
1361 }
1362 lcol[0] = key;
1363
1364 // Set up nt array
1365 nt[0] = ntot - lcol[0];
1366 for (int i = 1; i < lcol.GetSize(); i++) nt[i] = nt[i - 1] - lcol[i];
1367
1368 min = false;
1369 if (lrow[lrow.GetSize() - 1] <= lrow[0] + lcol.GetSize()) {
1370 min = longestPathSpecial(lrow.GetSlice(0, lrow.GetSize() - 1), lcol.GetSlice(0, lcol.GetSize() - 1), val);
1371 }
1372 if (!min && lcol[lcol.GetSize() - 1] <= lcol[0] + lrow.GetSize()) {
1373 min = longestPathSpecial(lrow.GetSlice(0, lrow.GetSize() - 1), lcol.GetSlice(0, lcol.GetSize() - 1), val);
1374 }
1375
1376 if (min) {
1377 if (val < vmn)
1378 vmn = val;
1379 continue;
1380 }
1381 continue_main = true;
1382 } else if (lrow.GetSize() > 2 && vht[active_vht].GetEntryCount()) {
1383 // Go to next level
1384 ntot -= lrow[0];
1385 Array<int> tmp(lrow);
1386 lrow.ResizeClear(lrow.GetSize() - 1);
1387 for (int i = 0; i < lrow.GetSize(); i++) lrow[i] = tmp[i + 1];
1388 active_vht = (active_vht) ? 0 : 1;
1389 continue;
1390 }
1391 break;
1392 } while (true);
1393 if (!continue_main)
1394 return -vmn;
1395 }
1396 if (continue_main) break;
1397 lev--;
1398 continue;
1399 }
1400 break;
1401 } while (true);
1402 if (continue_main) break;
1403
1404 lb[lev]++;
1405 nr[lev]--;
1406
1407 for (alen[lev + 1] = alen[lev] + m_facts[lb[lev]]; lev < nc1s; alen[lev + 1] = alen[lev] + m_facts[lb[lev]]) {
1408 int nn1 = nt[lev];
1409 int nrt = nr[lev];
1410 lev++;
1411 int nc1 = lcol.GetSize() - lev - 1;
1412 int nct = lcol[lev];
1413 lb[lev] = (int)((double)((nrt + 1) * (nct + 1)) / (double)(nn1 + nr1 * nc1 + 1) - m_tolerance);
1414 nu[lev] = (int)((double)((nrt + nc1) * (nct + nr1)) / (double)(nn1 + nr1 + nc1) - lb[lev] + 1);
1415 nr[lev] = nrt - lb[lev];
1416 }
1417 alen[lcol.GetSize()] = alen[lev + 1] + m_facts[nr[lev]];
1418 lb[lcol.GetSize() - 1] = nr[lev];
1419
1420 double v = val + alen[lcol.GetSize()];
1421 if (lrow.GetSize() == 2) {
1422 for (int i = 0; i < lcol.GetSize(); i++) v += m_facts[lcol[i] - lb[i]];
1423 if (v < vmn)
1424 vmn = v;
1425 } else if (lrow.GetSize() == 3 && lcol.GetSize() == 2) {
1426 int nn1 = ntot - lrow[0] + 2;
1427 int ic1 = lcol[0] - lb[0];
1428 int ic2 = lcol[1] - lb[1];
1429 int n11 = (lrow[1] + 1) * (ic1 + 1) / nn1;
1430 int n12 = lrow[1] - n11;
1431 v += m_facts[n11] + m_facts[n12] + m_facts[ic1 - n11] + m_facts[ic2 - n12];
1432 if (v < vmn)
1433 vmn = v;
1434 } else {
1435 Array<int> it(lcol.GetSize());
1436 for (int i = 0; i < lcol.GetSize(); i++) it[i] = lcol[i] - lb[i];
1437
1438 if (lcol.GetSize() == 2) {
1439 if (it[0] > it[1]) it.Swap(0, 1);
1440 } else {
1441 QSort(it);
1442 }
1443
1444 // Compute hash value
1445 int key = it[0] * kyy + it[1];
1446 for (int i = 2; i < lcol.GetSize(); i++) key = it[i] + key * kyy;
1447
1448 // Put onto stack (or update stack entry as necessary)
1449 int t_idx;
1450 if (vht[active_vht].Find(key, t_idx)) {
1451 if (v < vht[active_vht][t_idx]) vht[active_vht][t_idx] = v;
1452 } else {
1453 vht[active_vht][t_idx] = v;
1454 }
1455 }
1456 }
1457 }
1458
1459
1460
1461 return 0.0;
1462 }
1463
1464
longestPathSpecial(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,double & val)1465 bool FExact::longestPathSpecial(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& val)
1466 {
1467 Array<int> nd(row_marginals.GetSize() - 1);
1468 Array<int> ne(col_marginals.GetSize());
1469 Array<int> m(col_marginals.GetSize());
1470
1471 nd.SetAll(0);
1472 int is = col_marginals[0] / row_marginals.GetSize();
1473 ne[0] = is;
1474 int ix = col_marginals[0] - row_marginals.GetSize() * is;
1475 m[0] = ix;
1476 if (ix != 0) nd[ix - 1] = 1;
1477
1478 for (int i = 1; i < col_marginals.GetSize(); i++) {
1479 ix = col_marginals[i] / row_marginals.GetSize();
1480 ne[i] = ix;
1481 is += ix;
1482 ix = col_marginals[i] - row_marginals.GetSize() * ix;
1483 m[i] = ix;
1484 if (ix != 0) nd[ix - 1]++;
1485 }
1486
1487 for (int i = nd.GetSize() - 2; i >= 0; i--) nd[i] += nd[i + 1];
1488
1489 ix = 0;
1490 int nrow1 = row_marginals.GetSize() - 1;
1491 for (int i = (row_marginals.GetSize() - 1); i > 0; i--) {
1492 ix += is + nd[nrow1 - i] - row_marginals[i];
1493 if (ix < 0) return false;
1494 }
1495
1496 val = 0.0;
1497 for (int i = 0; i < col_marginals.GetSize(); i++) {
1498 ix = ne[i];
1499 is = m[i];
1500 val += is * m_facts[ix + 1] + (row_marginals.GetSize() - is) * m_facts[ix];
1501 }
1502
1503 return true;
1504 }
1505
1506
1507
1508
1509 // FExact Shortest Path Methods
1510 // --------------------------------------------------------------------------------------------------------------
1511
removeFromVector(const Array<int,ManualBuffer> & src,int idx_remove,Array<int,ManualBuffer> & dest)1512 inline void FExact::removeFromVector(const Array<int, ManualBuffer>& src, int idx_remove, Array<int, ManualBuffer>& dest)
1513 {
1514 dest.Resize(src.GetSize() - 1);
1515 for (int i = 0; i < idx_remove; i++) dest[i] = src[i];
1516 for (int i = idx_remove + 1; i < src.GetSize(); i++) dest[i - 1] = src[i];
1517 }
1518
1519
reduceZeroInVector(const Array<int,ManualBuffer> & src,int value,int idx_start,Array<int,ManualBuffer> & dest)1520 inline void FExact::reduceZeroInVector(const Array<int, ManualBuffer>& src, int value, int idx_start, Array<int, ManualBuffer>& dest)
1521 {
1522 dest.Resize(src.GetSize());
1523
1524 int i = 0;
1525 for (; i < idx_start; i++) dest[i] = src[i];
1526
1527 for (; i < (src.GetSize() - 1); i++) {
1528 if (value >= src[i + 1]) {
1529 break;
1530 }
1531 dest[i] = src[i + 1];
1532 }
1533 dest[i] = value;
1534
1535 for (++i; i < src.GetSize(); i++) dest[i] = src[i];
1536 }
1537
1538
shortestPath(const MarginalArray::Slice & row_marginals,const MarginalArray::Slice & col_marginals,double & shortest_path)1539 void FExact::shortestPath(const MarginalArray::Slice& row_marginals, const MarginalArray::Slice& col_marginals, double& shortest_path)
1540 {
1541 // Take care of easy cases first
1542
1543 // 1 x c
1544 if (row_marginals.GetSize() == 1) {
1545 for (int i = 0; i < col_marginals.GetSize(); i++) shortest_path -= m_facts[col_marginals[i]];
1546 return;
1547 }
1548
1549 // r x 1
1550 if (col_marginals.GetSize() == 1) {
1551 for (int i = 0; i < row_marginals.GetSize(); i++) shortest_path -= m_facts[row_marginals[i]];
1552 return;
1553 }
1554
1555 // 2 x 2
1556 if (row_marginals.GetSize() == 2 && col_marginals.GetSize() == 2) {
1557 if (row_marginals[1] <= col_marginals[1]) {
1558 shortest_path += -m_facts[row_marginals[1]] - m_facts[col_marginals[0]] - m_facts[col_marginals[1] - row_marginals[1]];
1559 } else {
1560 shortest_path += -m_facts[col_marginals[1]] - m_facts[row_marginals[0]] - m_facts[row_marginals[1] - col_marginals[1]];
1561 }
1562 return;
1563 }
1564
1565 // General Case
1566
1567
1568 const int ROW_BUFFER_SIZE = (row_marginals.GetSize() + col_marginals.GetSize() + 1) * row_marginals.GetSize();
1569 const int COL_BUFFER_SIZE = (row_marginals.GetSize() + col_marginals.GetSize() + 1) * col_marginals.GetSize();
1570 SmartPtr<int, NoCopy, ArrayStorage> row_data(new int[ROW_BUFFER_SIZE]);
1571 SmartPtr<int, NoCopy, ArrayStorage> col_data(new int[COL_BUFFER_SIZE]);
1572 Array<Array<int, ManualBuffer> > row_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1573 Array<Array<int, ManualBuffer> > col_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1574 for (int i = 0; i < row_stack.GetSize(); i++) {
1575 row_stack[i].SetBuffer(GetInternalPtr(row_data) + (i * row_marginals.GetSize()));
1576 col_stack[i].SetBuffer(GetInternalPtr(col_data) + (i * col_marginals.GetSize()));
1577 }
1578
1579 row_stack[0].Resize(row_marginals.GetSize());
1580 for (int i = 0; i < row_marginals.GetSize(); i++) row_stack[0][i] = row_marginals[row_marginals.GetSize() - i - 1];
1581 col_stack[0].Resize(col_marginals.GetSize());
1582 for (int i = 0; i < col_marginals.GetSize(); i++) col_stack[0][i] = col_marginals[col_marginals.GetSize() - i - 1];
1583
1584 int istk = 0;
1585
1586 Array<double> y_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1587 Array<int> l_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1588 Array<int> m_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1589 Array<int> n_stack(row_marginals.GetSize() + col_marginals.GetSize() + 1);
1590 y_stack[0] = 0.0;
1591 double y = 0.0;
1592
1593 int l = 0;
1594 double amx = 0.0;
1595
1596 int m, n, jrow, jcol;
1597
1598 do {
1599 int row1 = row_stack[istk][0];
1600 int col1 = col_stack[istk][0];
1601 if (row1 > col1) {
1602 if (row_stack[istk].GetSize() >= col_stack[istk].GetSize()) {
1603 m = col_stack[istk].GetSize() - 1;
1604 n = 2;
1605 } else {
1606 m = row_stack[istk].GetSize();
1607 n = 1;
1608 }
1609 } else if (row1 < col1) {
1610 if (row_stack[istk].GetSize() <= col_stack[istk].GetSize()) {
1611 m = row_stack[istk].GetSize() - 1;
1612 n = 1;
1613 } else {
1614 m = col_stack[istk].GetSize();
1615 n = 2;
1616 }
1617 } else {
1618 if (row_stack[istk].GetSize() <= col_stack[istk].GetSize()) {
1619 m = row_stack[istk].GetSize() - 1;
1620 n = 1;
1621 } else {
1622 m = col_stack[istk].GetSize() - 1;
1623 n = 2;
1624 }
1625 }
1626
1627 do {
1628 if (n == 1) {
1629 jrow = l;
1630 jcol = 0;
1631 } else {
1632 jrow = 0;
1633 jcol = l;
1634 }
1635
1636 int rowt = row_stack[istk][jrow];
1637 int colt = col_stack[istk][jcol];
1638 int mn = (rowt > colt) ? colt : rowt;
1639 y += m_facts[mn];
1640 if (rowt == colt) {
1641 removeFromVector(row_stack[istk], jrow, row_stack[istk + 1]);
1642 removeFromVector(col_stack[istk], jcol, col_stack[istk + 1]);
1643 } else if (rowt > colt) {
1644 removeFromVector(col_stack[istk], jcol, col_stack[istk + 1]);
1645 reduceZeroInVector(row_stack[istk], rowt - colt, jrow, row_stack[istk + 1]);
1646 } else {
1647 removeFromVector(row_stack[istk], jrow, row_stack[istk + 1]);
1648 reduceZeroInVector(col_stack[istk], colt - rowt, jcol, col_stack[istk + 1]);
1649 }
1650
1651 if (row_stack[istk + 1].GetSize() == 1 || col_stack[istk + 1].GetSize() == 1) {
1652 if (row_stack[istk + 1].GetSize() == 1) {
1653 for (int i = 0; i < col_stack[istk + 1].GetSize(); i++) y += m_facts[col_stack[istk + 1][i]];
1654 }
1655 if (col_stack[istk + 1].GetSize() == 1) {
1656 for (int i = 0; i < row_stack[istk + 1].GetSize(); i++) y += m_facts[row_stack[istk + 1][i]];
1657 }
1658
1659 if (y > amx) {
1660 amx = y;
1661 if (shortest_path - amx <= m_tolerance) {
1662 shortest_path = 0.0;
1663 return;
1664 }
1665 }
1666
1667 bool continue_outer = false;
1668 for (--istk; istk >= 0; istk--) {
1669 l = l_stack[istk] + 1;
1670 for (; l < m_stack[istk]; l++) {
1671 n = n_stack[istk];
1672 y = y_stack[istk];
1673 if (n == 1) {
1674 if (row_stack[istk][l] < row_stack[istk][l - 1]) {
1675 continue_outer = true;
1676 break;
1677 }
1678 } else if (n == 2) {
1679 if (col_stack[istk][l] < col_stack[istk][l - 1]) {
1680 continue_outer = true;
1681 break;
1682 }
1683 }
1684 }
1685 if (continue_outer) break;
1686 }
1687 if (continue_outer) continue;
1688
1689 shortest_path -= amx;
1690 if (shortest_path - amx <= m_tolerance) shortest_path = 0.0;
1691 return;
1692 } else {
1693 break;
1694 }
1695 } while (true);
1696
1697 l_stack[istk] = l;
1698 m_stack[istk] = m;
1699 n_stack[istk] = n;
1700 istk++;
1701 y_stack[istk] = y;
1702 l = 0;
1703 } while (true);
1704
1705 }
1706
1707
1708
1709 // Internal Function Definitions
1710 // --------------------------------------------------------------------------------------------------------------
1711
cummulativeGamma(double q,double alpha,bool & fault)1712 double cummulativeGamma(double q, double alpha, bool& fault)
1713 {
1714 if (q <= 0.0 || alpha <= 0.0) {
1715 fault = true;
1716 return 0.0;
1717 }
1718
1719 double f = exp(alpha * log(q) - logGamma(alpha + 1.0, fault) - q); // no need to test logGamma fail as an error is impossible
1720 if (f == 0.0) {
1721 fault = true;
1722 return 0.0;
1723 }
1724
1725 fault = false;
1726
1727 double c = 1.0;
1728 double ret_val = 1.0;
1729 double a = alpha;
1730
1731 do {
1732 a += 1.0;
1733 c = c * q / a;
1734 ret_val += c;
1735 } while (c / ret_val > (1e-6));
1736 ret_val *= f;
1737
1738 return ret_val;
1739 }
1740
1741
logGamma(double x,bool & fault)1742 double logGamma(double x, bool& fault)
1743 {
1744 const double a1 = .918938533204673;
1745 const double a2 = 5.95238095238e-4;
1746 const double a3 = 7.93650793651e-4;
1747 const double a4 = .002777777777778;
1748 const double a5 = .083333333333333;
1749
1750 if (x < 0.0) {
1751 fault = true;
1752 return 0.0;
1753 }
1754
1755 fault = false;
1756
1757 double f = 0.0;
1758
1759 if (x < 7.0) {
1760 f = x;
1761
1762 x += 1.0;
1763 while (x < 7.0) {
1764 f *= x;
1765 x += 1.0;
1766 }
1767
1768 f = -log(f);
1769 }
1770
1771 double z = 1 / (x * x);
1772 return f + (x - .5) * log(x) - x + a1 + (((-a2 * z + a3) * z - a4) * z + a5) / x;
1773 }
1774