1 #include <Rcpp.h>
2 using namespace Rcpp;
3
4
5 // compute partial sum using binary search algorithm like AVL
6 // pre-compute powers of two to save repeated calculations
7
8
9 IntegerVector containerNodes (int y, IntegerVector pwr2, IntegerVector psum);
10 NumericVector gamma1_direct(IntegerVector y, NumericVector z);
11 IntegerVector p2sum(IntegerVector pwr2);
12 IntegerVector powers2 (int L);
13 NumericVector rowsumsDist(NumericVector x, NumericVector sorted, IntegerVector ranks);
14 IntegerVector subNodes (int y, IntegerVector pwr2, IntegerVector psum);
15
16
17
18 // [[Rcpp::export]]
Btree_sum(IntegerVector y,NumericVector z)19 NumericVector Btree_sum (IntegerVector y, NumericVector z) {
20 //
21 // y is a permutation of the integers 1:n
22 // z is a numeric vector of length n
23 // compute gamma1(i) = sum(j<i; y_j<y_i) z(j)
24 //
25
26 int n = y.length(), L = ceil(log2(n));
27 int i, node, p;
28 IntegerVector pwr2 = powers2(L);
29 IntegerVector psum = p2sum(pwr2);
30 IntegerVector nodes(L);
31 NumericVector sums(2 * pwr2(L - 1));
32 NumericVector gamma1(n);
33
34 for (i = 1; i < n; i++) {
35 // update container sums for y(i - 1)
36 nodes = containerNodes(y(i - 1), pwr2, psum);
37 for (p = 0; p < L; p++)
38 sums(nodes(p)) += z(i - 1);
39
40 // get nodes below y(i) and update gamma(i)
41 nodes = subNodes(y(i) - 1, pwr2, psum);
42 for (p = 0; p < L; p++) {
43 node = nodes(p);
44 if (node > 0)
45 gamma1(i) += sums(node);
46 }
47 }
48 return gamma1;
49 }
50
containerNodes(int y,IntegerVector pwr2,IntegerVector psum)51 IntegerVector containerNodes (int y, IntegerVector pwr2, IntegerVector psum) {
52 /*
53 * get the indices of all nodes of binary tree whose closed
54 * intervals contain integer y
55 */
56 int i, L = pwr2.length();
57 IntegerVector nodes(L);
58
59 nodes(0) = y;
60 for (i = 0; i < L-1; i++) {
61 nodes(i+1) = ceil((double) y / pwr2(i)) + psum(i);
62 }
63 return nodes;
64 }
65
66
subNodes(int y,IntegerVector pwr2,IntegerVector psum)67 IntegerVector subNodes (int y, IntegerVector pwr2, IntegerVector psum) {
68 /*
69 * get indices of nodes whose intervals disjoint union is 1:y
70 */
71 int L = psum.length();
72 int idx, k, level, p2;
73 IntegerVector nodes(L);
74
75 std::fill(nodes.begin(), nodes.end(), -1L);
76
77 k = y;
78 for (level = L - 1; level > 0; level --) {
79 p2 = pwr2(level - 1);
80 if (k >= p2) {
81 // at index of left node plus an offset
82 idx = psum(level - 1) + (y / p2);
83 nodes(L - level - 1) = idx;
84 k -= p2;
85 }
86 }
87 if (k > 0)
88 nodes(L - 1) = y;
89 return nodes;
90 }
91
92
powers2(int L)93 IntegerVector powers2 (int L) {
94 // (2, 4, 8, ..., 2^L, 2^(L+1))
95 int k;
96 IntegerVector pwr2(L);
97
98 pwr2(0) = 2;
99 for (k = 1; k < L; k++)
100 pwr2(k) = pwr2(k-1) * 2;
101 return pwr2;
102 }
103
p2sum(IntegerVector pwr2)104 IntegerVector p2sum(IntegerVector pwr2) {
105 // computes the cumsum of 2^L, 2^(L-1), ..., 2^2, 2
106 int i, L = pwr2.length();
107 IntegerVector psum(L);
108
109 std::fill(psum.begin(), psum.end(), pwr2(L-1));
110 for (i = 1; i < L; i++)
111 psum(i) = psum(i-1) + pwr2(L-i-1);
112 return psum;
113 }
114
115
gamma1_direct(IntegerVector y,NumericVector z)116 NumericVector gamma1_direct(IntegerVector y, NumericVector z) {
117 // utility: direct computation of the sum gamm1
118 // for the purpose of testing and benchmarks
119
120 int n = y.length();
121 int i, j;
122 NumericVector gamma1(n);
123
124 for (i = 1; i < n; i++) {
125 for (j = 0; j < i; j++) {
126 if (y(j) < y(i)) {
127 gamma1(i) += z(j);
128 }
129 }
130 }
131 return gamma1;
132 }
133