1#!/usr/bin/python
2##
3## Copyright (c) 2016, Alliance for Open Media. All rights reserved
4##
5## This source code is subject to the terms of the BSD 2 Clause License and
6## the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7## was not distributed with this source code in the LICENSE file, you can
8## obtain it at www.aomedia.org/license/software. If the Alliance for Open
9## Media Patent License 1.0 was not distributed with this source code in the
10## PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11##
12"""Generate the probability model for the constrained token set.
13
14Model obtained from a 2-sided zero-centered distribution derived
15from a Pareto distribution. The cdf of the distribution is:
16cdf(x) = 0.5 + 0.5 * sgn(x) * [1 - {alpha/(alpha + |x|)} ^ beta]
17
18For a given beta and a given probability of the 1-node, the alpha
19is first solved, and then the {alpha, beta} pair is used to generate
20the probabilities for the rest of the nodes.
21"""
22
23import heapq
24import sys
25import numpy as np
26import scipy.optimize
27import scipy.stats
28
29
30def cdf_spareto(x, xm, beta):
31  p = 1 - (xm / (np.abs(x) + xm))**beta
32  p = 0.5 + 0.5 * np.sign(x) * p
33  return p
34
35
36def get_spareto(p, beta):
37  cdf = cdf_spareto
38
39  def func(x):
40    return ((cdf(1.5, x, beta) - cdf(0.5, x, beta)) /
41            (1 - cdf(0.5, x, beta)) - p)**2
42
43  alpha = scipy.optimize.fminbound(func, 1e-12, 10000, xtol=1e-12)
44  parray = np.zeros(11)
45  parray[0] = 2 * (cdf(0.5, alpha, beta) - 0.5)
46  parray[1] = (2 * (cdf(1.5, alpha, beta) - cdf(0.5, alpha, beta)))
47  parray[2] = (2 * (cdf(2.5, alpha, beta) - cdf(1.5, alpha, beta)))
48  parray[3] = (2 * (cdf(3.5, alpha, beta) - cdf(2.5, alpha, beta)))
49  parray[4] = (2 * (cdf(4.5, alpha, beta) - cdf(3.5, alpha, beta)))
50  parray[5] = (2 * (cdf(6.5, alpha, beta) - cdf(4.5, alpha, beta)))
51  parray[6] = (2 * (cdf(10.5, alpha, beta) - cdf(6.5, alpha, beta)))
52  parray[7] = (2 * (cdf(18.5, alpha, beta) - cdf(10.5, alpha, beta)))
53  parray[8] = (2 * (cdf(34.5, alpha, beta) - cdf(18.5, alpha, beta)))
54  parray[9] = (2 * (cdf(66.5, alpha, beta) - cdf(34.5, alpha, beta)))
55  parray[10] = 2 * (1. - cdf(66.5, alpha, beta))
56  return parray
57
58
59def quantize_probs(p, save_first_bin, bits):
60  """Quantize probability precisely.
61
62  Quantize probabilities minimizing dH (Kullback-Leibler divergence)
63  approximated by: sum (p_i-q_i)^2/p_i.
64  References:
65  https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence
66  https://github.com/JarekDuda/AsymmetricNumeralSystemsToolkit
67  """
68  num_sym = p.size
69  p = np.clip(p, 1e-16, 1)
70  L = 2**bits
71  pL = p * L
72  ip = 1. / p  # inverse probability
73  q = np.clip(np.round(pL), 1, L + 1 - num_sym)
74  quant_err = (pL - q)**2 * ip
75  sgn = np.sign(L - q.sum())  # direction of correction
76  if sgn != 0:  # correction is needed
77    v = []  # heap of adjustment results (adjustment err, index) of each symbol
78    for i in range(1 if save_first_bin else 0, num_sym):
79      q_adj = q[i] + sgn
80      if q_adj > 0 and q_adj < L:
81        adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
82        heapq.heappush(v, (adj_err, i))
83    while q.sum() != L:
84      # apply lowest error adjustment
85      (adj_err, i) = heapq.heappop(v)
86      quant_err[i] += adj_err
87      q[i] += sgn
88      # calculate the cost of adjusting this symbol again
89      q_adj = q[i] + sgn
90      if q_adj > 0 and q_adj < L:
91        adj_err = (pL[i] - q_adj)**2 * ip[i] - quant_err[i]
92        heapq.heappush(v, (adj_err, i))
93  return q
94
95
96def get_quantized_spareto(p, beta, bits, first_token):
97  parray = get_spareto(p, beta)
98  parray = parray[1:] / (1 - parray[0])
99  # CONFIG_NEW_TOKENSET
100  if first_token > 1:
101    parray = parray[1:] / (1 - parray[0])
102  qarray = quantize_probs(parray, first_token == 1, bits)
103  return qarray.astype(np.int)
104
105
106def main(bits=15, first_token=1):
107  beta = 8
108  for q in range(1, 256):
109    parray = get_quantized_spareto(q / 256., beta, bits, first_token)
110    assert parray.sum() == 2**bits
111    print '{', ', '.join('%d' % i for i in parray), '},'
112
113
114if __name__ == '__main__':
115  if len(sys.argv) > 2:
116    main(int(sys.argv[1]), int(sys.argv[2]))
117  elif len(sys.argv) > 1:
118    main(int(sys.argv[1]))
119  else:
120    main()
121