1# This file is part of Hypothesis, which may be found at
2# https://github.com/HypothesisWorks/hypothesis/
3#
4# Most of this work is copyright (C) 2013-2021 David R. MacIver
5# (david@drmaciver.com), but it contains contributions by others. See
6# CONTRIBUTING.rst for a full list of people who may hold copyright, and
7# consult the git log if you need to determine who owns an individual
8# contribution.
9#
10# This Source Code Form is subject to the terms of the Mozilla Public License,
11# v. 2.0. If a copy of the MPL was not distributed with this file, You can
12# obtain one at https://mozilla.org/MPL/2.0/.
13#
14# END HEADER
15
16from array import array
17
18from hypothesis.internal.conjecture.utils import calc_label_from_name
19from hypothesis.internal.floats import float_to_int, int_to_float
20
21"""
22This module implements support for arbitrary floating point numbers in
23Conjecture. It doesn't make any attempt to get a good distribution, only to
24get a format that will shrink well.
25
26It works by defining an encoding of non-negative floating point numbers
27(including NaN values with a zero sign bit) that has good lexical shrinking
28properties.
29
30This encoding is a tagged union of two separate encodings for floating point
31numbers, with the tag being the first bit of 64 and the remaining 63-bits being
32the payload.
33
34If the tag bit is 0, the next 7 bits are ignored, and the remaining 7 bytes are
35interpreted as a 7 byte integer in big-endian order and then converted to a
36float (there is some redundancy here, as 7 * 8 = 56, which is larger than the
37largest integer that floating point numbers can represent exactly, so multiple
38encodings may map to the same float).
39
40If the tag bit is 1, we instead use something that is closer to the normal
41representation of floats (and can represent every non-negative float exactly)
42but has a better ordering:
43
441. NaNs are ordered after everything else.
452. Infinity is ordered after every finite number.
463. The sign is ignored unless two floating point numbers are identical in
47   absolute magnitude. In that case, the positive is ordered before the
48   negative.
494. Positive floating point numbers are ordered first by int(x) where
50   encoding(x) < encoding(y) if int(x) < int(y).
515. If int(x) == int(y) then x and y are sorted towards lower denominators of
52   their fractional parts.
53
54The format of this encoding of floating point goes as follows:
55
56    [exponent] [mantissa]
57
58Each of these is the same size their equivalent in IEEE floating point, but are
59in a different format.
60
61We translate exponents as follows:
62
63    1. The maximum exponent (2 ** 11 - 1) is left unchanged.
64    2. We reorder the remaining exponents so that all of the positive exponents
65       are first, in increasing order, followed by all of the negative
66       exponents in decreasing order (where positive/negative is done by the
67       unbiased exponent e - 1023).
68
69We translate the mantissa as follows:
70
71    1. If the unbiased exponent is <= 0 we reverse it bitwise.
72    2. If the unbiased exponent is >= 52 we leave it alone.
73    3. If the unbiased exponent is in the range [1, 51] then we reverse the
74       low k bits, where k is 52 - unbiased exponent.
75
76The low bits correspond to the fractional part of the floating point number.
77Reversing it bitwise means that we try to minimize the low bits, which kills
78off the higher powers of 2 in the fraction first.
79"""
80
81
82MAX_EXPONENT = 0x7FF
83
84SPECIAL_EXPONENTS = (0, MAX_EXPONENT)
85
86BIAS = 1023
87MAX_POSITIVE_EXPONENT = MAX_EXPONENT - 1 - BIAS
88
89DRAW_FLOAT_LABEL = calc_label_from_name("drawing a float")
90
91
92def exponent_key(e):
93    if e == MAX_EXPONENT:
94        return float("inf")
95    unbiased = e - BIAS
96    if unbiased < 0:
97        return 10000 - unbiased
98    else:
99        return unbiased
100
101
102ENCODING_TABLE = array("H", sorted(range(MAX_EXPONENT + 1), key=exponent_key))
103DECODING_TABLE = array("H", [0]) * len(ENCODING_TABLE)
104
105for i, b in enumerate(ENCODING_TABLE):
106    DECODING_TABLE[b] = i
107
108del i, b
109
110
111def decode_exponent(e):
112    """Take draw_bits(11) and turn it into a suitable floating point exponent
113    such that lexicographically simpler leads to simpler floats."""
114    assert 0 <= e <= MAX_EXPONENT
115    return ENCODING_TABLE[e]
116
117
118def encode_exponent(e):
119    """Take a floating point exponent and turn it back into the equivalent
120    result from conjecture."""
121    assert 0 <= e <= MAX_EXPONENT
122    return DECODING_TABLE[e]
123
124
125def reverse_byte(b):
126    result = 0
127    for _ in range(8):
128        result <<= 1
129        result |= b & 1
130        b >>= 1
131    return result
132
133
134# Table mapping individual bytes to the equivalent byte with the bits of the
135# byte reversed. e.g. 1=0b1 is mapped to 0xb10000000=0x80=128. We use this
136# precalculated table to simplify calculating the bitwise reversal of a longer
137# integer.
138REVERSE_BITS_TABLE = bytearray(map(reverse_byte, range(256)))
139
140
141def reverse64(v):
142    """Reverse a 64-bit integer bitwise.
143
144    We do this by breaking it up into 8 bytes. The 64-bit integer is then the
145    concatenation of each of these bytes. We reverse it by reversing each byte
146    on its own using the REVERSE_BITS_TABLE above, and then concatenating the
147    reversed bytes.
148
149    In this case concatenating consists of shifting them into the right
150    position for the word and then oring the bits together.
151    """
152    assert v.bit_length() <= 64
153    return (
154        (REVERSE_BITS_TABLE[(v >> 0) & 0xFF] << 56)
155        | (REVERSE_BITS_TABLE[(v >> 8) & 0xFF] << 48)
156        | (REVERSE_BITS_TABLE[(v >> 16) & 0xFF] << 40)
157        | (REVERSE_BITS_TABLE[(v >> 24) & 0xFF] << 32)
158        | (REVERSE_BITS_TABLE[(v >> 32) & 0xFF] << 24)
159        | (REVERSE_BITS_TABLE[(v >> 40) & 0xFF] << 16)
160        | (REVERSE_BITS_TABLE[(v >> 48) & 0xFF] << 8)
161        | (REVERSE_BITS_TABLE[(v >> 56) & 0xFF] << 0)
162    )
163
164
165MANTISSA_MASK = (1 << 52) - 1
166
167
168def reverse_bits(x, n):
169    assert x.bit_length() <= n <= 64
170    x = reverse64(x)
171    x >>= 64 - n
172    return x
173
174
175def update_mantissa(unbiased_exponent, mantissa):
176    if unbiased_exponent <= 0:
177        mantissa = reverse_bits(mantissa, 52)
178    elif unbiased_exponent <= 51:
179        n_fractional_bits = 52 - unbiased_exponent
180        fractional_part = mantissa & ((1 << n_fractional_bits) - 1)
181        mantissa ^= fractional_part
182        mantissa |= reverse_bits(fractional_part, n_fractional_bits)
183    return mantissa
184
185
186def lex_to_float(i):
187    assert i.bit_length() <= 64
188    has_fractional_part = i >> 63
189    if has_fractional_part:
190        exponent = (i >> 52) & ((1 << 11) - 1)
191        exponent = decode_exponent(exponent)
192        mantissa = i & MANTISSA_MASK
193        mantissa = update_mantissa(exponent - BIAS, mantissa)
194
195        assert mantissa.bit_length() <= 52
196
197        return int_to_float((exponent << 52) | mantissa)
198    else:
199        integral_part = i & ((1 << 56) - 1)
200        return float(integral_part)
201
202
203def float_to_lex(f):
204    if is_simple(f):
205        assert f >= 0
206        return int(f)
207    return base_float_to_lex(f)
208
209
210def base_float_to_lex(f):
211    i = float_to_int(f)
212    i &= (1 << 63) - 1
213    exponent = i >> 52
214    mantissa = i & MANTISSA_MASK
215    mantissa = update_mantissa(exponent - BIAS, mantissa)
216    exponent = encode_exponent(exponent)
217
218    assert mantissa.bit_length() <= 52
219    return (1 << 63) | (exponent << 52) | mantissa
220
221
222def is_simple(f):
223    try:
224        i = int(f)
225    except (ValueError, OverflowError):
226        return False
227    if i != f:
228        return False
229    return i.bit_length() <= 56
230
231
232def draw_float(data):
233    try:
234        data.start_example(DRAW_FLOAT_LABEL)
235        f = lex_to_float(data.draw_bits(64))
236        if data.draw_bits(1):
237            f = -f
238        return f
239    finally:
240        data.stop_example()
241
242
243def write_float(data, f):
244    data.draw_bits(64, forced=float_to_lex(abs(f)))
245    sign = float_to_int(f) >> 63
246    data.draw_bits(1, forced=sign)
247