1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_
7 #define LIB_JXL_QUANT_WEIGHTS_H_
8 
9 #include <stdint.h>
10 #include <string.h>
11 
12 #include <array>
13 #include <hwy/aligned_allocator.h>
14 #include <utility>
15 #include <vector>
16 
17 #include "lib/jxl/ac_strategy.h"
18 #include "lib/jxl/aux_out_fwd.h"
19 #include "lib/jxl/base/cache_aligned.h"
20 #include "lib/jxl/base/compiler_specific.h"
21 #include "lib/jxl/base/span.h"
22 #include "lib/jxl/base/status.h"
23 #include "lib/jxl/common.h"
24 #include "lib/jxl/dec_bit_reader.h"
25 #include "lib/jxl/image.h"
26 
27 namespace jxl {
28 
29 template <typename T, size_t N>
30 constexpr T ArraySum(T (&a)[N], size_t i = N - 1) {
31   static_assert(N > 0, "Trying to compute the sum of an empty array");
32   return i == 0 ? a[0] : a[i] + ArraySum(a, i - 1);
33 }
34 
35 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea;
36 static constexpr size_t kNumPredefinedTables = 1;
37 static constexpr size_t kCeilLog2NumPredefinedTables = 0;
38 static constexpr size_t kLog2NumQuantModes = 3;
39 
40 struct DctQuantWeightParams {
41   static constexpr size_t kLog2MaxDistanceBands = 4;
42   static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands);
43   typedef std::array<std::array<float, kMaxDistanceBands>, 3>
44       DistanceBandsArray;
45 
46   size_t num_distance_bands = 0;
47   DistanceBandsArray distance_bands = {};
48 
DctQuantWeightParamsDctQuantWeightParams49   constexpr DctQuantWeightParams() : num_distance_bands(0) {}
50 
DctQuantWeightParamsDctQuantWeightParams51   constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands,
52                                  size_t num_dist_bands)
53       : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {}
54 
55   template <size_t num_dist_bands>
DctQuantWeightParamsDctQuantWeightParams56   explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) {
57     num_distance_bands = num_dist_bands;
58     for (size_t c = 0; c < 3; c++) {
59       memcpy(distance_bands[c].data(), dist_bands[c],
60              sizeof(float) * num_dist_bands);
61     }
62   }
63 };
64 
65 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding)
66 struct QuantEncodingInternal {
67   enum Mode {
68     kQuantModeLibrary,
69     kQuantModeID,
70     kQuantModeDCT2,
71     kQuantModeDCT4,
72     kQuantModeDCT4X8,
73     kQuantModeAFV,
74     kQuantModeDCT,
75     kQuantModeRAW,
76   };
77 
78   template <Mode mode>
79   struct Tag {};
80 
81   typedef std::array<std::array<float, 3>, 3> IdWeights;
82   typedef std::array<std::array<float, 6>, 3> DCT2Weights;
83   typedef std::array<std::array<float, 2>, 3> DCT4Multipliers;
84   typedef std::array<std::array<float, 9>, 3> AFVWeights;
85   typedef std::array<float, 3> DCT4x8Multipliers;
86 
LibraryQuantEncodingInternal87   static constexpr QuantEncodingInternal Library(uint8_t predefined) {
88     return ((predefined < kNumPredefinedTables) ||
89             JXL_ABORT("Assert predefined < kNumPredefinedTables")),
90            QuantEncodingInternal(Tag<kQuantModeLibrary>(), predefined);
91   }
QuantEncodingInternalQuantEncodingInternal92   constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */,
93                                   uint8_t predefined)
94       : mode(kQuantModeLibrary), predefined(predefined) {}
95 
96   // Identity
97   // xybweights is an array of {xweights, yweights, bweights}.
IdentityQuantEncodingInternal98   static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) {
99     return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights);
100   }
QuantEncodingInternalQuantEncodingInternal101   constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */,
102                                   const IdWeights& xybweights)
103       : mode(kQuantModeID), idweights(xybweights) {}
104 
105   // DCT2
DCT2QuantEncodingInternal106   static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) {
107     return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights);
108   }
QuantEncodingInternalQuantEncodingInternal109   constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */,
110                                   const DCT2Weights& xybweights)
111       : mode(kQuantModeDCT2), dct2weights(xybweights) {}
112 
113   // DCT4
DCT4QuantEncodingInternal114   static constexpr QuantEncodingInternal DCT4(
115       const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) {
116     return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul);
117   }
QuantEncodingInternalQuantEncodingInternal118   constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */,
119                                   const DctQuantWeightParams& params,
120                                   const DCT4Multipliers& xybmul)
121       : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {}
122 
123   // DCT4x8
DCT4X8QuantEncodingInternal124   static constexpr QuantEncodingInternal DCT4X8(
125       const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) {
126     return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul);
127   }
QuantEncodingInternalQuantEncodingInternal128   constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */,
129                                   const DctQuantWeightParams& params,
130                                   const DCT4x8Multipliers& xybmul)
131       : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {}
132 
133   // DCT
DCTQuantEncodingInternal134   static constexpr QuantEncodingInternal DCT(
135       const DctQuantWeightParams& params) {
136     return QuantEncodingInternal(Tag<kQuantModeDCT>(), params);
137   }
QuantEncodingInternalQuantEncodingInternal138   constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */,
139                                   const DctQuantWeightParams& params)
140       : mode(kQuantModeDCT), dct_params(params) {}
141 
142   // AFV
AFVQuantEncodingInternal143   static constexpr QuantEncodingInternal AFV(
144       const DctQuantWeightParams& params4x8,
145       const DctQuantWeightParams& params4x4, const AFVWeights& weights) {
146     return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4,
147                                  weights);
148   }
QuantEncodingInternalQuantEncodingInternal149   constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */,
150                                   const DctQuantWeightParams& params4x8,
151                                   const DctQuantWeightParams& params4x4,
152                                   const AFVWeights& weights)
153       : mode(kQuantModeAFV),
154         dct_params(params4x8),
155         afv_weights(weights),
156         dct_params_afv_4x4(params4x4) {}
157 
158   // This constructor is not constexpr so it can't be used in any of the
159   // constexpr cases above.
QuantEncodingInternalQuantEncodingInternal160   explicit QuantEncodingInternal(Mode mode) : mode(mode) {}
161 
162   Mode mode;
163 
164   // Weights for DCT4+ tables.
165   DctQuantWeightParams dct_params;
166 
167   union {
168     // Weights for identity.
169     IdWeights idweights;
170 
171     // Weights for DCT2.
172     DCT2Weights dct2weights;
173 
174     // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV.
175     DCT4Multipliers dct4multipliers;
176 
177     // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1,
178     // 0);  {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) +
179     // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated
180     // as in GetQuantWeights for DC and are used for other coefficients.
181     AFVWeights afv_weights = {};
182 
183     // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4.
184     DCT4x8Multipliers dct4x8multipliers;
185 
186     // Only used in kQuantModeRAW mode.
187     struct {
188       // explicit quantization table (like in JPEG)
189       std::vector<int>* qtable = nullptr;
190       float qtable_den = 1.f / (8 * 255);
191     } qraw;
192   };
193 
194   // Weights for 4x4 sub-block in AFV.
195   DctQuantWeightParams dct_params_afv_4x4;
196 
197   union {
198     // Which predefined table to use. Only used if mode is kQuantModeLibrary.
199     uint8_t predefined = 0;
200 
201     // Which other quant table to copy; must copy from a table that comes before
202     // the current one. Only used if mode is kQuantModeCopy.
203     uint8_t source;
204   };
205 };
206 
207 class QuantEncoding final : public QuantEncodingInternal {
208  public:
QuantEncoding(const QuantEncoding & other)209   QuantEncoding(const QuantEncoding& other)
210       : QuantEncodingInternal(
211             static_cast<const QuantEncodingInternal&>(other)) {
212     if (mode == kQuantModeRAW && qraw.qtable) {
213       // Need to make a copy of the passed *qtable.
214       qraw.qtable = new std::vector<int>(*other.qraw.qtable);
215     }
216   }
QuantEncoding(QuantEncoding && other)217   QuantEncoding(QuantEncoding&& other) noexcept
218       : QuantEncodingInternal(
219             static_cast<const QuantEncodingInternal&>(other)) {
220     // Steal the qtable from the other object if any.
221     if (mode == kQuantModeRAW) {
222       other.qraw.qtable = nullptr;
223     }
224   }
225   QuantEncoding& operator=(const QuantEncoding& other) {
226     if (mode == kQuantModeRAW && qraw.qtable) {
227       delete qraw.qtable;
228     }
229     *static_cast<QuantEncodingInternal*>(this) =
230         QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other));
231     if (mode == kQuantModeRAW && qraw.qtable) {
232       // Need to make a copy of the passed *qtable.
233       qraw.qtable = new std::vector<int>(*other.qraw.qtable);
234     }
235     return *this;
236   }
237 
~QuantEncoding()238   ~QuantEncoding() {
239     if (mode == kQuantModeRAW && qraw.qtable) {
240       delete qraw.qtable;
241     }
242   }
243 
244   // Wrappers of the QuantEncodingInternal:: static functions that return a
245   // QuantEncoding instead. This is using the explicit and private cast from
246   // QuantEncodingInternal to QuantEncoding, which would be inlined anyway.
247   // In general, you should use this wrappers. The only reason to directly
248   // create a QuantEncodingInternal instance is if you need a constexpr version
249   // of this class. Note that RAW() is not supported in that case since it uses
250   // a std::vector.
Library(uint8_t predefined)251   static QuantEncoding Library(uint8_t predefined) {
252     return QuantEncoding(QuantEncodingInternal::Library(predefined));
253   }
Identity(const IdWeights & xybweights)254   static QuantEncoding Identity(const IdWeights& xybweights) {
255     return QuantEncoding(QuantEncodingInternal::Identity(xybweights));
256   }
DCT2(const DCT2Weights & xybweights)257   static QuantEncoding DCT2(const DCT2Weights& xybweights) {
258     return QuantEncoding(QuantEncodingInternal::DCT2(xybweights));
259   }
DCT4(const DctQuantWeightParams & params,const DCT4Multipliers & xybmul)260   static QuantEncoding DCT4(const DctQuantWeightParams& params,
261                             const DCT4Multipliers& xybmul) {
262     return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul));
263   }
DCT4X8(const DctQuantWeightParams & params,const DCT4x8Multipliers & xybmul)264   static QuantEncoding DCT4X8(const DctQuantWeightParams& params,
265                               const DCT4x8Multipliers& xybmul) {
266     return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul));
267   }
DCT(const DctQuantWeightParams & params)268   static QuantEncoding DCT(const DctQuantWeightParams& params) {
269     return QuantEncoding(QuantEncodingInternal::DCT(params));
270   }
AFV(const DctQuantWeightParams & params4x8,const DctQuantWeightParams & params4x4,const AFVWeights & weights)271   static QuantEncoding AFV(const DctQuantWeightParams& params4x8,
272                            const DctQuantWeightParams& params4x4,
273                            const AFVWeights& weights) {
274     return QuantEncoding(
275         QuantEncodingInternal::AFV(params4x8, params4x4, weights));
276   }
277 
278   // RAW, note that this one is not a constexpr one.
279   static QuantEncoding RAW(const std::vector<int>& qtable, int shift = 0) {
280     QuantEncoding encoding(kQuantModeRAW);
281     encoding.qraw.qtable = new std::vector<int>();
282     *encoding.qraw.qtable = qtable;
283     encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255));
284     return encoding;
285   }
286 
287  private:
QuantEncoding(const QuantEncodingInternal & other)288   explicit QuantEncoding(const QuantEncodingInternal& other)
289       : QuantEncodingInternal(other) {}
290 
QuantEncoding(QuantEncodingInternal::Mode mode)291   explicit QuantEncoding(QuantEncodingInternal::Mode mode)
292       : QuantEncodingInternal(mode) {}
293 };
294 
295 // A constexpr QuantEncodingInternal instance is often downcasted to the
296 // QuantEncoding subclass even if the instance wasn't an instance of the
297 // subclass. This is safe because user will upcast to QuantEncodingInternal to
298 // access any of its members.
299 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal),
300               "Don't add any members to QuantEncoding");
301 
302 // Let's try to keep these 2**N for possible future simplicity.
303 const float kInvDCQuant[3] = {
304     4096.0f,
305     512.0f,
306     256.0f,
307 };
308 
309 const float kDCQuant[3] = {
310     1.0f / kInvDCQuant[0],
311     1.0f / kInvDCQuant[1],
312     1.0f / kInvDCQuant[2],
313 };
314 
315 class ModularFrameEncoder;
316 class ModularFrameDecoder;
317 
318 class DequantMatrices {
319  public:
320   enum QuantTable : size_t {
321     DCT = 0,
322     IDENTITY,
323     DCT2X2,
324     DCT4X4,
325     DCT16X16,
326     DCT32X32,
327     // DCT16X8
328     DCT8X16,
329     // DCT32X8
330     DCT8X32,
331     // DCT32X16
332     DCT16X32,
333     DCT4X8,
334     // DCT8X4
335     AFV0,
336     // AFV1
337     // AFV2
338     // AFV3
339     DCT64X64,
340     // DCT64X32,
341     DCT32X64,
342     DCT128X128,
343     // DCT128X64,
344     DCT64X128,
345     DCT256X256,
346     // DCT256X128,
347     DCT128X256,
348     kNum
349   };
350 
351   static constexpr QuantTable kQuantTable[] = {
352       QuantTable::DCT,        QuantTable::IDENTITY,   QuantTable::DCT2X2,
353       QuantTable::DCT4X4,     QuantTable::DCT16X16,   QuantTable::DCT32X32,
354       QuantTable::DCT8X16,    QuantTable::DCT8X16,    QuantTable::DCT8X32,
355       QuantTable::DCT8X32,    QuantTable::DCT16X32,   QuantTable::DCT16X32,
356       QuantTable::DCT4X8,     QuantTable::DCT4X8,     QuantTable::AFV0,
357       QuantTable::AFV0,       QuantTable::AFV0,       QuantTable::AFV0,
358       QuantTable::DCT64X64,   QuantTable::DCT32X64,   QuantTable::DCT32X64,
359       QuantTable::DCT128X128, QuantTable::DCT64X128,  QuantTable::DCT64X128,
360       QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256,
361   };
362   static_assert(AcStrategy::kNumValidStrategies ==
363                     sizeof(kQuantTable) / sizeof *kQuantTable,
364                 "Update this array when adding or removing AC strategies.");
365 
366   DequantMatrices();
367 
368   static const QuantEncoding* Library();
369 
370   typedef std::array<QuantEncodingInternal, kNumPredefinedTables * kNum>
371       DequantLibraryInternal;
372   // Return the array of library kNumPredefinedTables QuantEncoding entries as
373   // a constexpr array. Use Library() to obtain a pointer to the copy in the
374   // .cc file.
375   static const DequantLibraryInternal LibraryInit();
376 
377   // Returns aligned memory.
Matrix(size_t quant_kind,size_t c)378   JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const {
379     JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies);
380     JXL_DASSERT((1 << quant_kind) & computed_mask_);
381     return &table_[table_offsets_[quant_kind * 3 + c]];
382   }
383 
InvMatrix(size_t quant_kind,size_t c)384   JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const {
385     JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies);
386     JXL_DASSERT((1 << quant_kind) & computed_mask_);
387     return &inv_table_[table_offsets_[quant_kind * 3 + c]];
388   }
389 
390   // DC quants are used in modular mode for XYB multipliers.
DCQuant(size_t c)391   JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; }
DCQuants()392   JXL_INLINE const float* DCQuants() const { return dc_quant_; }
393 
InvDCQuant(size_t c)394   JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; }
395 
396   // For encoder.
SetEncodings(const std::vector<QuantEncoding> & encodings)397   void SetEncodings(const std::vector<QuantEncoding>& encodings) {
398     encodings_ = encodings;
399     computed_mask_ = 0;
400   }
401 
402   // For encoder.
SetDCQuant(const float dc[3])403   void SetDCQuant(const float dc[3]) {
404     for (size_t c = 0; c < 3; c++) {
405       dc_quant_[c] = 1.0f / dc[c];
406       inv_dc_quant_[c] = dc[c];
407     }
408   }
409 
410   Status Decode(BitReader* br,
411                 ModularFrameDecoder* modular_frame_decoder = nullptr);
412   Status DecodeDC(BitReader* br);
413 
encodings()414   const std::vector<QuantEncoding>& encodings() const { return encodings_; }
415 
416   static constexpr size_t required_size_x[] = {1, 1, 1, 1, 2,  4, 1,  1, 2,
417                                                1, 1, 8, 4, 16, 8, 32, 16};
418   static_assert(kNum == sizeof(required_size_x) / sizeof(*required_size_x),
419                 "Update this array when adding or removing quant tables.");
420 
421   static constexpr size_t required_size_y[] = {1, 1, 1, 1, 2,  4,  2,  4, 4,
422                                                1, 1, 8, 8, 16, 16, 32, 32};
423   static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y),
424                 "Update this array when adding or removing quant tables.");
425 
426   Status EnsureComputed(uint32_t kind_mask);
427 
428  private:
429   static constexpr size_t required_size_[] = {
430       1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512};
431   static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_),
432                 "Update this array when adding or removing quant tables.");
433   static constexpr size_t kTotalTableSize =
434       ArraySum(required_size_) * kDCTBlockSize * 3;
435 
436   uint32_t computed_mask_ = 0;
437   // kTotalTableSize entries followed by kTotalTableSize for inv_table
438   hwy::AlignedFreeUniquePtr<float[]> table_storage_;
439   const float* table_;
440   const float* inv_table_;
441   float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]};
442   float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]};
443   size_t table_offsets_[AcStrategy::kNumValidStrategies * 3];
444   std::vector<QuantEncoding> encodings_;
445 };
446 
447 }  // namespace jxl
448 
449 #endif  // LIB_JXL_QUANT_WEIGHTS_H_
450