1#!/usr/bin/python 2## 3## Copyright (c) 2016, Alliance for Open Media. All rights reserved 4## 5## This source code is subject to the terms of the BSD 2 Clause License and 6## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License 7## was not distributed with this source code in the LICENSE file, you can 8## obtain it at www.aomedia.org/license/software. If the Alliance for Open 9## Media Patent License 1.0 was not distributed with this source code in the 10## PATENTS file, you can obtain it at www.aomedia.org/license/patent. 11## 12"""Generate the probability model for the constrained token set. 13 14Model obtained from a 2-sided zero-centered distribution derived 15from a Pareto distribution. The cdf of the distribution is: 16cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta] 17 18For a given beta and a given probability of the 1-node, the alpha 19is first solved, and then the {alpha, beta} pair is used to generate 20the probabilities for the rest of the nodes. 21""" 22 23import heapq 24import sys 25import numpy as np 26import scipy.optimize 27import scipy.stats 28 29 30def cdf_spareto(x, xm, beta): 31 p = 1 - (xm / (np.abs(x) + xm))**beta 32 p = 0.5 + 0.5 * np.sign(x) * p 33 return p 34 35 36def get_spareto(p, beta): 37 cdf = cdf_spareto 38 39 def func(x): 40 return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) / 41 (1 - cdf(0.5, x, beta)) - p)**2 42 43 alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12) 44 parray = np.zeros(11) 45 parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5) 46 parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta))) 47 parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta))) 48 parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta))) 49 parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta))) 50 parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta))) 51 parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta))) 52 parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta))) 53 parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta))) 54 parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta))) 55 parray[10] = 2 * (1. - cdf(66.5, alpha, beta)) 56 return parray 57 58 59def quantize_probs(p, save_first_bin, bits): 60 """Quantize probability precisely. 61 62 Quantize probabilities minimizing dH (Kullback-Leibler divergence) 63 approximated by: sum (p_i-q_i)^2/p_i. 64 References: 65 https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence 66 https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit 67 """ 68 num_sym = p.size 69 p = np.clip(p, 1e-16, 1) 70 L = 2**bits 71 pL = p * L 72 ip = 1. / p # inverse probability 73 q = np.clip(np.round(pL), 1, L + 1 - num_sym) 74 quant_err = (pL - q)**2 * ip 75 sgn = np.sign(L - q.sum()) # direction of correction 76 if sgn != 0: # correction is needed 77 v = [] # heap of adjustment results (adjustment err, index) of each symbol 78 for i in range(1 if save_first_bin else 0, num_sym): 79 q_adj = q[i] + sgn 80 if q_adj > 0 and q_adj < L: 81 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 82 heapq.heappush(v, (adj_err, i)) 83 while q.sum() != L: 84 # apply lowest error adjustment 85 (adj_err, i) = heapq.heappop(v) 86 quant_err[i] += adj_err 87 q[i] += sgn 88 # calculate the cost of adjusting this symbol again 89 q_adj = q[i] + sgn 90 if q_adj > 0 and q_adj < L: 91 adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i] 92 heapq.heappush(v, (adj_err, i)) 93 return q 94 95 96def get_quantized_spareto(p, beta, bits, first_token): 97 parray = get_spareto(p, beta) 98 parray = parray[1:] / (1 - parray[0]) 99 # CONFIG_NEW_TOKENSET 100 if first_token > 1: 101 parray = parray[1:] / (1 - parray[0]) 102 qarray = quantize_probs(parray, first_token == 1, bits) 103 return qarray.astype(np.int) 104 105 106def main(bits=15, first_token=1): 107 beta = 8 108 for q in range(1, 256): 109 parray = get_quantized_spareto(q / 256., beta, bits, first_token) 110 assert parray.sum() == 2**bits 111 print '{', ', '.join('%d' % i for i in parray), '},' 112 113 114if __name__ == '__main__': 115 if len(sys.argv) > 2: 116 main(int(sys.argv[1]), int(sys.argv[2])) 117 elif len(sys.argv) > 1: 118 main(int(sys.argv[1])) 119 else: 120 main() 121