1 #include <Rcpp.h>
2 using namespace Rcpp;
3 
4 
5 // compute partial sum using binary search algorithm like AVL
6 // pre-compute powers of two to save repeated calculations
7 
8 
9 IntegerVector containerNodes (int y, IntegerVector pwr2, IntegerVector psum);
10 NumericVector gamma1_direct(IntegerVector y, NumericVector z);
11 IntegerVector p2sum(IntegerVector pwr2);
12 IntegerVector powers2 (int L);
13 NumericVector rowsumsDist(NumericVector x, NumericVector sorted, IntegerVector ranks);
14 IntegerVector subNodes (int y, IntegerVector pwr2, IntegerVector psum);
15 
16 
17 
18 // [[Rcpp::export]]
Btree_sum(IntegerVector y,NumericVector z)19 NumericVector Btree_sum (IntegerVector y, NumericVector z) {
20   //
21   //  y is a permutation of the integers 1:n
22   //  z is a numeric vector of length n
23   //  compute gamma1(i) = sum(j<i; y_j<y_i) z(j)
24   //
25 
26   int n = y.length(), L = ceil(log2(n));
27   int i, node, p;
28   IntegerVector pwr2 = powers2(L);
29   IntegerVector psum = p2sum(pwr2);
30   IntegerVector nodes(L);
31   NumericVector sums(2 * pwr2(L - 1));
32   NumericVector gamma1(n);
33 
34   for (i = 1; i < n; i++) {
35     // update container sums for y(i - 1)
36     nodes = containerNodes(y(i - 1), pwr2, psum);
37     for (p = 0; p < L; p++)
38       sums(nodes(p)) += z(i - 1);
39 
40     // get nodes below y(i) and update gamma(i)
41     nodes = subNodes(y(i) - 1, pwr2, psum);
42     for (p = 0; p < L; p++) {
43       node = nodes(p);
44       if (node > 0)
45         gamma1(i) += sums(node);
46     }
47   }
48   return gamma1;
49 }
50 
containerNodes(int y,IntegerVector pwr2,IntegerVector psum)51 IntegerVector containerNodes (int y, IntegerVector pwr2, IntegerVector psum) {
52   /*
53   * get the indices of all nodes of binary tree whose closed
54   * intervals contain integer y
55   */
56   int i, L = pwr2.length();
57   IntegerVector nodes(L);
58 
59   nodes(0) = y;
60   for (i = 0; i < L-1; i++) {
61      nodes(i+1) = ceil((double) y / pwr2(i)) + psum(i);
62   }
63   return nodes;
64 }
65 
66 
subNodes(int y,IntegerVector pwr2,IntegerVector psum)67 IntegerVector subNodes (int y, IntegerVector pwr2, IntegerVector psum) {
68   /*
69    * get indices of nodes whose intervals disjoint union is 1:y
70    */
71   int L = psum.length();
72   int idx, k, level, p2;
73   IntegerVector nodes(L);
74 
75   std::fill(nodes.begin(), nodes.end(), -1L);
76 
77   k = y;
78   for (level = L - 1; level > 0; level --) {
79     p2 = pwr2(level - 1);
80     if (k >= p2) {
81       // at index of left node plus an offset
82       idx = psum(level - 1) + (y / p2);
83       nodes(L - level - 1) = idx;
84       k -= p2;
85     }
86   }
87   if (k > 0)
88     nodes(L - 1) = y;
89   return nodes;
90 }
91 
92 
powers2(int L)93 IntegerVector  powers2 (int L) {
94   //  (2, 4, 8, ..., 2^L, 2^(L+1))
95   int k;
96   IntegerVector pwr2(L);
97 
98   pwr2(0) = 2;
99   for (k = 1; k < L; k++)
100     pwr2(k) = pwr2(k-1) * 2;
101   return pwr2;
102 }
103 
p2sum(IntegerVector pwr2)104 IntegerVector p2sum(IntegerVector pwr2) {
105   // computes the cumsum of 2^L, 2^(L-1), ..., 2^2, 2
106   int i, L = pwr2.length();
107   IntegerVector psum(L);
108 
109   std::fill(psum.begin(), psum.end(), pwr2(L-1));
110   for (i = 1; i < L; i++)
111     psum(i) = psum(i-1) + pwr2(L-i-1);
112   return psum;
113 }
114 
115 
gamma1_direct(IntegerVector y,NumericVector z)116 NumericVector gamma1_direct(IntegerVector y, NumericVector z) {
117   // utility: direct computation of the sum gamm1
118   // for the purpose of testing and benchmarks
119 
120   int n = y.length();
121   int i, j;
122   NumericVector gamma1(n);
123 
124   for (i = 1; i < n; i++) {
125     for (j = 0; j < i; j++) {
126       if (y(j) < y(i)) {
127         gamma1(i) += z(j);
128       }
129     }
130   }
131   return gamma1;
132 }
133