1 // Copyright (C) 2003 Davis E. King (davis@dlib.net) 2 // License: Boost Software License See LICENSE.txt for the full license. 3 #ifndef DLIB_CONDITIONING_CLASS_KERNEl_2_ 4 #define DLIB_CONDITIONING_CLASS_KERNEl_2_ 5 6 #include "conditioning_class_kernel_abstract.h" 7 #include "../assert.h" 8 #include "../algs.h" 9 10 namespace dlib 11 { 12 13 template < 14 unsigned long alphabet_size 15 > 16 class conditioning_class_kernel_2 17 { 18 /*! 19 INITIAL VALUE 20 total == 1 21 symbols == pointer to array of alphabet_size data structs 22 for all i except i == alphabet_size-1: symbols[i].count == 0 23 symbols[i].left_count == 0 24 25 symbols[alphabet_size-1].count == 1 26 symbols[alpahbet_size-1].left_count == 0 27 28 CONVENTION 29 symbols == pointer to array of alphabet_size data structs 30 get_total() == total 31 get_count(symbol) == symbols[symbol].count 32 33 symbols is organized as a tree with symbols[0] as the root. 34 35 the left subchild of symbols[i] is symbols[i*2+1] and 36 the right subchild is symbols[i*2+2]. 37 the partent of symbols[i] == symbols[(i-1)/2] 38 39 symbols[i].left_count == the sum of the counts of all the 40 symbols to the left of symbols[i] 41 42 get_memory_usage() == global_state.memory_usage 43 !*/ 44 45 public: 46 47 class global_state_type 48 { 49 public: global_state_type()50 global_state_type () : memory_usage(0) {} 51 private: 52 unsigned long memory_usage; 53 54 friend class conditioning_class_kernel_2<alphabet_size>; 55 }; 56 57 conditioning_class_kernel_2 ( 58 global_state_type& global_state_ 59 ); 60 61 ~conditioning_class_kernel_2 ( 62 ); 63 64 void clear( 65 ); 66 67 bool increment_count ( 68 unsigned long symbol, 69 unsigned short amount = 1 70 ); 71 72 unsigned long get_count ( 73 unsigned long symbol 74 ) const; 75 76 inline unsigned long get_total ( 77 ) const; 78 79 unsigned long get_range ( 80 unsigned long symbol, 81 unsigned long& low_count, 82 unsigned long& high_count, 83 unsigned long& total_count 84 ) const; 85 86 void get_symbol ( 87 unsigned long target, 88 unsigned long& symbol, 89 unsigned long& low_count, 90 unsigned long& high_count 91 ) const; 92 93 unsigned long get_memory_usage ( 94 ) const; 95 96 global_state_type& get_global_state ( 97 ); 98 99 static unsigned long get_alphabet_size ( 100 ); 101 102 private: 103 104 // restricted functions 105 conditioning_class_kernel_2(conditioning_class_kernel_2<alphabet_size>&); // copy constructor 106 conditioning_class_kernel_2& operator=(conditioning_class_kernel_2<alphabet_size>&); // assignment operator 107 108 // data members 109 unsigned short total; 110 struct data 111 { 112 unsigned short count; 113 unsigned short left_count; 114 }; 115 116 data* symbols; 117 global_state_type& global_state; 118 119 }; 120 121 // ---------------------------------------------------------------------------------------- 122 // ---------------------------------------------------------------------------------------- 123 // member function definitions 124 // ---------------------------------------------------------------------------------------- 125 // ---------------------------------------------------------------------------------------- 126 127 template < 128 unsigned long alphabet_size 129 > 130 conditioning_class_kernel_2<alphabet_size>:: conditioning_class_kernel_2(global_state_type & global_state_)131 conditioning_class_kernel_2 ( 132 global_state_type& global_state_ 133 ) : 134 total(1), 135 symbols(new data[alphabet_size]), 136 global_state(global_state_) 137 { 138 COMPILE_TIME_ASSERT( 1 < alphabet_size && alphabet_size < 65536 ); 139 140 data* start = symbols; 141 data* end = symbols + alphabet_size-1; 142 143 while (start != end) 144 { 145 start->count = 0; 146 start->left_count = 0; 147 ++start; 148 } 149 150 start->count = 1; 151 start->left_count = 0; 152 153 154 // update the left_counts for the symbol alphabet_size-1 155 unsigned short temp; 156 unsigned long symbol = alphabet_size-1; 157 while (symbol != 0) 158 { 159 // temp will be 1 if symbol is odd, 0 if it is even 160 temp = static_cast<unsigned short>(symbol&0x1); 161 162 // set symbol to its parent 163 symbol = (symbol-1)>>1; 164 165 // note that all left subchidren are odd and also that 166 // if symbol was a left subchild then we want to increment 167 // its parents left_count 168 if (temp) 169 ++symbols[symbol].left_count; 170 } 171 172 global_state.memory_usage += sizeof(data)*alphabet_size + 173 sizeof(conditioning_class_kernel_2); 174 } 175 176 // ---------------------------------------------------------------------------------------- 177 178 template < 179 unsigned long alphabet_size 180 > 181 conditioning_class_kernel_2<alphabet_size>:: ~conditioning_class_kernel_2()182 ~conditioning_class_kernel_2 ( 183 ) 184 { 185 delete [] symbols; 186 global_state.memory_usage -= sizeof(data)*alphabet_size + 187 sizeof(conditioning_class_kernel_2); 188 } 189 190 // ---------------------------------------------------------------------------------------- 191 192 template < 193 unsigned long alphabet_size 194 > 195 void conditioning_class_kernel_2<alphabet_size>:: clear()196 clear( 197 ) 198 { 199 data* start = symbols; 200 data* end = symbols + alphabet_size-1; 201 202 total = 1; 203 204 while (start != end) 205 { 206 start->count = 0; 207 start->left_count = 0; 208 ++start; 209 } 210 211 start->count = 1; 212 start->left_count = 0; 213 214 // update the left_counts 215 unsigned short temp; 216 unsigned long symbol = alphabet_size-1; 217 while (symbol != 0) 218 { 219 // temp will be 1 if symbol is odd, 0 if it is even 220 temp = static_cast<unsigned short>(symbol&0x1); 221 222 // set symbol to its parent 223 symbol = (symbol-1)>>1; 224 225 // note that all left subchidren are odd and also that 226 // if symbol was a left subchild then we want to increment 227 // its parents left_count 228 symbols[symbol].left_count += temp; 229 } 230 } 231 232 // ---------------------------------------------------------------------------------------- 233 234 template < 235 unsigned long alphabet_size 236 > 237 unsigned long conditioning_class_kernel_2<alphabet_size>:: get_memory_usage()238 get_memory_usage( 239 ) const 240 { 241 return global_state.memory_usage; 242 } 243 244 // ---------------------------------------------------------------------------------------- 245 246 template < 247 unsigned long alphabet_size 248 > 249 typename conditioning_class_kernel_2<alphabet_size>::global_state_type& conditioning_class_kernel_2<alphabet_size>:: get_global_state()250 get_global_state( 251 ) 252 { 253 return global_state; 254 } 255 256 // ---------------------------------------------------------------------------------------- 257 258 template < 259 unsigned long alphabet_size 260 > 261 bool conditioning_class_kernel_2<alphabet_size>:: increment_count(unsigned long symbol,unsigned short amount)262 increment_count ( 263 unsigned long symbol, 264 unsigned short amount 265 ) 266 { 267 // if we need to renormalize then do so 268 if (static_cast<unsigned long>(total)+static_cast<unsigned long>(amount) >= 65536) 269 { 270 unsigned long s; 271 unsigned short temp; 272 for (unsigned short i = 0; i < alphabet_size-1; ++i) 273 { 274 s = i; 275 276 // divide the count for this symbol by 2 277 symbols[i].count >>= 1; 278 279 symbols[i].left_count = 0; 280 281 // bubble this change up though the tree 282 while (s != 0) 283 { 284 // temp will be 1 if symbol is odd, 0 if it is even 285 temp = static_cast<unsigned short>(s&0x1); 286 287 // set s to its parent 288 s = (s-1)>>1; 289 290 // note that all left subchidren are odd and also that 291 // if s was a left subchild then we want to increment 292 // its parents left_count 293 if (temp) 294 symbols[s].left_count += symbols[i].count; 295 } 296 } 297 298 // update symbols alphabet_size-1 299 { 300 s = alphabet_size-1; 301 302 // divide alphabet_size-1 symbol by 2 if it's > 1 303 if (symbols[alphabet_size-1].count > 1) 304 symbols[alphabet_size-1].count >>= 1; 305 306 // bubble this change up though the tree 307 while (s != 0) 308 { 309 // temp will be 1 if symbol is odd, 0 if it is even 310 temp = static_cast<unsigned short>(s&0x1); 311 312 // set s to its parent 313 s = (s-1)>>1; 314 315 // note that all left subchidren are odd and also that 316 // if s was a left subchild then we want to increment 317 // its parents left_count 318 if (temp) 319 symbols[s].left_count += symbols[alphabet_size-1].count; 320 } 321 } 322 323 324 325 326 327 328 // calculate the new total 329 total = 0; 330 unsigned long m = 0; 331 while (m < alphabet_size) 332 { 333 total += symbols[m].count + symbols[m].left_count; 334 m = (m<<1) + 2; 335 } 336 337 } 338 339 340 341 342 // increment the count for the specified symbol 343 symbols[symbol].count += amount;; 344 total += amount; 345 346 347 unsigned short temp; 348 while (symbol != 0) 349 { 350 // temp will be 1 if symbol is odd, 0 if it is even 351 temp = static_cast<unsigned short>(symbol&0x1); 352 353 // set symbol to its parent 354 symbol = (symbol-1)>>1; 355 356 // note that all left subchidren are odd and also that 357 // if symbol was a left subchild then we want to increment 358 // its parents left_count 359 if (temp) 360 symbols[symbol].left_count += amount; 361 } 362 363 return true; 364 } 365 366 // ---------------------------------------------------------------------------------------- 367 368 template < 369 unsigned long alphabet_size 370 > 371 unsigned long conditioning_class_kernel_2<alphabet_size>:: get_count(unsigned long symbol)372 get_count ( 373 unsigned long symbol 374 ) const 375 { 376 return symbols[symbol].count; 377 } 378 379 // ---------------------------------------------------------------------------------------- 380 381 template < 382 unsigned long alphabet_size 383 > 384 unsigned long conditioning_class_kernel_2<alphabet_size>:: get_alphabet_size()385 get_alphabet_size ( 386 ) 387 { 388 return alphabet_size; 389 } 390 391 // ---------------------------------------------------------------------------------------- 392 393 template < 394 unsigned long alphabet_size 395 > 396 unsigned long conditioning_class_kernel_2<alphabet_size>:: get_total()397 get_total ( 398 ) const 399 { 400 return total; 401 } 402 403 // ---------------------------------------------------------------------------------------- 404 405 template < 406 unsigned long alphabet_size 407 > 408 unsigned long conditioning_class_kernel_2<alphabet_size>:: get_range(unsigned long symbol,unsigned long & low_count,unsigned long & high_count,unsigned long & total_count)409 get_range ( 410 unsigned long symbol, 411 unsigned long& low_count, 412 unsigned long& high_count, 413 unsigned long& total_count 414 ) const 415 { 416 if (symbols[symbol].count == 0) 417 return 0; 418 419 unsigned long current = symbol; 420 total_count = total; 421 unsigned long high_count_temp = 0; 422 bool came_from_right = true; 423 while (true) 424 { 425 if (came_from_right) 426 { 427 high_count_temp += symbols[current].count + symbols[current].left_count; 428 } 429 430 // note that if current is even then it is a right child 431 came_from_right = !(current&0x1); 432 433 if (current == 0) 434 break; 435 436 // set current to its parent 437 current = (current-1)>>1 ; 438 } 439 440 441 low_count = high_count_temp - symbols[symbol].count; 442 high_count = high_count_temp; 443 444 return symbols[symbol].count; 445 } 446 447 // ---------------------------------------------------------------------------------------- 448 449 template < 450 unsigned long alphabet_size 451 > 452 void conditioning_class_kernel_2<alphabet_size>:: get_symbol(unsigned long target,unsigned long & symbol,unsigned long & low_count,unsigned long & high_count)453 get_symbol ( 454 unsigned long target, 455 unsigned long& symbol, 456 unsigned long& low_count, 457 unsigned long& high_count 458 ) const 459 { 460 unsigned long current = 0; 461 unsigned long low_count_temp = 0; 462 463 while (true) 464 { 465 if (static_cast<unsigned short>(target) < symbols[current].left_count) 466 { 467 // we should go left 468 current = (current<<1) + 1; 469 } 470 else 471 { 472 target -= symbols[current].left_count; 473 low_count_temp += symbols[current].left_count; 474 if (static_cast<unsigned short>(target) < symbols[current].count) 475 { 476 // we have found our target 477 symbol = current; 478 high_count = low_count_temp + symbols[current].count; 479 low_count = low_count_temp; 480 break; 481 } 482 else 483 { 484 // go right 485 target -= symbols[current].count; 486 low_count_temp += symbols[current].count; 487 current = (current<<1) + 2; 488 } 489 } 490 491 } 492 493 } 494 495 // ---------------------------------------------------------------------------------------- 496 497 } 498 499 #endif // DLIB_CONDITIONING_CLASS_KERNEl_1_ 500 501