1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #ifndef LIB_JXL_DEC_PATCH_DICTIONARY_H_
7 #define LIB_JXL_DEC_PATCH_DICTIONARY_H_
8 
9 // Chooses reference patches, and avoids encoding them once per occurrence.
10 
11 #include <stddef.h>
12 #include <string.h>
13 #include <sys/types.h>
14 
15 #include <tuple>
16 #include <vector>
17 
18 #include "lib/jxl/base/status.h"
19 #include "lib/jxl/common.h"
20 #include "lib/jxl/dec_bit_reader.h"
21 #include "lib/jxl/image.h"
22 #include "lib/jxl/opsin_params.h"
23 
24 namespace jxl {
25 
26 constexpr size_t kMaxPatchSize = 32;
27 
28 enum class PatchBlendMode : uint8_t {
29   // The new values are the old ones. Useful to skip some channels.
30   kNone = 0,
31   // The new values (in the crop) replace the old ones: sample = new
32   kReplace = 1,
33   // The new values (in the crop) get added to the old ones: sample = old + new
34   kAdd = 2,
35   // The new values (in the crop) get multiplied by the old ones:
36   // sample = old * new
37   // This blend mode is only supported if BlendColorSpace is kEncoded. The
38   // range of the new value matters for multiplication purposes, and its
39   // nominal range of 0..1 is computed the same way as this is done for the
40   // alpha values in kBlend and kAlphaWeightedAdd.
41   kMul = 3,
42   // The new values (in the crop) replace the old ones if alpha>0:
43   // For first alpha channel:
44   // alpha = old + new * (1 - old)
45   // For other channels if !alpha_associated:
46   // sample = ((1 - new_alpha) * old * old_alpha + new_alpha * new) / alpha
47   // For other channels if alpha_associated:
48   // sample = (1 - new_alpha) * old + new
49   // The alpha formula applies to the alpha used for the division in the other
50   // channels formula, and applies to the alpha channel itself if its
51   // blend_channel value matches itself.
52   // If using kBlendAbove, new is the patch and old is the original image; if
53   // using kBlendBelow, the meaning is inverted.
54   kBlendAbove = 4,
55   kBlendBelow = 5,
56   // The new values (in the crop) are added to the old ones if alpha>0:
57   // For first alpha channel: sample = sample = old + new * (1 - old)
58   // For other channels: sample = old + alpha * new
59   kAlphaWeightedAddAbove = 6,
60   kAlphaWeightedAddBelow = 7,
61   kNumBlendModes,
62 };
63 
UsesAlpha(PatchBlendMode mode)64 inline bool UsesAlpha(PatchBlendMode mode) {
65   return mode == PatchBlendMode::kBlendAbove ||
66          mode == PatchBlendMode::kBlendBelow ||
67          mode == PatchBlendMode::kAlphaWeightedAddAbove ||
68          mode == PatchBlendMode::kAlphaWeightedAddBelow;
69 }
UsesClamp(PatchBlendMode mode)70 inline bool UsesClamp(PatchBlendMode mode) {
71   return UsesAlpha(mode) || mode == PatchBlendMode::kMul;
72 }
73 
74 struct PatchBlending {
75   PatchBlendMode mode;
76   uint32_t alpha_channel;
77   bool clamp;
78 };
79 
80 struct QuantizedPatch {
81   size_t xsize;
82   size_t ysize;
QuantizedPatchQuantizedPatch83   QuantizedPatch() {
84     for (size_t i = 0; i < 3; i++) {
85       pixels[i].resize(kMaxPatchSize * kMaxPatchSize);
86       fpixels[i].resize(kMaxPatchSize * kMaxPatchSize);
87     }
88   }
89   std::vector<int8_t> pixels[3] = {};
90   // Not compared. Used only to retrieve original pixels to construct the
91   // reference image.
92   std::vector<float> fpixels[3] = {};
93   bool operator==(const QuantizedPatch& other) const {
94     if (xsize != other.xsize) return false;
95     if (ysize != other.ysize) return false;
96     for (size_t c = 0; c < 3; c++) {
97       if (memcmp(pixels[c].data(), other.pixels[c].data(),
98                  sizeof(int8_t) * xsize * ysize) != 0)
99         return false;
100     }
101     return true;
102   }
103 
104   bool operator<(const QuantizedPatch& other) const {
105     if (xsize != other.xsize) return xsize < other.xsize;
106     if (ysize != other.ysize) return ysize < other.ysize;
107     for (size_t c = 0; c < 3; c++) {
108       int cmp = memcmp(pixels[c].data(), other.pixels[c].data(),
109                        sizeof(int8_t) * xsize * ysize);
110       if (cmp > 0) return false;
111       if (cmp < 0) return true;
112     }
113     return false;
114   }
115 };
116 
117 // Pair (patch, vector of occurrences).
118 using PatchInfo =
119     std::pair<QuantizedPatch, std::vector<std::pair<uint32_t, uint32_t>>>;
120 
121 // Position and size of the patch in the reference frame.
122 struct PatchReferencePosition {
123   size_t ref, x0, y0, xsize, ysize;
124   bool operator<(const PatchReferencePosition& oth) const {
125     return std::make_tuple(ref, x0, y0, xsize, ysize) <
126            std::make_tuple(oth.ref, oth.x0, oth.y0, oth.xsize, oth.ysize);
127   }
128   bool operator==(const PatchReferencePosition& oth) const {
129     return !(*this < oth) && !(oth < *this);
130   }
131 };
132 
133 struct PatchPosition {
134   // Position of top-left corner of the patch in the image.
135   size_t x, y;
136   // Different blend mode for color and extra channels.
137   std::vector<PatchBlending> blending;
138   PatchReferencePosition ref_pos;
139   bool operator<(const PatchPosition& oth) const {
140     return std::make_tuple(ref_pos, x, y) <
141            std::make_tuple(oth.ref_pos, oth.x, oth.y);
142   }
143 };
144 
145 struct PassesSharedState;
146 
147 // Encoder-side helper class to encode the PatchesDictionary.
148 class PatchDictionaryEncoder;
149 
150 class PatchDictionary {
151  public:
152   PatchDictionary() = default;
153 
SetPassesSharedState(const PassesSharedState * shared)154   void SetPassesSharedState(const PassesSharedState* shared) {
155     shared_ = shared;
156   }
157 
HasAny()158   bool HasAny() const { return !positions_.empty(); }
159 
160   Status Decode(BitReader* br, size_t xsize, size_t ysize,
161                 bool* uses_extra_channels);
162 
Clear()163   void Clear() {
164     positions_.clear();
165     ComputePatchCache();
166   }
167 
168   // Only adds patches that belong to the `image_rect` area of the decoded
169   // image, writing them to the `opsin_rect` area of `opsin`.
170   Status AddTo(Image3F* opsin, const Rect& opsin_rect,
171                float* const* extra_channels, const Rect& image_rect) const;
172 
173   // Returns dependencies of this patch dictionary on reference frame ids as a
174   // bit mask: bits 0-3 indicate reference frame 0-3.
175   int GetReferences() const;
176 
177  private:
178   friend class PatchDictionaryEncoder;
179 
180   const PassesSharedState* shared_;
181   std::vector<PatchPosition> positions_;
182 
183   // Patch occurrences sorted by y.
184   std::vector<size_t> sorted_patches_;
185   // Index of the first patch for each y value.
186   std::vector<size_t> patch_starts_;
187 
188   // Patch IDs in position [patch_starts_[y], patch_start_[y+1]) of
189   // sorted_patches_ are all the patches that intersect the horizontal line at
190   // y.
191   // The relative order of patches that affect the same pixels is the same -
192   // important when applying patches is noncommutative.
193 
194   // Compute patches_by_y_ after updating positions_.
195   void ComputePatchCache();
196 };
197 
198 }  // namespace jxl
199 
200 #endif  // LIB_JXL_DEC_PATCH_DICTIONARY_H_
201