1 use core;
2 pub type Prob = u16;
3 
4 pub const BLEND_FIXED_POINT_PRECISION : i8 = 15;
5 #[allow(dead_code)]
6 pub const LOG2_SCALE : i32 = 15;
7 #[derive(Debug,Copy, Clone)]
8 pub struct Weights {
9     model_weights: [i32;2],
10     mixing_param: u8,
11     normalized_weight: Prob,
12 }
13 impl Default for Weights {
default() -> Self14     fn default() -> Self {
15         Self::new()
16     }
17 }
18 impl Weights {
new() -> Self19     pub fn new() -> Self {
20         Weights {
21             model_weights:[1;2],
22             mixing_param: 1,
23             normalized_weight: 1 << (BLEND_FIXED_POINT_PRECISION - 1),
24         }
25     }
26     #[allow(unused)]
27     #[inline(always)]
update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob)28     pub fn update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob) {
29         debug_assert!(self.mixing_param != 0);
30         normalize_weights(&mut self.model_weights);
31         let w0new = compute_new_weight(model_probs,
32                                        weighted_prob,
33                                        self.model_weights,
34                                        false,
35                                        self.mixing_param - 1);
36         let w1new = compute_new_weight(model_probs,
37                                        weighted_prob,
38                                        self.model_weights,
39                                        true,
40                                        self.mixing_param - 1);
41         self.model_weights = [w0new, w1new];
42         self.normalized_weight = compute_normalized_weight(self.model_weights);
43     }
44     #[allow(dead_code)]
45     #[inline(always)]
norm_weight(&self) -> Prob46     pub fn norm_weight(&self) -> Prob {
47         self.normalized_weight
48     }
49 }
50 
51 #[allow(dead_code)]
52 #[inline(always)]
compute_normalized_weight(model_weights: [i32;2]) -> Prob53 fn compute_normalized_weight(model_weights: [i32;2]) -> Prob {
54     let total = i64::from(model_weights[0]) + i64::from(model_weights[1]);
55     let leading_zeros = total.leading_zeros();
56     let shift = core::cmp::max(56 - (leading_zeros as i8), 0);
57     let total_8bit = total >> shift;
58     /*::probability::numeric::fast_divide_16bit_by_8bit(
59         ((model_weights[0] >> shift) as u16)<< 8,
60         ::probability::numeric::lookup_divisor8(total_8bit as u8)) << (BLEND_FIXED_POINT_PRECISION - 8)
61         */
62     ((((model_weights[0] >> shift) as u16)<< 8) / total_8bit as u16/*fixme??*/) << (BLEND_FIXED_POINT_PRECISION - 8)
63 }
64 
65 #[allow(dead_code)]
66 #[cold]
fix_weights(weights: &mut [i32;2])67 fn fix_weights(weights: &mut [i32;2]) {
68     let ilog = 32  - core::cmp::min(weights[0].leading_zeros(),
69                                     weights[1].leading_zeros());
70     let max_log = 24;
71     if ilog >= max_log {
72         weights[0] >>= ilog - max_log;
73         weights[1] >>= ilog - max_log;
74     }
75 }
76 
77 #[allow(dead_code)]
78 #[inline(always)]
normalize_weights(weights: &mut [i32;2])79 fn normalize_weights(weights: &mut [i32;2]) {
80     if ((weights[0]|weights[1])&0x7f000000) != 0 {
81         fix_weights(weights);
82     }
83 }
84 
85 #[allow(dead_code)]
86 #[cfg(features="floating_point_context_mixing")]
compute_new_weight(probs: [Prob; 2], weighted_prob: Prob, weights: [i32;2], index_equal_1: bool, _speed: u8) -> i3287 fn compute_new_weight(probs: [Prob; 2],
88                       weighted_prob: Prob,
89                       weights: [i32;2],
90                       index_equal_1: bool,
91                       _speed: u8) -> i32{ // speed ranges from 1 to 14 inclusive
92     let index = index_equal_1 as usize;
93     let n1i = probs[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
94     //let n0i = 1.0f64 - n1i;
95     let ni = 1.0f64;
96     let s1 = weighted_prob as f64 / ((1i64 << LOG2_SCALE) as f64);
97     let s0 = 1.0f64 - s1;
98     let s = 1.0f64;
99     //let p0 = s0;
100     let p1 = s1;
101     let wi = weights[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
102     let mut wi_new = wi + (1.0 - p1) * (s * n1i - s1 * ni) / (s0 * s1);
103     let eps = 0.00001f64;
104     if !(wi_new > eps) {
105         wi_new = eps;
106     }
107     (wi_new * ((1i64 << LOG2_SCALE) as f64)) as i32
108 }
109 
110 #[allow(dead_code)]
111 #[cfg(not(features="floating_point_context_mixing"))]
112 #[inline(always)]
compute_new_weight(probs: [Prob; 2], weighted_prob: Prob, weights: [i32;2], index_equal_1: bool, _speed: u8) -> i32113 fn compute_new_weight(probs: [Prob; 2],
114                       weighted_prob: Prob,
115                       weights: [i32;2],
116                       index_equal_1: bool,
117                       _speed: u8) -> i32{ // speed ranges from 1 to 14 inclusive
118     let index = index_equal_1 as usize;
119     let full_model_sum_p1 = i64::from(weighted_prob);
120     let full_model_total = 1i64 << LOG2_SCALE;
121     let full_model_sum_p0 = full_model_total.wrapping_sub(i64::from(weighted_prob));
122     let n1i = i64::from(probs[index]);
123     let ni = 1i64 << LOG2_SCALE;
124     let error = full_model_total.wrapping_sub(full_model_sum_p1);
125     let wi = i64::from(weights[index]);
126     let efficacy = full_model_total.wrapping_mul(n1i) - full_model_sum_p1.wrapping_mul(ni);
127     //let geometric_probabilities = full_model_sum_p1 * full_model_sum_p0;
128     let log_geometric_probabilities = 64 - (full_model_sum_p1.wrapping_mul(full_model_sum_p0)).leading_zeros();
129     //let scaled_geometric_probabilities = geometric_probabilities * S;
130     //let new_weight_adj = (error * efficacy) >> log_geometric_probabilities;// / geometric_probabilities;
131     //let new_weight_adj = (error * efficacy)/(full_model_sum_p1 * full_model_sum_p0);
132     let new_weight_adj = (error.wrapping_mul(efficacy)) >> log_geometric_probabilities;
133 //    assert!(wi + new_weight_adj < (1i64 << 31));
134     //print!("{} -> {} due to {:?} vs {}\n", wi as f64 / (weights[0] + weights[1]) as f64, (wi + new_weight_adj) as f64 /(weights[0] as i64 + new_weight_adj as i64 + weights[1] as i64) as f64, probs[index], weighted_prob);
135     core::cmp::max(1,wi.wrapping_add(new_weight_adj) as i32)
136 }
137