1 // Copyright (c) the JPEG XL Project Authors. All rights reserved. 2 // 3 // Use of this source code is governed by a BSD-style 4 // license that can be found in the LICENSE file. 5 6 #ifndef LIB_JXL_QUANT_WEIGHTS_H_ 7 #define LIB_JXL_QUANT_WEIGHTS_H_ 8 9 #include <stdint.h> 10 #include <string.h> 11 12 #include <array> 13 #include <hwy/aligned_allocator.h> 14 #include <utility> 15 #include <vector> 16 17 #include "lib/jxl/ac_strategy.h" 18 #include "lib/jxl/aux_out_fwd.h" 19 #include "lib/jxl/base/cache_aligned.h" 20 #include "lib/jxl/base/compiler_specific.h" 21 #include "lib/jxl/base/span.h" 22 #include "lib/jxl/base/status.h" 23 #include "lib/jxl/common.h" 24 #include "lib/jxl/dec_bit_reader.h" 25 #include "lib/jxl/image.h" 26 27 namespace jxl { 28 29 template <typename T, size_t N> 30 constexpr T ArraySum(T (&a)[N], size_t i = N - 1) { 31 static_assert(N > 0, "Trying to compute the sum of an empty array"); 32 return i == 0 ? a[0] : a[i] + ArraySum(a, i - 1); 33 } 34 35 static constexpr size_t kMaxQuantTableSize = AcStrategy::kMaxCoeffArea; 36 static constexpr size_t kNumPredefinedTables = 1; 37 static constexpr size_t kCeilLog2NumPredefinedTables = 0; 38 static constexpr size_t kLog2NumQuantModes = 3; 39 40 struct DctQuantWeightParams { 41 static constexpr size_t kLog2MaxDistanceBands = 4; 42 static constexpr size_t kMaxDistanceBands = 1 + (1 << kLog2MaxDistanceBands); 43 typedef std::array<std::array<float, kMaxDistanceBands>, 3> 44 DistanceBandsArray; 45 46 size_t num_distance_bands = 0; 47 DistanceBandsArray distance_bands = {}; 48 DctQuantWeightParamsDctQuantWeightParams49 constexpr DctQuantWeightParams() : num_distance_bands(0) {} 50 DctQuantWeightParamsDctQuantWeightParams51 constexpr DctQuantWeightParams(const DistanceBandsArray& dist_bands, 52 size_t num_dist_bands) 53 : num_distance_bands(num_dist_bands), distance_bands(dist_bands) {} 54 55 template <size_t num_dist_bands> DctQuantWeightParamsDctQuantWeightParams56 explicit DctQuantWeightParams(const float dist_bands[3][num_dist_bands]) { 57 num_distance_bands = num_dist_bands; 58 for (size_t c = 0; c < 3; c++) { 59 memcpy(distance_bands[c].data(), dist_bands[c], 60 sizeof(float) * num_dist_bands); 61 } 62 } 63 }; 64 65 // NOLINTNEXTLINE(clang-analyzer-optin.performance.Padding) 66 struct QuantEncodingInternal { 67 enum Mode { 68 kQuantModeLibrary, 69 kQuantModeID, 70 kQuantModeDCT2, 71 kQuantModeDCT4, 72 kQuantModeDCT4X8, 73 kQuantModeAFV, 74 kQuantModeDCT, 75 kQuantModeRAW, 76 }; 77 78 template <Mode mode> 79 struct Tag {}; 80 81 typedef std::array<std::array<float, 3>, 3> IdWeights; 82 typedef std::array<std::array<float, 6>, 3> DCT2Weights; 83 typedef std::array<std::array<float, 2>, 3> DCT4Multipliers; 84 typedef std::array<std::array<float, 9>, 3> AFVWeights; 85 typedef std::array<float, 3> DCT4x8Multipliers; 86 LibraryQuantEncodingInternal87 static constexpr QuantEncodingInternal Library(uint8_t predefined) { 88 return ((predefined < kNumPredefinedTables) || 89 JXL_ABORT("Assert predefined < kNumPredefinedTables")), 90 QuantEncodingInternal(Tag<kQuantModeLibrary>(), predefined); 91 } QuantEncodingInternalQuantEncodingInternal92 constexpr QuantEncodingInternal(Tag<kQuantModeLibrary> /* tag */, 93 uint8_t predefined) 94 : mode(kQuantModeLibrary), predefined(predefined) {} 95 96 // Identity 97 // xybweights is an array of {xweights, yweights, bweights}. IdentityQuantEncodingInternal98 static constexpr QuantEncodingInternal Identity(const IdWeights& xybweights) { 99 return QuantEncodingInternal(Tag<kQuantModeID>(), xybweights); 100 } QuantEncodingInternalQuantEncodingInternal101 constexpr QuantEncodingInternal(Tag<kQuantModeID> /* tag */, 102 const IdWeights& xybweights) 103 : mode(kQuantModeID), idweights(xybweights) {} 104 105 // DCT2 DCT2QuantEncodingInternal106 static constexpr QuantEncodingInternal DCT2(const DCT2Weights& xybweights) { 107 return QuantEncodingInternal(Tag<kQuantModeDCT2>(), xybweights); 108 } QuantEncodingInternalQuantEncodingInternal109 constexpr QuantEncodingInternal(Tag<kQuantModeDCT2> /* tag */, 110 const DCT2Weights& xybweights) 111 : mode(kQuantModeDCT2), dct2weights(xybweights) {} 112 113 // DCT4 DCT4QuantEncodingInternal114 static constexpr QuantEncodingInternal DCT4( 115 const DctQuantWeightParams& params, const DCT4Multipliers& xybmul) { 116 return QuantEncodingInternal(Tag<kQuantModeDCT4>(), params, xybmul); 117 } QuantEncodingInternalQuantEncodingInternal118 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4> /* tag */, 119 const DctQuantWeightParams& params, 120 const DCT4Multipliers& xybmul) 121 : mode(kQuantModeDCT4), dct_params(params), dct4multipliers(xybmul) {} 122 123 // DCT4x8 DCT4X8QuantEncodingInternal124 static constexpr QuantEncodingInternal DCT4X8( 125 const DctQuantWeightParams& params, const DCT4x8Multipliers& xybmul) { 126 return QuantEncodingInternal(Tag<kQuantModeDCT4X8>(), params, xybmul); 127 } QuantEncodingInternalQuantEncodingInternal128 constexpr QuantEncodingInternal(Tag<kQuantModeDCT4X8> /* tag */, 129 const DctQuantWeightParams& params, 130 const DCT4x8Multipliers& xybmul) 131 : mode(kQuantModeDCT4X8), dct_params(params), dct4x8multipliers(xybmul) {} 132 133 // DCT DCTQuantEncodingInternal134 static constexpr QuantEncodingInternal DCT( 135 const DctQuantWeightParams& params) { 136 return QuantEncodingInternal(Tag<kQuantModeDCT>(), params); 137 } QuantEncodingInternalQuantEncodingInternal138 constexpr QuantEncodingInternal(Tag<kQuantModeDCT> /* tag */, 139 const DctQuantWeightParams& params) 140 : mode(kQuantModeDCT), dct_params(params) {} 141 142 // AFV AFVQuantEncodingInternal143 static constexpr QuantEncodingInternal AFV( 144 const DctQuantWeightParams& params4x8, 145 const DctQuantWeightParams& params4x4, const AFVWeights& weights) { 146 return QuantEncodingInternal(Tag<kQuantModeAFV>(), params4x8, params4x4, 147 weights); 148 } QuantEncodingInternalQuantEncodingInternal149 constexpr QuantEncodingInternal(Tag<kQuantModeAFV> /* tag */, 150 const DctQuantWeightParams& params4x8, 151 const DctQuantWeightParams& params4x4, 152 const AFVWeights& weights) 153 : mode(kQuantModeAFV), 154 dct_params(params4x8), 155 afv_weights(weights), 156 dct_params_afv_4x4(params4x4) {} 157 158 // This constructor is not constexpr so it can't be used in any of the 159 // constexpr cases above. QuantEncodingInternalQuantEncodingInternal160 explicit QuantEncodingInternal(Mode mode) : mode(mode) {} 161 162 Mode mode; 163 164 // Weights for DCT4+ tables. 165 DctQuantWeightParams dct_params; 166 167 union { 168 // Weights for identity. 169 IdWeights idweights; 170 171 // Weights for DCT2. 172 DCT2Weights dct2weights; 173 174 // Extra multipliers for coefficients 01/10 and 11 for DCT4 and AFV. 175 DCT4Multipliers dct4multipliers; 176 177 // Weights for AFV. {0, 1} are used directly for coefficients (0, 1) and (1, 178 // 0); {2, 3, 4} are used directly corner DC, (1,0) - (0,1) and (0, 1) + 179 // (1, 0) - (0, 0) inside the AFV block. Values from 5 to 8 are interpolated 180 // as in GetQuantWeights for DC and are used for other coefficients. 181 AFVWeights afv_weights = {}; 182 183 // Extra multipliers for coefficients 01 or 10 for DCT4X8 and DCT8X4. 184 DCT4x8Multipliers dct4x8multipliers; 185 186 // Only used in kQuantModeRAW mode. 187 struct { 188 // explicit quantization table (like in JPEG) 189 std::vector<int>* qtable = nullptr; 190 float qtable_den = 1.f / (8 * 255); 191 } qraw; 192 }; 193 194 // Weights for 4x4 sub-block in AFV. 195 DctQuantWeightParams dct_params_afv_4x4; 196 197 union { 198 // Which predefined table to use. Only used if mode is kQuantModeLibrary. 199 uint8_t predefined = 0; 200 201 // Which other quant table to copy; must copy from a table that comes before 202 // the current one. Only used if mode is kQuantModeCopy. 203 uint8_t source; 204 }; 205 }; 206 207 class QuantEncoding final : public QuantEncodingInternal { 208 public: QuantEncoding(const QuantEncoding & other)209 QuantEncoding(const QuantEncoding& other) 210 : QuantEncodingInternal( 211 static_cast<const QuantEncodingInternal&>(other)) { 212 if (mode == kQuantModeRAW && qraw.qtable) { 213 // Need to make a copy of the passed *qtable. 214 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 215 } 216 } QuantEncoding(QuantEncoding && other)217 QuantEncoding(QuantEncoding&& other) noexcept 218 : QuantEncodingInternal( 219 static_cast<const QuantEncodingInternal&>(other)) { 220 // Steal the qtable from the other object if any. 221 if (mode == kQuantModeRAW) { 222 other.qraw.qtable = nullptr; 223 } 224 } 225 QuantEncoding& operator=(const QuantEncoding& other) { 226 if (mode == kQuantModeRAW && qraw.qtable) { 227 delete qraw.qtable; 228 } 229 *static_cast<QuantEncodingInternal*>(this) = 230 QuantEncodingInternal(static_cast<const QuantEncodingInternal&>(other)); 231 if (mode == kQuantModeRAW && qraw.qtable) { 232 // Need to make a copy of the passed *qtable. 233 qraw.qtable = new std::vector<int>(*other.qraw.qtable); 234 } 235 return *this; 236 } 237 ~QuantEncoding()238 ~QuantEncoding() { 239 if (mode == kQuantModeRAW && qraw.qtable) { 240 delete qraw.qtable; 241 } 242 } 243 244 // Wrappers of the QuantEncodingInternal:: static functions that return a 245 // QuantEncoding instead. This is using the explicit and private cast from 246 // QuantEncodingInternal to QuantEncoding, which would be inlined anyway. 247 // In general, you should use this wrappers. The only reason to directly 248 // create a QuantEncodingInternal instance is if you need a constexpr version 249 // of this class. Note that RAW() is not supported in that case since it uses 250 // a std::vector. Library(uint8_t predefined)251 static QuantEncoding Library(uint8_t predefined) { 252 return QuantEncoding(QuantEncodingInternal::Library(predefined)); 253 } Identity(const IdWeights & xybweights)254 static QuantEncoding Identity(const IdWeights& xybweights) { 255 return QuantEncoding(QuantEncodingInternal::Identity(xybweights)); 256 } DCT2(const DCT2Weights & xybweights)257 static QuantEncoding DCT2(const DCT2Weights& xybweights) { 258 return QuantEncoding(QuantEncodingInternal::DCT2(xybweights)); 259 } DCT4(const DctQuantWeightParams & params,const DCT4Multipliers & xybmul)260 static QuantEncoding DCT4(const DctQuantWeightParams& params, 261 const DCT4Multipliers& xybmul) { 262 return QuantEncoding(QuantEncodingInternal::DCT4(params, xybmul)); 263 } DCT4X8(const DctQuantWeightParams & params,const DCT4x8Multipliers & xybmul)264 static QuantEncoding DCT4X8(const DctQuantWeightParams& params, 265 const DCT4x8Multipliers& xybmul) { 266 return QuantEncoding(QuantEncodingInternal::DCT4X8(params, xybmul)); 267 } DCT(const DctQuantWeightParams & params)268 static QuantEncoding DCT(const DctQuantWeightParams& params) { 269 return QuantEncoding(QuantEncodingInternal::DCT(params)); 270 } AFV(const DctQuantWeightParams & params4x8,const DctQuantWeightParams & params4x4,const AFVWeights & weights)271 static QuantEncoding AFV(const DctQuantWeightParams& params4x8, 272 const DctQuantWeightParams& params4x4, 273 const AFVWeights& weights) { 274 return QuantEncoding( 275 QuantEncodingInternal::AFV(params4x8, params4x4, weights)); 276 } 277 278 // RAW, note that this one is not a constexpr one. 279 static QuantEncoding RAW(const std::vector<int>& qtable, int shift = 0) { 280 QuantEncoding encoding(kQuantModeRAW); 281 encoding.qraw.qtable = new std::vector<int>(); 282 *encoding.qraw.qtable = qtable; 283 encoding.qraw.qtable_den = (1 << shift) * (1.f / (8 * 255)); 284 return encoding; 285 } 286 287 private: QuantEncoding(const QuantEncodingInternal & other)288 explicit QuantEncoding(const QuantEncodingInternal& other) 289 : QuantEncodingInternal(other) {} 290 QuantEncoding(QuantEncodingInternal::Mode mode)291 explicit QuantEncoding(QuantEncodingInternal::Mode mode) 292 : QuantEncodingInternal(mode) {} 293 }; 294 295 // A constexpr QuantEncodingInternal instance is often downcasted to the 296 // QuantEncoding subclass even if the instance wasn't an instance of the 297 // subclass. This is safe because user will upcast to QuantEncodingInternal to 298 // access any of its members. 299 static_assert(sizeof(QuantEncoding) == sizeof(QuantEncodingInternal), 300 "Don't add any members to QuantEncoding"); 301 302 // Let's try to keep these 2**N for possible future simplicity. 303 const float kInvDCQuant[3] = { 304 4096.0f, 305 512.0f, 306 256.0f, 307 }; 308 309 const float kDCQuant[3] = { 310 1.0f / kInvDCQuant[0], 311 1.0f / kInvDCQuant[1], 312 1.0f / kInvDCQuant[2], 313 }; 314 315 class ModularFrameEncoder; 316 class ModularFrameDecoder; 317 318 class DequantMatrices { 319 public: 320 enum QuantTable : size_t { 321 DCT = 0, 322 IDENTITY, 323 DCT2X2, 324 DCT4X4, 325 DCT16X16, 326 DCT32X32, 327 // DCT16X8 328 DCT8X16, 329 // DCT32X8 330 DCT8X32, 331 // DCT32X16 332 DCT16X32, 333 DCT4X8, 334 // DCT8X4 335 AFV0, 336 // AFV1 337 // AFV2 338 // AFV3 339 DCT64X64, 340 // DCT64X32, 341 DCT32X64, 342 DCT128X128, 343 // DCT128X64, 344 DCT64X128, 345 DCT256X256, 346 // DCT256X128, 347 DCT128X256, 348 kNum 349 }; 350 351 static constexpr QuantTable kQuantTable[] = { 352 QuantTable::DCT, QuantTable::IDENTITY, QuantTable::DCT2X2, 353 QuantTable::DCT4X4, QuantTable::DCT16X16, QuantTable::DCT32X32, 354 QuantTable::DCT8X16, QuantTable::DCT8X16, QuantTable::DCT8X32, 355 QuantTable::DCT8X32, QuantTable::DCT16X32, QuantTable::DCT16X32, 356 QuantTable::DCT4X8, QuantTable::DCT4X8, QuantTable::AFV0, 357 QuantTable::AFV0, QuantTable::AFV0, QuantTable::AFV0, 358 QuantTable::DCT64X64, QuantTable::DCT32X64, QuantTable::DCT32X64, 359 QuantTable::DCT128X128, QuantTable::DCT64X128, QuantTable::DCT64X128, 360 QuantTable::DCT256X256, QuantTable::DCT128X256, QuantTable::DCT128X256, 361 }; 362 static_assert(AcStrategy::kNumValidStrategies == 363 sizeof(kQuantTable) / sizeof *kQuantTable, 364 "Update this array when adding or removing AC strategies."); 365 366 DequantMatrices(); 367 368 static const QuantEncoding* Library(); 369 370 typedef std::array<QuantEncodingInternal, kNumPredefinedTables * kNum> 371 DequantLibraryInternal; 372 // Return the array of library kNumPredefinedTables QuantEncoding entries as 373 // a constexpr array. Use Library() to obtain a pointer to the copy in the 374 // .cc file. 375 static const DequantLibraryInternal LibraryInit(); 376 377 // Returns aligned memory. Matrix(size_t quant_kind,size_t c)378 JXL_INLINE const float* Matrix(size_t quant_kind, size_t c) const { 379 JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); 380 JXL_DASSERT((1 << quant_kind) & computed_mask_); 381 return &table_[table_offsets_[quant_kind * 3 + c]]; 382 } 383 InvMatrix(size_t quant_kind,size_t c)384 JXL_INLINE const float* InvMatrix(size_t quant_kind, size_t c) const { 385 JXL_DASSERT(quant_kind < AcStrategy::kNumValidStrategies); 386 JXL_DASSERT((1 << quant_kind) & computed_mask_); 387 return &inv_table_[table_offsets_[quant_kind * 3 + c]]; 388 } 389 390 // DC quants are used in modular mode for XYB multipliers. DCQuant(size_t c)391 JXL_INLINE float DCQuant(size_t c) const { return dc_quant_[c]; } DCQuants()392 JXL_INLINE const float* DCQuants() const { return dc_quant_; } 393 InvDCQuant(size_t c)394 JXL_INLINE float InvDCQuant(size_t c) const { return inv_dc_quant_[c]; } 395 396 // For encoder. SetEncodings(const std::vector<QuantEncoding> & encodings)397 void SetEncodings(const std::vector<QuantEncoding>& encodings) { 398 encodings_ = encodings; 399 computed_mask_ = 0; 400 } 401 402 // For encoder. SetDCQuant(const float dc[3])403 void SetDCQuant(const float dc[3]) { 404 for (size_t c = 0; c < 3; c++) { 405 dc_quant_[c] = 1.0f / dc[c]; 406 inv_dc_quant_[c] = dc[c]; 407 } 408 } 409 410 Status Decode(BitReader* br, 411 ModularFrameDecoder* modular_frame_decoder = nullptr); 412 Status DecodeDC(BitReader* br); 413 encodings()414 const std::vector<QuantEncoding>& encodings() const { return encodings_; } 415 416 static constexpr size_t required_size_x[] = {1, 1, 1, 1, 2, 4, 1, 1, 2, 417 1, 1, 8, 4, 16, 8, 32, 16}; 418 static_assert(kNum == sizeof(required_size_x) / sizeof(*required_size_x), 419 "Update this array when adding or removing quant tables."); 420 421 static constexpr size_t required_size_y[] = {1, 1, 1, 1, 2, 4, 2, 4, 4, 422 1, 1, 8, 8, 16, 16, 32, 32}; 423 static_assert(kNum == sizeof(required_size_y) / sizeof(*required_size_y), 424 "Update this array when adding or removing quant tables."); 425 426 Status EnsureComputed(uint32_t kind_mask); 427 428 private: 429 static constexpr size_t required_size_[] = { 430 1, 1, 1, 1, 4, 16, 2, 4, 8, 1, 1, 64, 32, 256, 128, 1024, 512}; 431 static_assert(kNum == sizeof(required_size_) / sizeof(*required_size_), 432 "Update this array when adding or removing quant tables."); 433 static constexpr size_t kTotalTableSize = 434 ArraySum(required_size_) * kDCTBlockSize * 3; 435 436 uint32_t computed_mask_ = 0; 437 // kTotalTableSize entries followed by kTotalTableSize for inv_table 438 hwy::AlignedFreeUniquePtr<float[]> table_storage_; 439 const float* table_; 440 const float* inv_table_; 441 float dc_quant_[3] = {kDCQuant[0], kDCQuant[1], kDCQuant[2]}; 442 float inv_dc_quant_[3] = {kInvDCQuant[0], kInvDCQuant[1], kInvDCQuant[2]}; 443 size_t table_offsets_[AcStrategy::kNumValidStrategies * 3]; 444 std::vector<QuantEncoding> encodings_; 445 }; 446 447 } // namespace jxl 448 449 #endif // LIB_JXL_QUANT_WEIGHTS_H_ 450