1 // Copyright (C) 2003  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_
4 #define DLIB_CONDITIONING_CLASS_KERNEl_2_
5 
6 #include "conditioning_class_kernel_abstract.h"
7 #include "../assert.h"
8 #include "../algs.h"
9 
10 namespace dlib
11 {
12 
13     template <
14         unsigned long alphabet_size
15         >
16     class conditioning_class_kernel_2
17     {
18         /*!
19             INITIAL VALUE
20                 total == 1
21                 symbols == pointer to array of alphabet_size data structs
22                 for all i except i == alphabet_size-1: symbols[i].count == 0
23                                                        symbols[i].left_count == 0
24 
25                 symbols[alphabet_size-1].count == 1
26                 symbols[alpahbet_size-1].left_count == 0
27 
28             CONVENTION
29                 symbols == pointer to array of alphabet_size data structs
30                 get_total() == total
31                 get_count(symbol) == symbols[symbol].count
32 
33                 symbols is organized as a tree with symbols[0] as the root.
34 
35                 the left subchild of symbols[i] is symbols[i*2+1] and
36                 the right subchild is symbols[i*2+2].
37                 the partent of symbols[i] == symbols[(i-1)/2]
38 
39                 symbols[i].left_count == the sum of the counts of all the
40                                          symbols to the left of symbols[i]
41 
42                 get_memory_usage() == global_state.memory_usage
43         !*/
44 
45     public:
46 
47         class global_state_type
48         {
49         public:
global_state_type()50             global_state_type () : memory_usage(0) {}
51         private:
52             unsigned long memory_usage;
53 
54             friend class conditioning_class_kernel_2<alphabet_size>;
55         };
56 
57         conditioning_class_kernel_2 (
58             global_state_type& global_state_
59         );
60 
61         ~conditioning_class_kernel_2 (
62         );
63 
64         void clear(
65         );
66 
67         bool increment_count (
68             unsigned long symbol,
69             unsigned short amount = 1
70         );
71 
72         unsigned long get_count (
73             unsigned long symbol
74         ) const;
75 
76         inline unsigned long get_total (
77         ) const;
78 
79         unsigned long get_range (
80             unsigned long symbol,
81             unsigned long& low_count,
82             unsigned long& high_count,
83             unsigned long& total_count
84         ) const;
85 
86         void get_symbol (
87             unsigned long target,
88             unsigned long& symbol,
89             unsigned long& low_count,
90             unsigned long& high_count
91         ) const;
92 
93         unsigned long get_memory_usage (
94         ) const;
95 
96         global_state_type& get_global_state (
97         );
98 
99         static unsigned long get_alphabet_size (
100         );
101 
102     private:
103 
104         // restricted functions
105         conditioning_class_kernel_2(conditioning_class_kernel_2<alphabet_size>&);        // copy constructor
106         conditioning_class_kernel_2& operator=(conditioning_class_kernel_2<alphabet_size>&);    // assignment operator
107 
108         // data members
109         unsigned short total;
110         struct data
111         {
112             unsigned short count;
113             unsigned short left_count;
114         };
115 
116         data* symbols;
117         global_state_type& global_state;
118 
119     };
120 
121 // ----------------------------------------------------------------------------------------
122 // ----------------------------------------------------------------------------------------
123     // member function definitions
124 // ----------------------------------------------------------------------------------------
125 // ----------------------------------------------------------------------------------------
126 
127     template <
128         unsigned long alphabet_size
129         >
130     conditioning_class_kernel_2<alphabet_size>::
conditioning_class_kernel_2(global_state_type & global_state_)131     conditioning_class_kernel_2 (
132         global_state_type& global_state_
133     ) :
134         total(1),
135         symbols(new data[alphabet_size]),
136         global_state(global_state_)
137     {
138         COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 );
139 
140         data* start = symbols;
141         data* end = symbols + alphabet_size-1;
142 
143         while (start != end)
144         {
145             start->count = 0;
146             start->left_count = 0;
147             ++start;
148         }
149 
150         start->count = 1;
151         start->left_count = 0;
152 
153 
154         // update the left_counts for the symbol alphabet_size-1
155         unsigned short temp;
156         unsigned long symbol = alphabet_size-1;
157         while (symbol != 0)
158         {
159             // temp will be 1 if symbol is odd, 0 if it is even
160             temp = static_cast<unsigned short>(symbol&0x1);
161 
162             // set symbol to its parent
163             symbol = (symbol-1)>>1;
164 
165             // note that all left subchidren are odd and also that
166             // if symbol was a left subchild then we want to increment
167             // its parents left_count
168             if (temp)
169                 ++symbols[symbol].left_count;
170         }
171 
172         global_state.memory_usage += sizeof(data)*alphabet_size +
173                                      sizeof(conditioning_class_kernel_2);
174     }
175 
176 // ----------------------------------------------------------------------------------------
177 
178     template <
179         unsigned long alphabet_size
180         >
181     conditioning_class_kernel_2<alphabet_size>::
~conditioning_class_kernel_2()182     ~conditioning_class_kernel_2 (
183     )
184     {
185         delete [] symbols;
186         global_state.memory_usage -= sizeof(data)*alphabet_size +
187                                      sizeof(conditioning_class_kernel_2);
188     }
189 
190 // ----------------------------------------------------------------------------------------
191 
192     template <
193         unsigned long alphabet_size
194         >
195     void conditioning_class_kernel_2<alphabet_size>::
clear()196     clear(
197     )
198     {
199         data* start = symbols;
200         data* end = symbols + alphabet_size-1;
201 
202         total = 1;
203 
204         while (start != end)
205         {
206             start->count = 0;
207             start->left_count = 0;
208             ++start;
209         }
210 
211         start->count = 1;
212         start->left_count = 0;
213 
214         // update the left_counts
215         unsigned short temp;
216         unsigned long symbol = alphabet_size-1;
217         while (symbol != 0)
218         {
219             // temp will be 1 if symbol is odd, 0 if it is even
220             temp = static_cast<unsigned short>(symbol&0x1);
221 
222             // set symbol to its parent
223             symbol = (symbol-1)>>1;
224 
225             // note that all left subchidren are odd and also that
226             // if symbol was a left subchild then we want to increment
227             // its parents left_count
228             symbols[symbol].left_count += temp;
229         }
230     }
231 
232 // ----------------------------------------------------------------------------------------
233 
234     template <
235         unsigned long alphabet_size
236         >
237     unsigned long conditioning_class_kernel_2<alphabet_size>::
get_memory_usage()238     get_memory_usage(
239     ) const
240     {
241         return global_state.memory_usage;
242     }
243 
244 // ----------------------------------------------------------------------------------------
245 
246     template <
247         unsigned long alphabet_size
248         >
249     typename conditioning_class_kernel_2<alphabet_size>::global_state_type& conditioning_class_kernel_2<alphabet_size>::
get_global_state()250     get_global_state(
251     )
252     {
253         return global_state;
254     }
255 
256 // ----------------------------------------------------------------------------------------
257 
258     template <
259         unsigned long alphabet_size
260         >
261     bool conditioning_class_kernel_2<alphabet_size>::
increment_count(unsigned long symbol,unsigned short amount)262     increment_count (
263         unsigned long symbol,
264         unsigned short amount
265     )
266     {
267         // if we need to renormalize then do so
268         if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536)
269         {
270             unsigned long s;
271             unsigned short temp;
272             for (unsigned short i = 0; i < alphabet_size-1; ++i)
273             {
274                 s = i;
275 
276                 // divide the count for this symbol by 2
277                 symbols[i].count >>= 1;
278 
279                 symbols[i].left_count = 0;
280 
281                 // bubble this change up though the tree
282                 while (s != 0)
283                 {
284                     // temp will be 1 if symbol is odd, 0 if it is even
285                     temp = static_cast<unsigned short>(s&0x1);
286 
287                     // set s to its parent
288                     s = (s-1)>>1;
289 
290                     // note that all left subchidren are odd and also that
291                     // if s was a left subchild then we want to increment
292                     // its parents left_count
293                     if (temp)
294                         symbols[s].left_count += symbols[i].count;
295                 }
296             }
297 
298             // update symbols alphabet_size-1
299             {
300                 s = alphabet_size-1;
301 
302                 // divide alphabet_size-1 symbol by 2 if it's > 1
303                 if (symbols[alphabet_size-1].count > 1)
304                     symbols[alphabet_size-1].count >>= 1;
305 
306                 // bubble this change up though the tree
307                 while (s != 0)
308                 {
309                     // temp will be 1 if symbol is odd, 0 if it is even
310                     temp = static_cast<unsigned short>(s&0x1);
311 
312                     // set s to its parent
313                     s = (s-1)>>1;
314 
315                     // note that all left subchidren are odd and also that
316                     // if s was a left subchild then we want to increment
317                     // its parents left_count
318                     if (temp)
319                         symbols[s].left_count += symbols[alphabet_size-1].count;
320                 }
321             }
322 
323 
324 
325 
326 
327 
328             // calculate the new total
329             total = 0;
330             unsigned long m = 0;
331             while (m < alphabet_size)
332             {
333                 total += symbols[m].count + symbols[m].left_count;
334                 m = (m<<1) + 2;
335             }
336 
337         }
338 
339 
340 
341 
342         // increment the count for the specified symbol
343         symbols[symbol].count += amount;;
344         total += amount;
345 
346 
347         unsigned short temp;
348         while (symbol != 0)
349         {
350             // temp will be 1 if symbol is odd, 0 if it is even
351             temp = static_cast<unsigned short>(symbol&0x1);
352 
353             // set symbol to its parent
354             symbol = (symbol-1)>>1;
355 
356             // note that all left subchidren are odd and also that
357             // if symbol was a left subchild then we want to increment
358             // its parents left_count
359             if (temp)
360                 symbols[symbol].left_count += amount;
361         }
362 
363         return true;
364     }
365 
366 // ----------------------------------------------------------------------------------------
367 
368     template <
369         unsigned long alphabet_size
370         >
371     unsigned long conditioning_class_kernel_2<alphabet_size>::
get_count(unsigned long symbol)372     get_count (
373         unsigned long symbol
374     ) const
375     {
376         return symbols[symbol].count;
377     }
378 
379 // ----------------------------------------------------------------------------------------
380 
381     template <
382         unsigned long alphabet_size
383         >
384     unsigned long conditioning_class_kernel_2<alphabet_size>::
get_alphabet_size()385     get_alphabet_size (
386     )
387     {
388         return alphabet_size;
389     }
390 
391 // ----------------------------------------------------------------------------------------
392 
393     template <
394         unsigned long alphabet_size
395         >
396     unsigned long conditioning_class_kernel_2<alphabet_size>::
get_total()397     get_total (
398     ) const
399     {
400         return total;
401     }
402 
403 // ----------------------------------------------------------------------------------------
404 
405     template <
406         unsigned long alphabet_size
407         >
408     unsigned long conditioning_class_kernel_2<alphabet_size>::
get_range(unsigned long symbol,unsigned long & low_count,unsigned long & high_count,unsigned long & total_count)409     get_range (
410         unsigned long symbol,
411         unsigned long& low_count,
412         unsigned long& high_count,
413         unsigned long& total_count
414     ) const
415     {
416         if (symbols[symbol].count == 0)
417             return 0;
418 
419         unsigned long current = symbol;
420         total_count = total;
421         unsigned long high_count_temp = 0;
422         bool came_from_right = true;
423         while (true)
424         {
425             if (came_from_right)
426             {
427                 high_count_temp += symbols[current].count + symbols[current].left_count;
428             }
429 
430             // note that if current is even then it is a right child
431             came_from_right = !(current&0x1);
432 
433             if (current == 0)
434                 break;
435 
436             // set current to its parent
437             current = (current-1)>>1 ;
438         }
439 
440 
441         low_count = high_count_temp - symbols[symbol].count;
442         high_count = high_count_temp;
443 
444         return symbols[symbol].count;
445     }
446 
447 // ----------------------------------------------------------------------------------------
448 
449     template <
450         unsigned long alphabet_size
451         >
452     void conditioning_class_kernel_2<alphabet_size>::
get_symbol(unsigned long target,unsigned long & symbol,unsigned long & low_count,unsigned long & high_count)453     get_symbol (
454         unsigned long target,
455         unsigned long& symbol,
456         unsigned long& low_count,
457         unsigned long& high_count
458     ) const
459     {
460         unsigned long current = 0;
461         unsigned long low_count_temp = 0;
462 
463         while (true)
464         {
465             if (static_cast<unsigned short>(target) < symbols[current].left_count)
466             {
467                 // we should go left
468                 current = (current<<1) + 1;
469             }
470             else
471             {
472                 target -= symbols[current].left_count;
473                 low_count_temp += symbols[current].left_count;
474                 if (static_cast<unsigned short>(target) < symbols[current].count)
475                 {
476                     // we have found our target
477                     symbol = current;
478                     high_count = low_count_temp + symbols[current].count;
479                     low_count = low_count_temp;
480                     break;
481                 }
482                 else
483                 {
484                     // go right
485                     target -= symbols[current].count;
486                     low_count_temp += symbols[current].count;
487                     current = (current<<1) + 2;
488                 }
489             }
490 
491         }
492 
493     }
494 
495 // ----------------------------------------------------------------------------------------
496 
497 }
498 
499 #endif // DLIB_CONDITIONING_CLASS_KERNEl_1_
500 
501