1 /*
2  * Copyright 2015 Philip Taylor <philip@zaynar.co.uk>
3  * Copyright 2018 Advanced Micro Devices, Inc.
4  *
5  * Permission is hereby granted, free of charge, to any person obtaining a
6  * copy of this software and associated documentation files (the "Software"),
7  * to deal in the Software without restriction, including without limitation
8  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
9  * and/or sell copies of the Software, and to permit persons to whom the
10  * Software is furnished to do so, subject to the following conditions:
11  *
12  * The above copyright notice and this permission notice (including the next
13  * paragraph) shall be included in all copies or substantial portions of the
14  * Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
19  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
21  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
22  * DEALINGS IN THE SOFTWARE.
23  */
24 
25 /**
26  * \file texcompress_astc.c
27  *
28  * Decompression code for GL_KHR_texture_compression_astc_ldr, which is just
29  * ASTC 2D LDR.
30  *
31  * The ASTC 2D LDR decoder (without the sRGB part) was copied from the OASTC
32  * library written by Philip Taylor. I added sRGB support and adjusted it for
33  * Mesa. - Marek
34  */
35 
36 #include "texcompress_astc.h"
37 #include "macros.h"
38 #include "util/half_float.h"
39 #include <stdio.h>
40 #include <cstdlib>  // for abort() on windows
41 
42 static bool VERBOSE_DECODE = false;
43 static bool VERBOSE_WRITE = false;
44 
45 static inline uint8_t
uint16_div_64k_to_half_to_unorm8(uint16_t v)46 uint16_div_64k_to_half_to_unorm8(uint16_t v)
47 {
48    return _mesa_half_to_unorm8(_mesa_uint16_div_64k_to_half(v));
49 }
50 
51 class decode_error
52 {
53 public:
54    enum type {
55       ok,
56       unsupported_hdr_void_extent,
57       reserved_block_mode_1,
58       reserved_block_mode_2,
59       dual_plane_and_too_many_partitions,
60       invalid_range_in_void_extent,
61       weight_grid_exceeds_block_size,
62       invalid_colour_endpoints_size,
63       invalid_colour_endpoints_count,
64       invalid_weight_bits,
65       invalid_num_weights,
66    };
67 };
68 
69 
70 struct cem_range {
71    uint8_t max;
72    uint8_t t, q, b;
73 };
74 
75 /* Based on the Color Unquantization Parameters table,
76  * plus the bit-only representations, sorted by increasing size
77  */
78 static cem_range cem_ranges[] = {
79    { 5, 1, 0, 1 },
80    { 7, 0, 0, 3 },
81    { 9, 0, 1, 1 },
82    { 11, 1, 0, 2 },
83    { 15, 0, 0, 4 },
84    { 19, 0, 1, 2 },
85    { 23, 1, 0, 3 },
86    { 31, 0, 0, 5 },
87    { 39, 0, 1, 3 },
88    { 47, 1, 0, 4 },
89    { 63, 0, 0, 6 },
90    { 79, 0, 1, 4 },
91    { 95, 1, 0, 5 },
92    { 127, 0, 0, 7 },
93    { 159, 0, 1, 5 },
94    { 191, 1, 0, 6 },
95    { 255, 0, 0, 8 },
96 };
97 
98 #define CAT_BITS_2(a, b)          ( ((a) << 1) | (b) )
99 #define CAT_BITS_3(a, b, c)       ( ((a) << 2) | ((b) << 1) | (c) )
100 #define CAT_BITS_4(a, b, c, d)    ( ((a) << 3) | ((b) << 2) | ((c) << 1) | (d) )
101 #define CAT_BITS_5(a, b, c, d, e) ( ((a) << 4) | ((b) << 3) | ((c) << 2) | ((d) << 1) | (e) )
102 
103 /**
104  * Unpack 5n+8 bits from 'in' into 5 output values.
105  * If n <= 4 then T should be uint32_t, else it must be uint64_t.
106  */
107 template <typename T>
unpack_trit_block(int n,T in,uint8_t * out)108 static void unpack_trit_block(int n, T in, uint8_t *out)
109 {
110    assert(n <= 6); /* else output will overflow uint8_t */
111 
112    uint8_t T0 = (in >> (n)) & 0x1;
113    uint8_t T1 = (in >> (n+1)) & 0x1;
114    uint8_t T2 = (in >> (2*n+2)) & 0x1;
115    uint8_t T3 = (in >> (2*n+3)) & 0x1;
116    uint8_t T4 = (in >> (3*n+4)) & 0x1;
117    uint8_t T5 = (in >> (4*n+5)) & 0x1;
118    uint8_t T6 = (in >> (4*n+6)) & 0x1;
119    uint8_t T7 = (in >> (5*n+7)) & 0x1;
120    uint8_t mmask = (1 << n) - 1;
121    uint8_t m0 = (in >> (0)) & mmask;
122    uint8_t m1 = (in >> (n+2)) & mmask;
123    uint8_t m2 = (in >> (2*n+4)) & mmask;
124    uint8_t m3 = (in >> (3*n+5)) & mmask;
125    uint8_t m4 = (in >> (4*n+7)) & mmask;
126 
127    uint8_t C;
128    uint8_t t4, t3, t2, t1, t0;
129    if (CAT_BITS_3(T4, T3, T2) == 0x7) {
130       C = CAT_BITS_5(T7, T6, T5, T1, T0);
131       t4 = t3 = 2;
132    } else {
133       C = CAT_BITS_5(T4, T3, T2, T1, T0);
134       if (CAT_BITS_2(T6, T5) == 0x3) {
135          t4 = 2;
136          t3 = T7;
137       } else {
138          t4 = T7;
139          t3 = CAT_BITS_2(T6, T5);
140       }
141    }
142 
143    if ((C & 0x3) == 0x3) {
144       t2 = 2;
145       t1 = (C >> 4) & 0x1;
146       uint8_t C3 = (C >> 3) & 0x1;
147       uint8_t C2 = (C >> 2) & 0x1;
148       t0 = (C3 << 1) | (C2 & ~C3);
149    } else if (((C >> 2) & 0x3) == 0x3) {
150       t2 = 2;
151       t1 = 2;
152       t0 = C & 0x3;
153    } else {
154       t2 = (C >> 4) & 0x1;
155       t1 = (C >> 2) & 0x3;
156       uint8_t C1 = (C >> 1) & 0x1;
157       uint8_t C0 = (C >> 0) & 0x1;
158       t0 = (C1 << 1) | (C0 & ~C1);
159    }
160 
161    out[0] = (t0 << n) | m0;
162    out[1] = (t1 << n) | m1;
163    out[2] = (t2 << n) | m2;
164    out[3] = (t3 << n) | m3;
165    out[4] = (t4 << n) | m4;
166 }
167 
168 /**
169  * Unpack 3n+7 bits from 'in' into 3 output values
170  */
unpack_quint_block(int n,uint32_t in,uint8_t * out)171 static void unpack_quint_block(int n, uint32_t in, uint8_t *out)
172 {
173    assert(n <= 5); /* else output will overflow uint8_t */
174 
175    uint8_t Q0 = (in >> (n)) & 0x1;
176    uint8_t Q1 = (in >> (n+1)) & 0x1;
177    uint8_t Q2 = (in >> (n+2)) & 0x1;
178    uint8_t Q3 = (in >> (2*n+3)) & 0x1;
179    uint8_t Q4 = (in >> (2*n+4)) & 0x1;
180    uint8_t Q5 = (in >> (3*n+5)) & 0x1;
181    uint8_t Q6 = (in >> (3*n+6)) & 0x1;
182    uint8_t mmask = (1 << n) - 1;
183    uint8_t m0 = (in >> (0)) & mmask;
184    uint8_t m1 = (in >> (n+3)) & mmask;
185    uint8_t m2 = (in >> (2*n+5)) & mmask;
186 
187    uint8_t C;
188    uint8_t q2, q1, q0;
189    if (CAT_BITS_4(Q6, Q5, Q2, Q1) == 0x3) {
190       q2 = CAT_BITS_3(Q0, Q4 & ~Q0, Q3 & ~Q0);
191       q1 = 4;
192       q0 = 4;
193    } else {
194       if (CAT_BITS_2(Q2, Q1) == 0x3) {
195          q2 = 4;
196          C = CAT_BITS_5(Q4, Q3, 0x1 & ~Q6, 0x1 & ~Q5, Q0);
197       } else {
198          q2 = CAT_BITS_2(Q6, Q5);
199          C = CAT_BITS_5(Q4, Q3, Q2, Q1, Q0);
200       }
201       if ((C & 0x7) == 0x5) {
202          q1 = 4;
203          q0 = (C >> 3) & 0x3;
204       } else {
205          q1 = (C >> 3) & 0x3;
206          q0 = C & 0x7;
207       }
208    }
209    out[0] = (q0 << n) | m0;
210    out[1] = (q1 << n) | m1;
211    out[2] = (q2 << n) | m2;
212 }
213 
214 
215 struct uint8x4_t
216 {
217    uint8_t v[4];
218 
uint8x4_tuint8x4_t219    uint8x4_t() { }
220 
uint8x4_tuint8x4_t221    uint8x4_t(int a, int b, int c, int d)
222    {
223       assert(0 <= a && a <= 255);
224       assert(0 <= b && b <= 255);
225       assert(0 <= c && c <= 255);
226       assert(0 <= d && d <= 255);
227       v[0] = a;
228       v[1] = b;
229       v[2] = c;
230       v[3] = d;
231    }
232 
clampeduint8x4_t233    static uint8x4_t clamped(int a, int b, int c, int d)
234    {
235       uint8x4_t r;
236       r.v[0] = MAX2(0, MIN2(255, a));
237       r.v[1] = MAX2(0, MIN2(255, b));
238       r.v[2] = MAX2(0, MIN2(255, c));
239       r.v[3] = MAX2(0, MIN2(255, d));
240       return r;
241    }
242 };
243 
blue_contract(int r,int g,int b,int a)244 static uint8x4_t blue_contract(int r, int g, int b, int a)
245 {
246    return uint8x4_t((r+b) >> 1, (g+b) >> 1, b, a);
247 }
248 
blue_contract_clamped(int r,int g,int b,int a)249 static uint8x4_t blue_contract_clamped(int r, int g, int b, int a)
250 {
251    return uint8x4_t::clamped((r+b) >> 1, (g+b) >> 1, b, a);
252 }
253 
bit_transfer_signed(int & a,int & b)254 static void bit_transfer_signed(int &a, int &b)
255 {
256    b >>= 1;
257    b |= a & 0x80;
258    a >>= 1;
259    a &= 0x3f;
260    if (a & 0x20)
261       a -= 0x40;
262 }
263 
hash52(uint32_t p)264 static uint32_t hash52(uint32_t p)
265 {
266    p ^= p >> 15;
267    p -= p << 17;
268    p += p << 7;
269    p += p << 4;
270    p ^= p >> 5;
271    p += p << 16;
272    p ^= p >> 7;
273    p ^= p >> 3;
274    p ^= p << 6;
275    p ^= p >> 17;
276    return p;
277 }
278 
select_partition(int seed,int x,int y,int z,int partitioncount,int small_block)279 static int select_partition(int seed, int x, int y, int z, int partitioncount,
280                             int small_block)
281 {
282    if (small_block) {
283       x <<= 1;
284       y <<= 1;
285       z <<= 1;
286    }
287    seed += (partitioncount - 1) * 1024;
288    uint32_t rnum = hash52(seed);
289    uint8_t seed1 = rnum & 0xF;
290    uint8_t seed2 = (rnum >> 4) & 0xF;
291    uint8_t seed3 = (rnum >> 8) & 0xF;
292    uint8_t seed4 = (rnum >> 12) & 0xF;
293    uint8_t seed5 = (rnum >> 16) & 0xF;
294    uint8_t seed6 = (rnum >> 20) & 0xF;
295    uint8_t seed7 = (rnum >> 24) & 0xF;
296    uint8_t seed8 = (rnum >> 28) & 0xF;
297    uint8_t seed9 = (rnum >> 18) & 0xF;
298    uint8_t seed10 = (rnum >> 22) & 0xF;
299    uint8_t seed11 = (rnum >> 26) & 0xF;
300    uint8_t seed12 = ((rnum >> 30) | (rnum << 2)) & 0xF;
301 
302    seed1 *= seed1;
303    seed2 *= seed2;
304    seed3 *= seed3;
305    seed4 *= seed4;
306    seed5 *= seed5;
307    seed6 *= seed6;
308    seed7 *= seed7;
309    seed8 *= seed8;
310    seed9 *= seed9;
311    seed10 *= seed10;
312    seed11 *= seed11;
313    seed12 *= seed12;
314 
315    int sh1, sh2, sh3;
316    if (seed & 1) {
317       sh1 = (seed & 2 ? 4 : 5);
318       sh2 = (partitioncount == 3 ? 6 : 5);
319    } else {
320       sh1 = (partitioncount == 3 ? 6 : 5);
321       sh2 = (seed & 2 ? 4 : 5);
322    }
323    sh3 = (seed & 0x10) ? sh1 : sh2;
324 
325    seed1 >>= sh1;
326    seed2 >>= sh2;
327    seed3 >>= sh1;
328    seed4 >>= sh2;
329    seed5 >>= sh1;
330    seed6 >>= sh2;
331    seed7 >>= sh1;
332    seed8 >>= sh2;
333    seed9 >>= sh3;
334    seed10 >>= sh3;
335    seed11 >>= sh3;
336    seed12 >>= sh3;
337 
338    int a = seed1 * x + seed2 * y + seed11 * z + (rnum >> 14);
339    int b = seed3 * x + seed4 * y + seed12 * z + (rnum >> 10);
340    int c = seed5 * x + seed6 * y + seed9 * z + (rnum >> 6);
341    int d = seed7 * x + seed8 * y + seed10 * z + (rnum >> 2);
342 
343    a &= 0x3F;
344    b &= 0x3F;
345    c &= 0x3F;
346    d &= 0x3F;
347 
348    if (partitioncount < 4)
349       d = 0;
350    if (partitioncount < 3)
351       c = 0;
352 
353    if (a >= b && a >= c && a >= d)
354       return 0;
355    else if (b >= c && b >= d)
356       return 1;
357    else if (c >= d)
358       return 2;
359    else
360       return 3;
361 }
362 
363 
364 struct InputBitVector
365 {
366    uint32_t data[4];
367 
printf_bitsInputBitVector368    void printf_bits(int offset, int count, const char *fmt = "", ...)
369    {
370       char out[129];
371       memset(out, '.', 128);
372       out[128] = '\0';
373       int idx = offset;
374       for (int i = 0; i < count; ++i) {
375          out[127 - idx] = ((data[idx >> 5] >> (idx & 31)) & 1) ? '1' : '0';
376          ++idx;
377       }
378       printf("%s ", out);
379       va_list ap;
380       va_start(ap, fmt);
381       vprintf(fmt, ap);
382       va_end(ap);
383       printf("\n");
384    }
385 
get_bitsInputBitVector386    uint32_t get_bits(int offset, int count)
387    {
388       assert(count >= 0 && count < 32);
389 
390       uint32_t out = 0;
391       if (offset < 32)
392          out |= data[0] >> offset;
393 
394       if (0 < offset && offset <= 32)
395          out |= data[1] << (32 - offset);
396       if (32 < offset && offset < 64)
397          out |= data[1] >> (offset - 32);
398 
399       if (32 < offset && offset <= 64)
400          out |= data[2] << (64 - offset);
401       if (64 < offset && offset < 96)
402          out |= data[2] >> (offset - 64);
403 
404       if (64 < offset && offset <= 96)
405          out |= data[3] << (96 - offset);
406       if (96 < offset && offset < 128)
407          out |= data[3] >> (offset - 96);
408 
409       out &= (1 << count) - 1;
410       return out;
411    }
412 
get_bits64InputBitVector413    uint64_t get_bits64(int offset, int count)
414    {
415       assert(count >= 0 && count < 64);
416 
417       uint64_t out = 0;
418       if (offset < 32)
419          out |= data[0] >> offset;
420 
421       if (offset <= 32)
422          out |= (uint64_t)data[1] << (32 - offset);
423       if (32 < offset && offset < 64)
424          out |= data[1] >> (offset - 32);
425 
426       if (0 < offset && offset <= 64)
427          out |= (uint64_t)data[2] << (64 - offset);
428       if (64 < offset && offset < 96)
429          out |= data[2] >> (offset - 64);
430 
431       if (32 < offset && offset <= 96)
432          out |= (uint64_t)data[3] << (96 - offset);
433       if (96 < offset && offset < 128)
434          out |= data[3] >> (offset - 96);
435 
436       out &= ((uint64_t)1 << count) - 1;
437       return out;
438    }
439 
get_bits_revInputBitVector440    uint32_t get_bits_rev(int offset, int count)
441    {
442       assert(offset >= count);
443       uint32_t tmp = get_bits(offset - count, count);
444       uint32_t out = 0;
445       for (int i = 0; i < count; ++i)
446          out |= ((tmp >> i) & 1) << (count - 1 - i);
447       return out;
448    }
449 };
450 
451 struct OutputBitVector
452 {
453    uint32_t data[4];
454    int offset;
455 
OutputBitVectorOutputBitVector456    OutputBitVector()
457       : offset(0)
458    {
459       memset(data, 0, sizeof(data));
460    }
461 
appendOutputBitVector462    void append(uint32_t value, int size)
463    {
464       if (VERBOSE_WRITE)
465          printf("append offset=%d size=%d values=0x%x\n", offset, size, value);
466 
467       assert(offset + size <= 128);
468 
469       assert(size <= 32);
470       if (size < 32)
471          assert((value >> size) == 0);
472 
473       while (size) {
474          int c = MIN2(size, 32 - (offset & 31));
475          data[offset >> 5] |= (value << (offset & 31));
476          offset += c;
477          size -= c;
478          value >>= c;
479       }
480    }
481 
append64OutputBitVector482    void append64(uint64_t value, int size)
483    {
484       if (VERBOSE_WRITE)
485          printf("append offset=%d size=%d values=0x%llx\n", offset, size, (unsigned long long)value);
486 
487       assert(offset + size <= 128);
488 
489       assert(size <= 64);
490       if (size < 64)
491          assert((value >> size) == 0);
492 
493       while (size) {
494          int c = MIN2(size, 32 - (offset & 31));
495          data[offset >> 5] |= (value << (offset & 31));
496          offset += c;
497          size -= c;
498          value >>= c;
499       }
500    }
501 
appendOutputBitVector502    void append(OutputBitVector &v, int size)
503    {
504       if (VERBOSE_WRITE)
505          printf("append vector offset=%d size=%d\n", offset, size);
506 
507       assert(offset + size <= 128);
508       int i = 0;
509       while (size >= 32) {
510          append(v.data[i++], 32);
511          size -= 32;
512       }
513       if (size > 0)
514          append(v.data[i] & ((1 << size) - 1), size);
515    }
516 
append_endOutputBitVector517    void append_end(OutputBitVector &v, int size)
518    {
519       for (int i = 0; i < size; ++i)
520          data[(127 - i) >> 5] |= ((v.data[i >> 5] >> (i & 31)) & 1) << ((127 - i) & 31);
521    }
522 
523    /* Insert the given number of '1' bits. (We could use 0s instead, but 1s are
524     * more likely to flush out bugs where we accidentally read undefined bits.)
525     */
skipOutputBitVector526    void skip(int size)
527    {
528       if (VERBOSE_WRITE)
529          printf("skip offset=%d size=%d\n", offset, size);
530 
531       assert(offset + size <= 128);
532       while (size >= 32) {
533          append(0xffffffff, 32);
534          size -= 32;
535       }
536       if (size > 0)
537          append(0xffffffff >> (32 - size), size);
538    }
539 };
540 
541 
542 class Decoder
543 {
544 public:
Decoder(int block_w,int block_h,int block_d,bool srgb,bool output_unorm8)545    Decoder(int block_w, int block_h, int block_d, bool srgb, bool output_unorm8)
546       : block_w(block_w), block_h(block_h), block_d(block_d), srgb(srgb),
547         output_unorm8(output_unorm8) {}
548 
549    decode_error::type decode(const uint8_t *in, uint16_t *output) const;
550 
551    int block_w, block_h, block_d;
552    bool srgb, output_unorm8;
553 };
554 
555 struct Block
556 {
557    bool is_error;
558    bool bogus_colour_endpoints;
559    bool bogus_weights;
560 
561    int high_prec;
562    int dual_plane;
563    int colour_component_selector;
564    int wt_range;
565    int wt_w, wt_h, wt_d;
566    int num_parts;
567    int partition_index;
568 
569    bool is_void_extent;
570    int void_extent_d;
571    int void_extent_min_s;
572    int void_extent_max_s;
573    int void_extent_min_t;
574    int void_extent_max_t;
575    uint16_t void_extent_colour_r;
576    uint16_t void_extent_colour_g;
577    uint16_t void_extent_colour_b;
578    uint16_t void_extent_colour_a;
579 
580    bool is_multi_cem;
581    int num_extra_cem_bits;
582    int colour_endpoint_data_offset;
583    int extra_cem_bits;
584    int cem_base_class;
585    int cems[4];
586 
587    int num_cem_values;
588 
589    /* Calculated by unpack_weights(): */
590    uint8_t weights_quant[64 + 4]; /* max 64 values, plus padding for overflows in trit parsing */
591 
592    /* Calculated by unquantise_weights(): */
593    uint8_t weights[64 + 18]; /* max 64 values, plus padding for the infill interpolation */
594 
595    /* Calculated by unpack_colour_endpoints(): */
596    uint8_t colour_endpoints_quant[18 + 4]; /* max 18 values, plus padding for overflows in trit parsing */
597 
598    /* Calculated by unquantise_colour_endpoints(): */
599    uint8_t colour_endpoints[18];
600 
601    /* Calculated by calculate_from_weights(): */
602    int wt_trits;
603    int wt_quints;
604    int wt_bits;
605    int wt_max;
606    int num_weights;
607    int weight_bits;
608 
609    /* Calculated by calculate_remaining_bits(): */
610    int remaining_bits;
611 
612    /* Calculated by calculate_colour_endpoints_size(): */
613    int colour_endpoint_bits;
614    int ce_max;
615    int ce_trits;
616    int ce_quints;
617    int ce_bits;
618 
619    /* Calculated by compute_infill_weights(); */
620    uint8_t infill_weights[2][216]; /* large enough for 6x6x6 */
621 
622    /* Calculated by decode_colour_endpoints(); */
623    uint8x4_t endpoints_decoded[2][4];
624 
625    void calculate_from_weights();
626    void calculate_remaining_bits();
627    decode_error::type calculate_colour_endpoints_size();
628 
629    void unquantise_weights();
630    void unquantise_colour_endpoints();
631 
632    decode_error::type decode(const Decoder &decoder, InputBitVector in);
633 
634    decode_error::type decode_block_mode(InputBitVector in);
635    decode_error::type decode_void_extent(InputBitVector in);
636    void decode_cem(InputBitVector in);
637    void unpack_colour_endpoints(InputBitVector in);
638    void decode_colour_endpoints();
639    void unpack_weights(InputBitVector in);
640    void compute_infill_weights(int block_w, int block_h, int block_d);
641 
642    void write_decoded(const Decoder &decoder, uint16_t *output);
643 };
644 
645 
decode(const uint8_t * in,uint16_t * output) const646 decode_error::type Decoder::decode(const uint8_t *in, uint16_t *output) const
647 {
648    Block blk;
649    InputBitVector in_vec;
650    memcpy(&in_vec.data, in, 16);
651    decode_error::type err = blk.decode(*this, in_vec);
652    if (err == decode_error::ok) {
653       blk.write_decoded(*this, output);
654    } else {
655       /* Fill output with the error colour */
656       for (int i = 0; i < block_w * block_h * block_d; ++i) {
657          if (output_unorm8) {
658             output[i*4+0] = 0xff;
659             output[i*4+1] = 0;
660             output[i*4+2] = 0xff;
661             output[i*4+3] = 0xff;
662          } else {
663             assert(!srgb); /* srgb must use unorm8 */
664 
665             output[i*4+0] = FP16_ONE;
666             output[i*4+1] = FP16_ZERO;
667             output[i*4+2] = FP16_ONE;
668             output[i*4+3] = FP16_ONE;
669          }
670       }
671    }
672    return err;
673 }
674 
675 
decode_void_extent(InputBitVector block)676 decode_error::type Block::decode_void_extent(InputBitVector block)
677 {
678    /* TODO: 3D */
679 
680    is_void_extent = true;
681    void_extent_d = block.get_bits(9, 1);
682    void_extent_min_s = block.get_bits(12, 13);
683    void_extent_max_s = block.get_bits(25, 13);
684    void_extent_min_t = block.get_bits(38, 13);
685    void_extent_max_t = block.get_bits(51, 13);
686    void_extent_colour_r = block.get_bits(64, 16);
687    void_extent_colour_g = block.get_bits(80, 16);
688    void_extent_colour_b = block.get_bits(96, 16);
689    void_extent_colour_a = block.get_bits(112, 16);
690 
691    /* TODO: maybe we should do something useful with the extent coordinates? */
692 
693    if (void_extent_d) {
694       return decode_error::unsupported_hdr_void_extent;
695    }
696 
697    if (void_extent_min_s == 0x1fff && void_extent_max_s == 0x1fff
698        && void_extent_min_t == 0x1fff && void_extent_max_t == 0x1fff) {
699 
700       /* No extents */
701 
702    } else {
703 
704       /* Check for illegal encoding */
705       if (void_extent_min_s >= void_extent_max_s || void_extent_min_t >= void_extent_max_t) {
706          return decode_error::invalid_range_in_void_extent;
707       }
708    }
709 
710    return decode_error::ok;
711 }
712 
decode_block_mode(InputBitVector in)713 decode_error::type Block::decode_block_mode(InputBitVector in)
714 {
715    dual_plane = in.get_bits(10, 1);
716    high_prec = in.get_bits(9, 1);
717 
718    if (in.get_bits(0, 2) != 0x0) {
719       wt_range = (in.get_bits(0, 2) << 1) | in.get_bits(4, 1);
720       int a = in.get_bits(5, 2);
721       int b = in.get_bits(7, 2);
722       switch (in.get_bits(2, 2)) {
723       case 0x0:
724          if (VERBOSE_DECODE)
725             in.printf_bits(0, 11, "DHBBAAR00RR");
726          wt_w = b + 4;
727          wt_h = a + 2;
728          break;
729       case 0x1:
730          if (VERBOSE_DECODE)
731             in.printf_bits(0, 11, "DHBBAAR01RR");
732          wt_w = b + 8;
733          wt_h = a + 2;
734          break;
735       case 0x2:
736          if (VERBOSE_DECODE)
737             in.printf_bits(0, 11, "DHBBAAR10RR");
738          wt_w = a + 2;
739          wt_h = b + 8;
740          break;
741       case 0x3:
742          if ((b & 0x2) == 0) {
743             if (VERBOSE_DECODE)
744                in.printf_bits(0, 11, "DH0BAAR11RR");
745             wt_w = a + 2;
746             wt_h = b + 6;
747          } else {
748             if (VERBOSE_DECODE)
749                in.printf_bits(0, 11, "DH1BAAR11RR");
750             wt_w = (b & 0x1) + 2;
751             wt_h = a + 2;
752          }
753          break;
754       }
755    } else {
756       if (in.get_bits(6, 3) == 0x7) {
757          if (in.get_bits(0, 9) == 0x1fc) {
758             if (VERBOSE_DECODE)
759                in.printf_bits(0, 11, "xx111111100 (void extent)");
760             return decode_void_extent(in);
761          } else {
762             if (VERBOSE_DECODE)
763                in.printf_bits(0, 11, "xx111xxxx00");
764             return decode_error::reserved_block_mode_1;
765          }
766       }
767       if (in.get_bits(0, 4) == 0x0) {
768          if (VERBOSE_DECODE)
769             in.printf_bits(0, 11, "xxxxxxx0000");
770          return decode_error::reserved_block_mode_2;
771       }
772 
773       wt_range = in.get_bits(1, 3) | in.get_bits(4, 1);
774       int a = in.get_bits(5, 2);
775       int b;
776 
777       switch (in.get_bits(7, 2)) {
778       case 0x0:
779          if (VERBOSE_DECODE)
780             in.printf_bits(0, 11, "DH00AARRR00");
781          wt_w = 12;
782          wt_h = a + 2;
783          break;
784       case 0x1:
785          if (VERBOSE_DECODE)
786             in.printf_bits(0, 11, "DH01AARRR00");
787          wt_w = a + 2;
788          wt_h = 12;
789          break;
790       case 0x3:
791          if (in.get_bits(5, 1) == 0) {
792             if (VERBOSE_DECODE)
793                in.printf_bits(0, 11, "DH1100RRR00");
794             wt_w = 6;
795             wt_h = 10;
796          } else {
797             if (VERBOSE_DECODE)
798                in.printf_bits(0, 11, "DH1101RRR00");
799             wt_w = 10;
800             wt_h = 6;
801          }
802          break;
803       case 0x2:
804          if (VERBOSE_DECODE)
805             in.printf_bits(0, 11, "BB10AARRR00");
806          b = in.get_bits(9, 2);
807          wt_w = a + 6;
808          wt_h = b + 6;
809          dual_plane = 0;
810          high_prec = 0;
811          break;
812       }
813    }
814    return decode_error::ok;
815 }
816 
decode_cem(InputBitVector in)817 void Block::decode_cem(InputBitVector in)
818 {
819    cems[0] = cems[1] = cems[2] = cems[3] = -1;
820 
821    num_extra_cem_bits = 0;
822    extra_cem_bits = 0;
823 
824    if (num_parts > 1) {
825 
826       partition_index = in.get_bits(13, 10);
827       if (VERBOSE_DECODE)
828          in.printf_bits(13, 10, "partition ID (%d)", partition_index);
829 
830       uint32_t cem = in.get_bits(23, 6);
831 
832       if ((cem & 0x3) == 0x0) {
833          cem >>= 2;
834          cem_base_class = cem >> 2;
835          is_multi_cem = false;
836 
837          for (int i = 0; i < num_parts; ++i)
838             cems[i] = cem;
839 
840          if (VERBOSE_DECODE)
841             in.printf_bits(23, 6, "CEM (single, %d)", cem);
842       } else {
843 
844          cem_base_class = (cem & 0x3) - 1;
845          is_multi_cem = true;
846 
847          if (VERBOSE_DECODE)
848             in.printf_bits(23, 6, "CEM (multi, base class %d)", cem_base_class);
849 
850          int offset = 128 - weight_bits;
851 
852          if (num_parts == 2) {
853             if (VERBOSE_DECODE) {
854                in.printf_bits(25, 4, "M0M0 C1 C0");
855                in.printf_bits(offset - 2, 2, "M1M1");
856             }
857 
858             uint32_t c0 = in.get_bits(25, 1);
859             uint32_t c1 = in.get_bits(26, 1);
860 
861             extra_cem_bits = c0 + c1;
862 
863             num_extra_cem_bits = 2;
864 
865             uint32_t m0 = in.get_bits(27, 2);
866             uint32_t m1 = in.get_bits(offset - 2, 2);
867 
868             cems[0] = ((cem_base_class + c0) << 2) | m0;
869             cems[1] = ((cem_base_class + c1) << 2) | m1;
870 
871          } else if (num_parts == 3) {
872             if (VERBOSE_DECODE) {
873                in.printf_bits(25, 4, "M0 C2 C1 C0");
874                in.printf_bits(offset - 5, 5, "M2M2 M1M1 M0");
875             }
876 
877             uint32_t c0 = in.get_bits(25, 1);
878             uint32_t c1 = in.get_bits(26, 1);
879             uint32_t c2 = in.get_bits(27, 1);
880 
881             extra_cem_bits = c0 + c1 + c2;
882 
883             num_extra_cem_bits = 5;
884 
885             uint32_t m0 = in.get_bits(28, 1) | (in.get_bits(128 - weight_bits - 5, 1) << 1);
886             uint32_t m1 = in.get_bits(offset - 4, 2);
887             uint32_t m2 = in.get_bits(offset - 2, 2);
888 
889             cems[0] = ((cem_base_class + c0) << 2) | m0;
890             cems[1] = ((cem_base_class + c1) << 2) | m1;
891             cems[2] = ((cem_base_class + c2) << 2) | m2;
892 
893          } else if (num_parts == 4) {
894             if (VERBOSE_DECODE) {
895                in.printf_bits(25, 4, "C3 C2 C1 C0");
896                in.printf_bits(offset - 8, 8, "M3M3 M2M2 M1M1 M0M0");
897             }
898 
899             uint32_t c0 = in.get_bits(25, 1);
900             uint32_t c1 = in.get_bits(26, 1);
901             uint32_t c2 = in.get_bits(27, 1);
902             uint32_t c3 = in.get_bits(28, 1);
903 
904             extra_cem_bits = c0 + c1 + c2 + c3;
905 
906             num_extra_cem_bits = 8;
907 
908             uint32_t m0 = in.get_bits(offset - 8, 2);
909             uint32_t m1 = in.get_bits(offset - 6, 2);
910             uint32_t m2 = in.get_bits(offset - 4, 2);
911             uint32_t m3 = in.get_bits(offset - 2, 2);
912 
913             cems[0] = ((cem_base_class + c0) << 2) | m0;
914             cems[1] = ((cem_base_class + c1) << 2) | m1;
915             cems[2] = ((cem_base_class + c2) << 2) | m2;
916             cems[3] = ((cem_base_class + c3) << 2) | m3;
917          } else {
918             unreachable("");
919          }
920       }
921 
922       colour_endpoint_data_offset = 29;
923 
924    } else {
925       uint32_t cem = in.get_bits(13, 4);
926 
927       cem_base_class = cem >> 2;
928       is_multi_cem = false;
929 
930       cems[0] = cem;
931 
932       partition_index = -1;
933 
934       if (VERBOSE_DECODE)
935          in.printf_bits(13, 4, "CEM = %d (class %d)", cem, cem_base_class);
936 
937       colour_endpoint_data_offset = 17;
938    }
939 }
940 
unpack_colour_endpoints(InputBitVector in)941 void Block::unpack_colour_endpoints(InputBitVector in)
942 {
943    if (ce_trits) {
944       int offset = colour_endpoint_data_offset;
945       int bits_left = colour_endpoint_bits;
946       for (int i = 0; i < num_cem_values; i += 5) {
947          int bits_to_read = MIN2(bits_left, 8 + ce_bits * 5);
948          /* If ce_trits then ce_bits <= 6, so bits_to_read <= 38 and we have to use uint64_t */
949          uint64_t raw = in.get_bits64(offset, bits_to_read);
950          unpack_trit_block(ce_bits, raw, &colour_endpoints_quant[i]);
951 
952          if (VERBOSE_DECODE)
953             in.printf_bits(offset, bits_to_read,
954                            "trits [%d,%d,%d,%d,%d]",
955                            colour_endpoints_quant[i+0], colour_endpoints_quant[i+1],
956                   colour_endpoints_quant[i+2], colour_endpoints_quant[i+3],
957                   colour_endpoints_quant[i+4]);
958 
959          offset += 8 + ce_bits * 5;
960          bits_left -= 8 + ce_bits * 5;
961       }
962    } else if (ce_quints) {
963       int offset = colour_endpoint_data_offset;
964       int bits_left = colour_endpoint_bits;
965       for (int i = 0; i < num_cem_values; i += 3) {
966          int bits_to_read = MIN2(bits_left, 7 + ce_bits * 3);
967          /* If ce_quints then ce_bits <= 5, so bits_to_read <= 22 and we can use uint32_t */
968          uint32_t raw = in.get_bits(offset, bits_to_read);
969          unpack_quint_block(ce_bits, raw, &colour_endpoints_quant[i]);
970 
971          if (VERBOSE_DECODE)
972             in.printf_bits(offset, bits_to_read,
973                            "quints [%d,%d,%d]",
974                            colour_endpoints_quant[i], colour_endpoints_quant[i+1], colour_endpoints_quant[i+2]);
975 
976          offset += 7 + ce_bits * 3;
977          bits_left -= 7 + ce_bits * 3;
978       }
979    } else {
980       assert((colour_endpoint_bits % ce_bits) == 0);
981       int offset = colour_endpoint_data_offset;
982       for (int i = 0; i < num_cem_values; i++) {
983          colour_endpoints_quant[i] = in.get_bits(offset, ce_bits);
984 
985          if (VERBOSE_DECODE)
986             in.printf_bits(offset, ce_bits, "bits [%d]", colour_endpoints_quant[i]);
987 
988          offset += ce_bits;
989       }
990    }
991 }
992 
decode_colour_endpoints()993 void Block::decode_colour_endpoints()
994 {
995    int cem_values_idx = 0;
996    for (int part = 0; part < num_parts; ++part) {
997       uint8_t *v = &colour_endpoints[cem_values_idx];
998       int v0 = v[0];
999       int v1 = v[1];
1000       int v2 = v[2];
1001       int v3 = v[3];
1002       int v4 = v[4];
1003       int v5 = v[5];
1004       int v6 = v[6];
1005       int v7 = v[7];
1006       cem_values_idx += ((cems[part] >> 2) + 1) * 2;
1007 
1008       uint8x4_t e0, e1;
1009       int s0, s1, L0, L1;
1010 
1011       switch (cems[part])
1012       {
1013       case 0:
1014          e0 = uint8x4_t(v0, v0, v0, 0xff);
1015          e1 = uint8x4_t(v1, v1, v1, 0xff);
1016          break;
1017       case 1:
1018          L0 = (v0 >> 2) | (v1 & 0xc0);
1019          L1 = L0 + (v1 & 0x3f);
1020          if (L1 > 0xff)
1021             L1 = 0xff;
1022          e0 = uint8x4_t(L0, L0, L0, 0xff);
1023          e1 = uint8x4_t(L1, L1, L1, 0xff);
1024          break;
1025       case 4:
1026          e0 = uint8x4_t(v0, v0, v0, v2);
1027          e1 = uint8x4_t(v1, v1, v1, v3);
1028          break;
1029       case 5:
1030          bit_transfer_signed(v1, v0);
1031          bit_transfer_signed(v3, v2);
1032          e0 = uint8x4_t(v0, v0, v0, v2);
1033          e1 = uint8x4_t::clamped(v0+v1, v0+v1, v0+v1, v2+v3);
1034          break;
1035       case 6:
1036          e0 = uint8x4_t(v0*v3 >> 8, v1*v3 >> 8, v2*v3 >> 8, 0xff);
1037          e1 = uint8x4_t(v0, v1, v2, 0xff);
1038          break;
1039       case 8:
1040          s0 = v0 + v2 + v4;
1041          s1 = v1 + v3 + v5;
1042          if (s1 >= s0) {
1043             e0 = uint8x4_t(v0, v2, v4, 0xff);
1044             e1 = uint8x4_t(v1, v3, v5, 0xff);
1045          } else {
1046             e0 = blue_contract(v1, v3, v5, 0xff);
1047             e1 = blue_contract(v0, v2, v4, 0xff);
1048          }
1049          break;
1050       case 9:
1051          bit_transfer_signed(v1, v0);
1052          bit_transfer_signed(v3, v2);
1053          bit_transfer_signed(v5, v4);
1054          if (v1 + v3 + v5 >= 0) {
1055             e0 = uint8x4_t(v0, v2, v4, 0xff);
1056             e1 = uint8x4_t::clamped(v0+v1, v2+v3, v4+v5, 0xff);
1057          } else {
1058             e0 = blue_contract_clamped(v0+v1, v2+v3, v4+v5, 0xff);
1059             e1 = blue_contract(v0, v2, v4, 0xff);
1060          }
1061          break;
1062       case 10:
1063          e0 = uint8x4_t(v0*v3 >> 8, v1*v3 >> 8, v2*v3 >> 8, v4);
1064          e1 = uint8x4_t(v0, v1, v2, v5);
1065          break;
1066       case 12:
1067          s0 = v0 + v2 + v4;
1068          s1 = v1 + v3 + v5;
1069          if (s1 >= s0) {
1070             e0 = uint8x4_t(v0, v2, v4, v6);
1071             e1 = uint8x4_t(v1, v3, v5, v7);
1072          } else {
1073             e0 = blue_contract(v1, v3, v5, v7);
1074             e1 = blue_contract(v0, v2, v4, v6);
1075          }
1076          break;
1077       case 13:
1078          bit_transfer_signed(v1, v0);
1079          bit_transfer_signed(v3, v2);
1080          bit_transfer_signed(v5, v4);
1081          bit_transfer_signed(v7, v6);
1082          if (v1 + v3 + v5 >= 0) {
1083             e0 = uint8x4_t(v0, v2, v4, v6);
1084             e1 = uint8x4_t::clamped(v0+v1, v2+v3, v4+v5, v6+v7);
1085          } else {
1086             e0 = blue_contract_clamped(v0+v1, v2+v3, v4+v5, v6+v7);
1087             e1 = blue_contract(v0, v2, v4, v6);
1088          }
1089          break;
1090       default:
1091          /* HDR endpoints not supported; return error colour */
1092          e0 = uint8x4_t(255, 0, 255, 255);
1093          e1 = uint8x4_t(255, 0, 255, 255);
1094          break;
1095       }
1096 
1097       endpoints_decoded[0][part] = e0;
1098       endpoints_decoded[1][part] = e1;
1099 
1100       if (VERBOSE_DECODE) {
1101          printf("cems[%d]=%d v=[", part, cems[part]);
1102          for (int i = 0; i < (cems[part] >> 2) + 1; ++i) {
1103             if (i)
1104                printf(", ");
1105             printf("%3d", v[i]);
1106          }
1107          printf("] e0=[%3d,%4d,%4d,%4d] e1=[%3d,%4d,%4d,%4d]\n",
1108                 e0.v[0], e0.v[1], e0.v[2], e0.v[3],
1109                e1.v[0], e1.v[1], e1.v[2], e1.v[3]);
1110       }
1111    }
1112 }
1113 
unpack_weights(InputBitVector in)1114 void Block::unpack_weights(InputBitVector in)
1115 {
1116    if (wt_trits) {
1117       int offset = 128;
1118       int bits_left = weight_bits;
1119       for (int i = 0; i < num_weights; i += 5) {
1120          int bits_to_read = MIN2(bits_left, 8 + 5*wt_bits);
1121          /* If wt_trits then wt_bits <= 3, so bits_to_read <= 23 and we can use uint32_t */
1122          uint32_t raw = in.get_bits_rev(offset, bits_to_read);
1123          unpack_trit_block(wt_bits, raw, &weights_quant[i]);
1124 
1125          if (VERBOSE_DECODE)
1126             in.printf_bits(offset - bits_to_read, bits_to_read, "weight trits [%d,%d,%d,%d,%d]",
1127                            weights_quant[i+0], weights_quant[i+1],
1128                   weights_quant[i+2], weights_quant[i+3],
1129                   weights_quant[i+4]);
1130 
1131          offset -= 8 + wt_bits * 5;
1132          bits_left -= 8 + wt_bits * 5;
1133       }
1134 
1135    } else if (wt_quints) {
1136 
1137       int offset = 128;
1138       int bits_left = weight_bits;
1139       for (int i = 0; i < num_weights; i += 3) {
1140          int bits_to_read = MIN2(bits_left, 7 + 3*wt_bits);
1141          /* If wt_quints then wt_bits <= 2, so bits_to_read <= 13 and we can use uint32_t */
1142          uint32_t raw = in.get_bits_rev(offset, bits_to_read);
1143          unpack_quint_block(wt_bits, raw, &weights_quant[i]);
1144 
1145          if (VERBOSE_DECODE)
1146             in.printf_bits(offset - bits_to_read, bits_to_read, "weight quints [%d,%d,%d]",
1147                            weights_quant[i], weights_quant[i+1], weights_quant[i+2]);
1148 
1149          offset -= 7 + wt_bits * 3;
1150          bits_left -= 7 + wt_bits * 3;
1151       }
1152 
1153    } else {
1154       int offset = 128;
1155       assert((weight_bits % wt_bits) == 0);
1156       for (int i = 0; i < num_weights; ++i) {
1157          weights_quant[i] = in.get_bits_rev(offset, wt_bits);
1158 
1159          if (VERBOSE_DECODE)
1160             in.printf_bits(offset - wt_bits, wt_bits, "weight bits [%d]", weights_quant[i]);
1161 
1162          offset -= wt_bits;
1163       }
1164    }
1165 }
1166 
unquantise_weights()1167 void Block::unquantise_weights()
1168 {
1169    assert(num_weights <= (int)ARRAY_SIZE(weights_quant));
1170    assert(num_weights <= (int)ARRAY_SIZE(weights));
1171 
1172    memset(weights, 0, sizeof(weights));
1173 
1174    for (int i = 0; i < num_weights; ++i) {
1175 
1176       uint8_t v = weights_quant[i];
1177       uint8_t w;
1178 
1179       if (wt_trits) {
1180 
1181          if (wt_bits == 0) {
1182             w = v * 32;
1183          } else {
1184             uint8_t A, B, C, D;
1185             A = (v & 0x1) ? 0x7F : 0x00;
1186             switch (wt_bits) {
1187             case 1:
1188                B = 0;
1189                C = 50;
1190                D = v >> 1;
1191                break;
1192             case 2:
1193                B = (v & 0x2) ? 0x45 : 0x00;
1194                C = 23;
1195                D = v >> 2;
1196                break;
1197             case 3:
1198                B = ((v & 0x6) >> 1) | ((v & 0x6) << 4);
1199                C = 11;
1200                D = v >> 3;
1201                break;
1202             default:
1203                unreachable("");
1204             }
1205             uint16_t T = D * C + B;
1206             T = T ^ A;
1207             T = (A & 0x20) | (T >> 2);
1208             assert(T < 64);
1209             if (T > 32)
1210                T++;
1211             w = T;
1212          }
1213 
1214       } else if (wt_quints) {
1215 
1216          if (wt_bits == 0) {
1217             w = v * 16;
1218          } else {
1219             uint8_t A, B, C, D;
1220             A = (v & 0x1) ? 0x7F : 0x00;
1221             switch (wt_bits) {
1222             case 1:
1223                B = 0;
1224                C = 28;
1225                D = v >> 1;
1226                break;
1227             case 2:
1228                B = (v & 0x2) ? 0x42 : 0x00;
1229                C = 13;
1230                D = v >> 2;
1231                break;
1232             default:
1233                unreachable("");
1234             }
1235             uint16_t T = D * C + B;
1236             T = T ^ A;
1237             T = (A & 0x20) | (T >> 2);
1238             assert(T < 64);
1239             if (T > 32)
1240                T++;
1241             w = T;
1242          }
1243          weights[i] = w;
1244 
1245       } else {
1246 
1247          switch (wt_bits) {
1248          case 1: w = v ? 0x3F : 0x00; break;
1249          case 2: w = v | (v << 2) | (v << 4); break;
1250          case 3: w = v | (v << 3); break;
1251          case 4: w = (v >> 2) | (v << 2); break;
1252          case 5: w = (v >> 4) | (v << 1); break;
1253          default: unreachable("");
1254          }
1255          assert(w < 64);
1256          if (w > 32)
1257             w++;
1258       }
1259       weights[i] = w;
1260    }
1261 }
1262 
compute_infill_weights(int block_w,int block_h,int block_d)1263 void Block::compute_infill_weights(int block_w, int block_h, int block_d)
1264 {
1265    int Ds = block_w <= 1 ? 0 : (1024 + block_w / 2) / (block_w - 1);
1266    int Dt = block_h <= 1 ? 0 : (1024 + block_h / 2) / (block_h - 1);
1267    int Dr = block_d <= 1 ? 0 : (1024 + block_d / 2) / (block_d - 1);
1268    for (int r = 0; r < block_d; ++r) {
1269       for (int t = 0; t < block_h; ++t) {
1270          for (int s = 0; s < block_w; ++s) {
1271             int cs = Ds * s;
1272             int ct = Dt * t;
1273             int cr = Dr * r;
1274             int gs = (cs * (wt_w - 1) + 32) >> 6;
1275             int gt = (ct * (wt_h - 1) + 32) >> 6;
1276             int gr = (cr * (wt_d - 1) + 32) >> 6;
1277             assert(gs >= 0 && gs <= 176);
1278             assert(gt >= 0 && gt <= 176);
1279             assert(gr >= 0 && gr <= 176);
1280             int js = gs >> 4;
1281             int fs = gs & 0xf;
1282             int jt = gt >> 4;
1283             int ft = gt & 0xf;
1284             int jr = gr >> 4;
1285             int fr = gr & 0xf;
1286 
1287             /* TODO: 3D */
1288             (void)jr;
1289             (void)fr;
1290 
1291             int w11 = (fs * ft + 8) >> 4;
1292             int w10 = ft - w11;
1293             int w01 = fs - w11;
1294             int w00 = 16 - fs - ft + w11;
1295 
1296             if (dual_plane) {
1297                int p00, p01, p10, p11, i0, i1;
1298                int v0 = js + jt * wt_w;
1299                p00 = weights[(v0) * 2];
1300                p01 = weights[(v0 + 1) * 2];
1301                p10 = weights[(v0 + wt_w) * 2];
1302                p11 = weights[(v0 + wt_w + 1) * 2];
1303                i0 = (p00*w00 + p01*w01 + p10*w10 + p11*w11 + 8) >> 4;
1304                p00 = weights[(v0) * 2 + 1];
1305                p01 = weights[(v0 + 1) * 2 + 1];
1306                p10 = weights[(v0 + wt_w) * 2 + 1];
1307                p11 = weights[(v0 + wt_w + 1) * 2 + 1];
1308                assert((v0 + wt_w + 1) * 2 + 1 < (int)ARRAY_SIZE(weights));
1309                i1 = (p00*w00 + p01*w01 + p10*w10 + p11*w11 + 8) >> 4;
1310                assert(0 <= i0 && i0 <= 64);
1311                infill_weights[0][s + t*block_w + r*block_w*block_h] = i0;
1312                infill_weights[1][s + t*block_w + r*block_w*block_h] = i1;
1313             } else {
1314                int p00, p01, p10, p11, i;
1315                int v0 = js + jt * wt_w;
1316                p00 = weights[v0];
1317                p01 = weights[v0 + 1];
1318                p10 = weights[v0 + wt_w];
1319                p11 = weights[v0 + wt_w + 1];
1320                assert(v0 + wt_w + 1 < (int)ARRAY_SIZE(weights));
1321                i = (p00*w00 + p01*w01 + p10*w10 + p11*w11 + 8) >> 4;
1322                assert(0 <= i && i <= 64);
1323                infill_weights[0][s + t*block_w + r*block_w*block_h] = i;
1324             }
1325          }
1326       }
1327    }
1328 }
1329 
unquantise_colour_endpoints()1330 void Block::unquantise_colour_endpoints()
1331 {
1332    assert(num_cem_values <= (int)ARRAY_SIZE(colour_endpoints_quant));
1333    assert(num_cem_values <= (int)ARRAY_SIZE(colour_endpoints));
1334 
1335    for (int i = 0; i < num_cem_values; ++i) {
1336       uint8_t v = colour_endpoints_quant[i];
1337 
1338       if (ce_trits) {
1339          uint16_t A, B, C, D;
1340          uint16_t t;
1341          A = (v & 0x1) ? 0x1FF : 0x000;
1342          switch (ce_bits) {
1343          case 1:
1344             B = 0;
1345             C = 204;
1346             D = v >> 1;
1347             break;
1348          case 2:
1349             B = (v & 0x2) ? 0x116 : 0x000;
1350             C = 93;
1351             D = v >> 2;
1352             break;
1353          case 3:
1354             t = ((v >> 1) & 0x3);
1355             B = t | (t << 2) | (t << 7);
1356             C = 44;
1357             D = v >> 3;
1358             break;
1359          case 4:
1360             t = ((v >> 1) & 0x7);
1361             B = t | (t << 6);
1362             C = 22;
1363             D = v >> 4;
1364             break;
1365          case 5:
1366             t = ((v >> 1) & 0xF);
1367             B = (t >> 2) | (t << 5);
1368             C = 11;
1369             D = v >> 5;
1370             break;
1371          case 6:
1372             B = ((v & 0x3E) << 3) | ((v >> 5) & 0x1);
1373             C = 5;
1374             D = v >> 6;
1375             break;
1376          default:
1377             unreachable("");
1378          }
1379          uint16_t T = D * C + B;
1380          T = T ^ A;
1381          T = (A & 0x80) | (T >> 2);
1382          assert(T < 256);
1383          colour_endpoints[i] = T;
1384       } else if (ce_quints) {
1385          uint16_t A, B, C, D;
1386          uint16_t t;
1387          A = (v & 0x1) ? 0x1FF : 0x000;
1388          switch (ce_bits) {
1389          case 1:
1390             B = 0;
1391             C = 113;
1392             D = v >> 1;
1393             break;
1394          case 2:
1395             B = (v & 0x2) ? 0x10C : 0x000;
1396             C = 54;
1397             D = v >> 2;
1398             break;
1399          case 3:
1400             t = ((v >> 1) & 0x3);
1401             B = (t >> 1) | (t << 1) | (t << 7);
1402             C = 26;
1403             D = v >> 3;
1404             break;
1405          case 4:
1406             t = ((v >> 1) & 0x7);
1407             B = (t >> 1) | (t << 6);
1408             C = 13;
1409             D = v >> 4;
1410             break;
1411          case 5:
1412             t = ((v >> 1) & 0xF);
1413             B = (t >> 4) | (t << 5);
1414             C = 6;
1415             D = v >> 5;
1416             break;
1417          default:
1418             unreachable("");
1419          }
1420          uint16_t T = D * C + B;
1421          T = T ^ A;
1422          T = (A & 0x80) | (T >> 2);
1423          assert(T < 256);
1424          colour_endpoints[i] = T;
1425       } else {
1426          switch (ce_bits) {
1427          case 1: v = v ? 0xFF : 0x00; break;
1428          case 2: v = (v << 6) | (v << 4) | (v << 2) | v; break;
1429          case 3: v = (v << 5) | (v << 2) | (v >> 1); break;
1430          case 4: v = (v << 4) | v; break;
1431          case 5: v = (v << 3) | (v >> 2); break;
1432          case 6: v = (v << 2) | (v >> 4); break;
1433          case 7: v = (v << 1) | (v >> 6); break;
1434          case 8: break;
1435          default: unreachable("");
1436          }
1437          colour_endpoints[i] = v;
1438       }
1439    }
1440 }
1441 
decode(const Decoder & decoder,InputBitVector in)1442 decode_error::type Block::decode(const Decoder &decoder, InputBitVector in)
1443 {
1444    decode_error::type err;
1445 
1446    is_error = false;
1447    bogus_colour_endpoints = false;
1448    bogus_weights = false;
1449    is_void_extent = false;
1450 
1451    wt_d = 1;
1452    /* TODO: 3D */
1453 
1454    /* TODO: test for all the illegal encodings */
1455 
1456    if (VERBOSE_DECODE)
1457       in.printf_bits(0, 128);
1458 
1459    err = decode_block_mode(in);
1460    if (err != decode_error::ok)
1461       return err;
1462 
1463    if (is_void_extent)
1464       return decode_error::ok;
1465 
1466    /* TODO: 3D */
1467 
1468    calculate_from_weights();
1469 
1470    if (VERBOSE_DECODE)
1471       printf("weights_grid=%dx%dx%d dual_plane=%d num_weights=%d high_prec=%d r=%d range=0..%d (%dt %dq %db) weight_bits=%d\n",
1472              wt_w, wt_h, wt_d, dual_plane, num_weights, high_prec, wt_range, wt_max, wt_trits, wt_quints, wt_bits, weight_bits);
1473 
1474    if (wt_w > decoder.block_w || wt_h > decoder.block_h || wt_d > decoder.block_d)
1475       return decode_error::weight_grid_exceeds_block_size;
1476 
1477    num_parts = in.get_bits(11, 2) + 1;
1478 
1479    if (VERBOSE_DECODE)
1480       in.printf_bits(11, 2, "partitions = %d", num_parts);
1481 
1482    if (dual_plane && num_parts > 3)
1483       return decode_error::dual_plane_and_too_many_partitions;
1484 
1485    decode_cem(in);
1486 
1487    if (VERBOSE_DECODE)
1488       printf("cem=[%d,%d,%d,%d] base_cem_class=%d\n", cems[0], cems[1], cems[2], cems[3], cem_base_class);
1489 
1490    int num_cem_pairs = (cem_base_class + 1) * num_parts + extra_cem_bits;
1491    num_cem_values = num_cem_pairs * 2;
1492 
1493    calculate_remaining_bits();
1494    err = calculate_colour_endpoints_size();
1495    if (err != decode_error::ok)
1496       return err;
1497 
1498    if (VERBOSE_DECODE)
1499       in.printf_bits(colour_endpoint_data_offset, colour_endpoint_bits,
1500                      "endpoint data (%d bits, %d vals, %dt %dq %db)",
1501                      colour_endpoint_bits, num_cem_values, ce_trits, ce_quints, ce_bits);
1502 
1503    unpack_colour_endpoints(in);
1504 
1505    if (VERBOSE_DECODE) {
1506       printf("cem values raw =[");
1507       for (int i = 0; i < num_cem_values; i++) {
1508          if (i)
1509             printf(", ");
1510          printf("%3d", colour_endpoints_quant[i]);
1511       }
1512       printf("]\n");
1513    }
1514 
1515    if (num_cem_values > 18)
1516       return decode_error::invalid_colour_endpoints_count;
1517 
1518    unquantise_colour_endpoints();
1519 
1520    if (VERBOSE_DECODE) {
1521       printf("cem values norm=[");
1522       for (int i = 0; i < num_cem_values; i++) {
1523          if (i)
1524             printf(", ");
1525          printf("%3d", colour_endpoints[i]);
1526       }
1527       printf("]\n");
1528    }
1529 
1530    decode_colour_endpoints();
1531 
1532    if (dual_plane) {
1533       int ccs_offset = 128 - weight_bits - num_extra_cem_bits - 2;
1534       colour_component_selector = in.get_bits(ccs_offset, 2);
1535 
1536       if (VERBOSE_DECODE)
1537          in.printf_bits(ccs_offset, 2, "colour component selector = %d", colour_component_selector);
1538    } else {
1539       colour_component_selector = 0;
1540    }
1541 
1542 
1543    if (VERBOSE_DECODE)
1544       in.printf_bits(128 - weight_bits, weight_bits, "weights (%d bits)", weight_bits);
1545 
1546    if (num_weights > 64)
1547       return decode_error::invalid_num_weights;
1548 
1549    if (weight_bits < 24 || weight_bits > 96)
1550       return decode_error::invalid_weight_bits;
1551 
1552    unpack_weights(in);
1553 
1554    unquantise_weights();
1555 
1556    if (VERBOSE_DECODE) {
1557       printf("weights=[");
1558       for (int i = 0; i < num_weights; ++i) {
1559          if (i)
1560             printf(", ");
1561          printf("%d", weights[i]);
1562       }
1563       printf("]\n");
1564 
1565       for (int plane = 0; plane <= dual_plane; ++plane) {
1566          printf("weights (plane %d):\n", plane);
1567          int i = 0;
1568          (void)i;
1569 
1570          for (int r = 0; r < wt_d; ++r) {
1571             for (int t = 0; t < wt_h; ++t) {
1572                for (int s = 0; s < wt_w; ++s) {
1573                   printf("%3d", weights[i++ * (1 + dual_plane) + plane]);
1574                }
1575                printf("\n");
1576             }
1577             if (r < wt_d - 1)
1578                printf("\n");
1579          }
1580       }
1581    }
1582 
1583    compute_infill_weights(decoder.block_w, decoder.block_h, decoder.block_d);
1584 
1585    if (VERBOSE_DECODE) {
1586       for (int plane = 0; plane <= dual_plane; ++plane) {
1587          printf("infilled weights (plane %d):\n", plane);
1588          int i = 0;
1589          (void)i;
1590 
1591          for (int r = 0; r < decoder.block_d; ++r) {
1592             for (int t = 0; t < decoder.block_h; ++t) {
1593                for (int s = 0; s < decoder.block_w; ++s) {
1594                   printf("%3d", infill_weights[plane][i++]);
1595                }
1596                printf("\n");
1597             }
1598             if (r < decoder.block_d - 1)
1599                printf("\n");
1600          }
1601       }
1602    }
1603    if (VERBOSE_DECODE)
1604       printf("\n");
1605 
1606    return decode_error::ok;
1607 }
1608 
write_decoded(const Decoder & decoder,uint16_t * output)1609 void Block::write_decoded(const Decoder &decoder, uint16_t *output)
1610 {
1611    /* sRGB can only be stored as unorm8. */
1612    assert(!decoder.srgb || decoder.output_unorm8);
1613 
1614    if (is_void_extent) {
1615       for (int idx = 0; idx < decoder.block_w*decoder.block_h*decoder.block_d; ++idx) {
1616          if (decoder.output_unorm8) {
1617             if (decoder.srgb) {
1618                output[idx*4+0] = void_extent_colour_r >> 8;
1619                output[idx*4+1] = void_extent_colour_g >> 8;
1620                output[idx*4+2] = void_extent_colour_b >> 8;
1621             } else {
1622                output[idx*4+0] = uint16_div_64k_to_half_to_unorm8(void_extent_colour_r);
1623                output[idx*4+1] = uint16_div_64k_to_half_to_unorm8(void_extent_colour_g);
1624                output[idx*4+2] = uint16_div_64k_to_half_to_unorm8(void_extent_colour_b);
1625             }
1626             output[idx*4+3] = uint16_div_64k_to_half_to_unorm8(void_extent_colour_a);
1627          } else {
1628             /* Store the color as FP16. */
1629             output[idx*4+0] = _mesa_uint16_div_64k_to_half(void_extent_colour_r);
1630             output[idx*4+1] = _mesa_uint16_div_64k_to_half(void_extent_colour_g);
1631             output[idx*4+2] = _mesa_uint16_div_64k_to_half(void_extent_colour_b);
1632             output[idx*4+3] = _mesa_uint16_div_64k_to_half(void_extent_colour_a);
1633          }
1634       }
1635       return;
1636    }
1637 
1638    int small_block = (decoder.block_w * decoder.block_h * decoder.block_d) < 31;
1639 
1640    int idx = 0;
1641    for (int z = 0; z < decoder.block_d; ++z) {
1642       for (int y = 0; y < decoder.block_h; ++y) {
1643          for (int x = 0; x < decoder.block_w; ++x) {
1644 
1645             int partition;
1646             if (num_parts > 1) {
1647                partition = select_partition(partition_index, x, y, z, num_parts, small_block);
1648                assert(partition < num_parts);
1649             } else {
1650                partition = 0;
1651             }
1652 
1653             /* TODO: HDR */
1654 
1655             uint8x4_t e0 = endpoints_decoded[0][partition];
1656             uint8x4_t e1 = endpoints_decoded[1][partition];
1657             uint16_t c0[4], c1[4];
1658 
1659             /* Expand to 16 bits. */
1660             if (decoder.srgb) {
1661                c0[0] = (uint16_t)((e0.v[0] << 8) | 0x80);
1662                c0[1] = (uint16_t)((e0.v[1] << 8) | 0x80);
1663                c0[2] = (uint16_t)((e0.v[2] << 8) | 0x80);
1664                c0[3] = (uint16_t)((e0.v[3] << 8) | 0x80);
1665 
1666                c1[0] = (uint16_t)((e1.v[0] << 8) | 0x80);
1667                c1[1] = (uint16_t)((e1.v[1] << 8) | 0x80);
1668                c1[2] = (uint16_t)((e1.v[2] << 8) | 0x80);
1669                c1[3] = (uint16_t)((e1.v[3] << 8) | 0x80);
1670             } else {
1671                c0[0] = (uint16_t)((e0.v[0] << 8) | e0.v[0]);
1672                c0[1] = (uint16_t)((e0.v[1] << 8) | e0.v[1]);
1673                c0[2] = (uint16_t)((e0.v[2] << 8) | e0.v[2]);
1674                c0[3] = (uint16_t)((e0.v[3] << 8) | e0.v[3]);
1675 
1676                c1[0] = (uint16_t)((e1.v[0] << 8) | e1.v[0]);
1677                c1[1] = (uint16_t)((e1.v[1] << 8) | e1.v[1]);
1678                c1[2] = (uint16_t)((e1.v[2] << 8) | e1.v[2]);
1679                c1[3] = (uint16_t)((e1.v[3] << 8) | e1.v[3]);
1680             }
1681 
1682             int w[4];
1683             if (dual_plane) {
1684                int w0 = infill_weights[0][idx];
1685                int w1 = infill_weights[1][idx];
1686                w[0] = w[1] = w[2] = w[3] = w0;
1687                w[colour_component_selector] = w1;
1688             } else {
1689                int w0 = infill_weights[0][idx];
1690                w[0] = w[1] = w[2] = w[3] = w0;
1691             }
1692 
1693             /* Interpolate to produce UNORM16, applying weights. */
1694             uint16_t c[4] = {
1695                (uint16_t)((c0[0] * (64 - w[0]) + c1[0] * w[0] + 32) >> 6),
1696                (uint16_t)((c0[1] * (64 - w[1]) + c1[1] * w[1] + 32) >> 6),
1697                (uint16_t)((c0[2] * (64 - w[2]) + c1[2] * w[2] + 32) >> 6),
1698                (uint16_t)((c0[3] * (64 - w[3]) + c1[3] * w[3] + 32) >> 6),
1699             };
1700 
1701             if (decoder.output_unorm8) {
1702                if (decoder.srgb) {
1703                   output[idx*4+0] = c[0] >> 8;
1704                   output[idx*4+1] = c[1] >> 8;
1705                   output[idx*4+2] = c[2] >> 8;
1706                } else {
1707                   output[idx*4+0] = c[0] == 65535 ? 0xff : uint16_div_64k_to_half_to_unorm8(c[0]);
1708                   output[idx*4+1] = c[1] == 65535 ? 0xff : uint16_div_64k_to_half_to_unorm8(c[1]);
1709                   output[idx*4+2] = c[2] == 65535 ? 0xff : uint16_div_64k_to_half_to_unorm8(c[2]);
1710                }
1711                output[idx*4+3] = c[3] == 65535 ? 0xff : uint16_div_64k_to_half_to_unorm8(c[3]);
1712             } else {
1713                /* Store the color as FP16. */
1714                output[idx*4+0] = c[0] == 65535 ? FP16_ONE : _mesa_uint16_div_64k_to_half(c[0]);
1715                output[idx*4+1] = c[1] == 65535 ? FP16_ONE : _mesa_uint16_div_64k_to_half(c[1]);
1716                output[idx*4+2] = c[2] == 65535 ? FP16_ONE : _mesa_uint16_div_64k_to_half(c[2]);
1717                output[idx*4+3] = c[3] == 65535 ? FP16_ONE : _mesa_uint16_div_64k_to_half(c[3]);
1718             }
1719 
1720             idx++;
1721          }
1722       }
1723    }
1724 }
1725 
calculate_from_weights()1726 void Block::calculate_from_weights()
1727 {
1728    wt_trits = 0;
1729    wt_quints = 0;
1730    wt_bits = 0;
1731    switch (high_prec) {
1732    case 0:
1733       switch (wt_range) {
1734       case 0x2: wt_max = 1; wt_bits = 1; break;
1735       case 0x3: wt_max = 2; wt_trits = 1; break;
1736       case 0x4: wt_max = 3; wt_bits = 2; break;
1737       case 0x5: wt_max = 4; wt_quints = 1; break;
1738       case 0x6: wt_max = 5; wt_trits = 1; wt_bits = 1; break;
1739       case 0x7: wt_max = 7; wt_bits = 3; break;
1740       default: abort();
1741       }
1742       break;
1743    case 1:
1744       switch (wt_range) {
1745       case 0x2: wt_max = 9; wt_quints = 1; wt_bits = 1; break;
1746       case 0x3: wt_max = 11; wt_trits = 1; wt_bits = 2; break;
1747       case 0x4: wt_max = 15; wt_bits = 4; break;
1748       case 0x5: wt_max = 19; wt_quints = 1; wt_bits = 2; break;
1749       case 0x6: wt_max = 23; wt_trits = 1; wt_bits = 3; break;
1750       case 0x7: wt_max = 31; wt_bits = 5; break;
1751       default: abort();
1752       }
1753       break;
1754    }
1755 
1756    assert(wt_trits || wt_quints || wt_bits);
1757 
1758    num_weights = wt_w * wt_h * wt_d;
1759 
1760    if (dual_plane)
1761       num_weights *= 2;
1762 
1763    weight_bits =
1764          (num_weights * 8 * wt_trits + 4) / 5
1765          + (num_weights * 7 * wt_quints + 2) / 3
1766          +  num_weights * wt_bits;
1767 }
1768 
calculate_remaining_bits()1769 void Block::calculate_remaining_bits()
1770 {
1771    int config_bits;
1772    if (num_parts > 1) {
1773       if (!is_multi_cem)
1774          config_bits = 29;
1775       else
1776          config_bits = 25 + 3 * num_parts;
1777    } else {
1778       config_bits = 17;
1779    }
1780 
1781    if (dual_plane)
1782       config_bits += 2;
1783 
1784    remaining_bits = 128 - config_bits - weight_bits;
1785 }
1786 
calculate_colour_endpoints_size()1787 decode_error::type Block::calculate_colour_endpoints_size()
1788 {
1789    /* Specified as illegal */
1790    if (remaining_bits < (13 * num_cem_values + 4) / 5) {
1791       colour_endpoint_bits = ce_max = ce_trits = ce_quints = ce_bits = 0;
1792       return decode_error::invalid_colour_endpoints_size;
1793    }
1794 
1795    /* Find the largest cem_ranges that fits within remaining_bits */
1796    for (int i = ARRAY_SIZE(cem_ranges)-1; i >= 0; --i) {
1797       int cem_bits;
1798       cem_bits = (num_cem_values * 8 * cem_ranges[i].t + 4) / 5
1799                  + (num_cem_values * 7 * cem_ranges[i].q + 2) / 3
1800                  +  num_cem_values * cem_ranges[i].b;
1801 
1802       if (cem_bits <= remaining_bits)
1803       {
1804          colour_endpoint_bits = cem_bits;
1805          ce_max = cem_ranges[i].max;
1806          ce_trits = cem_ranges[i].t;
1807          ce_quints = cem_ranges[i].q;
1808          ce_bits = cem_ranges[i].b;
1809          return decode_error::ok;
1810       }
1811    }
1812 
1813    assert(0);
1814    return decode_error::invalid_colour_endpoints_size;
1815 }
1816 
1817 /**
1818  * Decode ASTC 2D LDR texture data.
1819  *
1820  * \param src_width in pixels
1821  * \param src_height in pixels
1822  * \param dst_stride in bytes
1823  */
1824 extern "C" void
_mesa_unpack_astc_2d_ldr(uint8_t * dst_row,unsigned dst_stride,const uint8_t * src_row,unsigned src_stride,unsigned src_width,unsigned src_height,mesa_format format)1825 _mesa_unpack_astc_2d_ldr(uint8_t *dst_row,
1826                          unsigned dst_stride,
1827                          const uint8_t *src_row,
1828                          unsigned src_stride,
1829                          unsigned src_width,
1830                          unsigned src_height,
1831                          mesa_format format)
1832 {
1833    assert(_mesa_is_format_astc_2d(format));
1834    bool srgb = _mesa_is_format_srgb(format);
1835 
1836    unsigned blk_w, blk_h;
1837    _mesa_get_format_block_size(format, &blk_w, &blk_h);
1838 
1839    const unsigned block_size = 16;
1840    unsigned x_blocks = (src_width + blk_w - 1) / blk_w;
1841    unsigned y_blocks = (src_height + blk_h - 1) / blk_h;
1842 
1843    Decoder dec(blk_w, blk_h, 1, srgb, true);
1844 
1845    for (unsigned y = 0; y < y_blocks; ++y) {
1846       for (unsigned x = 0; x < x_blocks; ++x) {
1847          /* Same size as the largest block. */
1848          uint16_t block_out[12 * 12 * 4];
1849 
1850          dec.decode(src_row + x * block_size, block_out);
1851 
1852          /* This can be smaller with NPOT dimensions. */
1853          unsigned dst_blk_w = MIN2(blk_w, src_width  - x*blk_w);
1854          unsigned dst_blk_h = MIN2(blk_h, src_height - y*blk_h);
1855 
1856          for (unsigned sub_y = 0; sub_y < dst_blk_h; ++sub_y) {
1857             for (unsigned sub_x = 0; sub_x < dst_blk_w; ++sub_x) {
1858                uint8_t *dst = dst_row + sub_y * dst_stride +
1859                               (x * blk_w + sub_x) * 4;
1860                const uint16_t *src = &block_out[(sub_y * blk_w + sub_x) * 4];
1861 
1862                dst[0] = src[0];
1863                dst[1] = src[1];
1864                dst[2] = src[2];
1865                dst[3] = src[3];
1866             }
1867          }
1868       }
1869       src_row += src_stride;
1870       dst_row += dst_stride * blk_h;
1871    }
1872 }
1873