1 use core;
2 pub type Prob = u16;
3
4 pub const BLEND_FIXED_POINT_PRECISION : i8 = 15;
5 #[allow(dead_code)]
6 pub const LOG2_SCALE : i32 = 15;
7 #[derive(Debug,Copy, Clone)]
8 pub struct Weights {
9 model_weights: [i32;2],
10 mixing_param: u8,
11 normalized_weight: Prob,
12 }
13 impl Default for Weights {
default() -> Self14 fn default() -> Self {
15 Self::new()
16 }
17 }
18 impl Weights {
new() -> Self19 pub fn new() -> Self {
20 Weights {
21 model_weights:[1;2],
22 mixing_param: 1,
23 normalized_weight: 1 << (BLEND_FIXED_POINT_PRECISION - 1),
24 }
25 }
26 #[allow(unused)]
27 #[inline(always)]
update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob)28 pub fn update(&mut self, model_probs: [Prob; 2], weighted_prob: Prob) {
29 debug_assert!(self.mixing_param != 0);
30 normalize_weights(&mut self.model_weights);
31 let w0new = compute_new_weight(model_probs,
32 weighted_prob,
33 self.model_weights,
34 false,
35 self.mixing_param - 1);
36 let w1new = compute_new_weight(model_probs,
37 weighted_prob,
38 self.model_weights,
39 true,
40 self.mixing_param - 1);
41 self.model_weights = [w0new, w1new];
42 self.normalized_weight = compute_normalized_weight(self.model_weights);
43 }
44 #[allow(dead_code)]
45 #[inline(always)]
norm_weight(&self) -> Prob46 pub fn norm_weight(&self) -> Prob {
47 self.normalized_weight
48 }
49 }
50
51 #[allow(dead_code)]
52 #[inline(always)]
compute_normalized_weight(model_weights: [i32;2]) -> Prob53 fn compute_normalized_weight(model_weights: [i32;2]) -> Prob {
54 let total = i64::from(model_weights[0]) + i64::from(model_weights[1]);
55 let leading_zeros = total.leading_zeros();
56 let shift = core::cmp::max(56 - (leading_zeros as i8), 0);
57 let total_8bit = total >> shift;
58 /*::probability::numeric::fast_divide_16bit_by_8bit(
59 ((model_weights[0] >> shift) as u16)<< 8,
60 ::probability::numeric::lookup_divisor8(total_8bit as u8)) << (BLEND_FIXED_POINT_PRECISION - 8)
61 */
62 ((((model_weights[0] >> shift) as u16)<< 8) / total_8bit as u16/*fixme??*/) << (BLEND_FIXED_POINT_PRECISION - 8)
63 }
64
65 #[allow(dead_code)]
66 #[cold]
fix_weights(weights: &mut [i32;2])67 fn fix_weights(weights: &mut [i32;2]) {
68 let ilog = 32 - core::cmp::min(weights[0].leading_zeros(),
69 weights[1].leading_zeros());
70 let max_log = 24;
71 if ilog >= max_log {
72 weights[0] >>= ilog - max_log;
73 weights[1] >>= ilog - max_log;
74 }
75 }
76
77 #[allow(dead_code)]
78 #[inline(always)]
normalize_weights(weights: &mut [i32;2])79 fn normalize_weights(weights: &mut [i32;2]) {
80 if ((weights[0]|weights[1])&0x7f000000) != 0 {
81 fix_weights(weights);
82 }
83 }
84
85 #[allow(dead_code)]
86 #[cfg(features="floating_point_context_mixing")]
compute_new_weight(probs: [Prob; 2], weighted_prob: Prob, weights: [i32;2], index_equal_1: bool, _speed: u8) -> i3287 fn compute_new_weight(probs: [Prob; 2],
88 weighted_prob: Prob,
89 weights: [i32;2],
90 index_equal_1: bool,
91 _speed: u8) -> i32{ // speed ranges from 1 to 14 inclusive
92 let index = index_equal_1 as usize;
93 let n1i = probs[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
94 //let n0i = 1.0f64 - n1i;
95 let ni = 1.0f64;
96 let s1 = weighted_prob as f64 / ((1i64 << LOG2_SCALE) as f64);
97 let s0 = 1.0f64 - s1;
98 let s = 1.0f64;
99 //let p0 = s0;
100 let p1 = s1;
101 let wi = weights[index] as f64 / ((1i64 << LOG2_SCALE) as f64);
102 let mut wi_new = wi + (1.0 - p1) * (s * n1i - s1 * ni) / (s0 * s1);
103 let eps = 0.00001f64;
104 if !(wi_new > eps) {
105 wi_new = eps;
106 }
107 (wi_new * ((1i64 << LOG2_SCALE) as f64)) as i32
108 }
109
110 #[allow(dead_code)]
111 #[cfg(not(features="floating_point_context_mixing"))]
112 #[inline(always)]
compute_new_weight(probs: [Prob; 2], weighted_prob: Prob, weights: [i32;2], index_equal_1: bool, _speed: u8) -> i32113 fn compute_new_weight(probs: [Prob; 2],
114 weighted_prob: Prob,
115 weights: [i32;2],
116 index_equal_1: bool,
117 _speed: u8) -> i32{ // speed ranges from 1 to 14 inclusive
118 let index = index_equal_1 as usize;
119 let full_model_sum_p1 = i64::from(weighted_prob);
120 let full_model_total = 1i64 << LOG2_SCALE;
121 let full_model_sum_p0 = full_model_total.wrapping_sub(i64::from(weighted_prob));
122 let n1i = i64::from(probs[index]);
123 let ni = 1i64 << LOG2_SCALE;
124 let error = full_model_total.wrapping_sub(full_model_sum_p1);
125 let wi = i64::from(weights[index]);
126 let efficacy = full_model_total.wrapping_mul(n1i) - full_model_sum_p1.wrapping_mul(ni);
127 //let geometric_probabilities = full_model_sum_p1 * full_model_sum_p0;
128 let log_geometric_probabilities = 64 - (full_model_sum_p1.wrapping_mul(full_model_sum_p0)).leading_zeros();
129 //let scaled_geometric_probabilities = geometric_probabilities * S;
130 //let new_weight_adj = (error * efficacy) >> log_geometric_probabilities;// / geometric_probabilities;
131 //let new_weight_adj = (error * efficacy)/(full_model_sum_p1 * full_model_sum_p0);
132 let new_weight_adj = (error.wrapping_mul(efficacy)) >> log_geometric_probabilities;
133 // assert!(wi + new_weight_adj < (1i64 << 31));
134 //print!("{} -> {} due to {:?} vs {}\n", wi as f64 / (weights[0] + weights[1]) as f64, (wi + new_weight_adj) as f64 /(weights[0] as i64 + new_weight_adj as i64 + weights[1] as i64) as f64, probs[index], weighted_prob);
135 core::cmp::max(1,wi.wrapping_add(new_weight_adj) as i32)
136 }
137