1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_patch_dictionary.h"
7 
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <sys/types.h>
11 
12 #include <algorithm>
13 #include <random>
14 #include <string>
15 #include <tuple>
16 #include <utility>
17 #include <vector>
18 
19 #include "lib/jxl/ans_params.h"
20 #include "lib/jxl/base/compiler_specific.h"
21 #include "lib/jxl/base/override.h"
22 #include "lib/jxl/base/status.h"
23 #include "lib/jxl/chroma_from_luma.h"
24 #include "lib/jxl/color_management.h"
25 #include "lib/jxl/common.h"
26 #include "lib/jxl/dec_cache.h"
27 #include "lib/jxl/dec_frame.h"
28 #include "lib/jxl/enc_ans.h"
29 #include "lib/jxl/enc_cache.h"
30 #include "lib/jxl/enc_dot_dictionary.h"
31 #include "lib/jxl/enc_frame.h"
32 #include "lib/jxl/entropy_coder.h"
33 #include "lib/jxl/frame_header.h"
34 #include "lib/jxl/image.h"
35 #include "lib/jxl/image_bundle.h"
36 #include "lib/jxl/image_ops.h"
37 #include "lib/jxl/patch_dictionary_internal.h"
38 
39 namespace jxl {
40 
41 // static
Encode(const PatchDictionary & pdic,BitWriter * writer,size_t layer,AuxOut * aux_out)42 void PatchDictionaryEncoder::Encode(const PatchDictionary& pdic,
43                                     BitWriter* writer, size_t layer,
44                                     AuxOut* aux_out) {
45   JXL_ASSERT(pdic.HasAny());
46   std::vector<std::vector<Token>> tokens(1);
47 
48   auto add_num = [&](int context, size_t num) {
49     tokens[0].emplace_back(context, num);
50   };
51   size_t num_ref_patch = 0;
52   for (size_t i = 0; i < pdic.positions_.size();) {
53     size_t i_start = i;
54     while (i < pdic.positions_.size() &&
55            pdic.positions_[i].ref_pos == pdic.positions_[i_start].ref_pos) {
56       i++;
57     }
58     num_ref_patch++;
59   }
60   add_num(kNumRefPatchContext, num_ref_patch);
61   for (size_t i = 0; i < pdic.positions_.size();) {
62     size_t i_start = i;
63     while (i < pdic.positions_.size() &&
64            pdic.positions_[i].ref_pos == pdic.positions_[i_start].ref_pos) {
65       i++;
66     }
67     size_t num = i - i_start;
68     JXL_ASSERT(num > 0);
69     add_num(kReferenceFrameContext, pdic.positions_[i_start].ref_pos.ref);
70     add_num(kPatchReferencePositionContext,
71             pdic.positions_[i_start].ref_pos.x0);
72     add_num(kPatchReferencePositionContext,
73             pdic.positions_[i_start].ref_pos.y0);
74     add_num(kPatchSizeContext, pdic.positions_[i_start].ref_pos.xsize - 1);
75     add_num(kPatchSizeContext, pdic.positions_[i_start].ref_pos.ysize - 1);
76     add_num(kPatchCountContext, num - 1);
77     for (size_t j = i_start; j < i; j++) {
78       const PatchPosition& pos = pdic.positions_[j];
79       if (j == i_start) {
80         add_num(kPatchPositionContext, pos.x);
81         add_num(kPatchPositionContext, pos.y);
82       } else {
83         add_num(kPatchOffsetContext,
84                 PackSigned(pos.x - pdic.positions_[j - 1].x));
85         add_num(kPatchOffsetContext,
86                 PackSigned(pos.y - pdic.positions_[j - 1].y));
87       }
88       JXL_ASSERT(pdic.shared_->metadata->m.extra_channel_info.size() + 1 ==
89                  pos.blending.size());
90       for (size_t i = 0;
91            i < pdic.shared_->metadata->m.extra_channel_info.size() + 1; i++) {
92         const PatchBlending& info = pos.blending[i];
93         add_num(kPatchBlendModeContext, static_cast<uint32_t>(info.mode));
94         if (UsesAlpha(info.mode) &&
95             pdic.shared_->metadata->m.extra_channel_info.size() > 1) {
96           add_num(kPatchAlphaChannelContext, info.alpha_channel);
97         }
98         if (UsesClamp(info.mode)) {
99           add_num(kPatchClampContext, info.clamp);
100         }
101       }
102     }
103   }
104 
105   EntropyEncodingData codes;
106   std::vector<uint8_t> context_map;
107   BuildAndEncodeHistograms(HistogramParams(), kNumPatchDictionaryContexts,
108                            tokens, &codes, &context_map, writer, layer,
109                            aux_out);
110   WriteTokens(tokens[0], codes, context_map, writer, layer, aux_out);
111 }
112 
113 // static
SubtractFrom(const PatchDictionary & pdic,Image3F * opsin)114 void PatchDictionaryEncoder::SubtractFrom(const PatchDictionary& pdic,
115                                           Image3F* opsin) {
116   // TODO(veluca): this can likely be optimized knowing it runs on full images.
117   for (size_t y = 0; y < opsin->ysize(); y++) {
118     if (y + 1 >= pdic.patch_starts_.size()) continue;
119     float* JXL_RESTRICT rows[3] = {
120         opsin->PlaneRow(0, y),
121         opsin->PlaneRow(1, y),
122         opsin->PlaneRow(2, y),
123     };
124     for (size_t id = pdic.patch_starts_[y]; id < pdic.patch_starts_[y + 1];
125          id++) {
126       const PatchPosition& pos = pdic.positions_[pdic.sorted_patches_[id]];
127       size_t by = pos.y;
128       size_t bx = pos.x;
129       size_t xsize = pos.ref_pos.xsize;
130       JXL_DASSERT(y >= by);
131       JXL_DASSERT(y < by + pos.ref_pos.ysize);
132       size_t iy = y - by;
133       size_t ref = pos.ref_pos.ref;
134       const float* JXL_RESTRICT ref_rows[3] = {
135           pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow(
136               0, pos.ref_pos.y0 + iy) +
137               pos.ref_pos.x0,
138           pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow(
139               1, pos.ref_pos.y0 + iy) +
140               pos.ref_pos.x0,
141           pdic.shared_->reference_frames[ref].frame->color()->ConstPlaneRow(
142               2, pos.ref_pos.y0 + iy) +
143               pos.ref_pos.x0,
144       };
145       for (size_t ix = 0; ix < xsize; ix++) {
146         for (size_t c = 0; c < 3; c++) {
147           if (pos.blending[0].mode == PatchBlendMode::kAdd) {
148             rows[c][bx + ix] -= ref_rows[c][ix];
149           } else if (pos.blending[0].mode == PatchBlendMode::kReplace) {
150             rows[c][bx + ix] = 0;
151           } else if (pos.blending[0].mode == PatchBlendMode::kNone) {
152             // Nothing to do.
153           } else {
154             JXL_ABORT("Blending mode %u not yet implemented",
155                       (uint32_t)pos.blending[0].mode);
156           }
157         }
158       }
159     }
160   }
161 }
162 
163 namespace {
164 
165 struct PatchColorspaceInfo {
166   float kChannelDequant[3];
167   float kChannelWeights[3];
168 
PatchColorspaceInfojxl::__anonbfd532db0211::PatchColorspaceInfo169   explicit PatchColorspaceInfo(bool is_xyb) {
170     if (is_xyb) {
171       kChannelDequant[0] = 0.01615;
172       kChannelDequant[1] = 0.08875;
173       kChannelDequant[2] = 0.1922;
174       kChannelWeights[0] = 30.0;
175       kChannelWeights[1] = 3.0;
176       kChannelWeights[2] = 1.0;
177     } else {
178       kChannelDequant[0] = 20.0f / 255;
179       kChannelDequant[1] = 22.0f / 255;
180       kChannelDequant[2] = 20.0f / 255;
181       kChannelWeights[0] = 0.017 * 255;
182       kChannelWeights[1] = 0.02 * 255;
183       kChannelWeights[2] = 0.017 * 255;
184     }
185   }
186 
ScaleForQuantizationjxl::__anonbfd532db0211::PatchColorspaceInfo187   float ScaleForQuantization(float val, size_t c) {
188     return val / kChannelDequant[c];
189   }
190 
Quantizejxl::__anonbfd532db0211::PatchColorspaceInfo191   int Quantize(float val, size_t c) {
192     return truncf(ScaleForQuantization(val, c));
193   }
194 
is_similar_vjxl::__anonbfd532db0211::PatchColorspaceInfo195   bool is_similar_v(const float v1[3], const float v2[3], float threshold) {
196     float distance = 0;
197     for (size_t c = 0; c < 3; c++) {
198       distance += std::fabs(v1[c] - v2[c]) * kChannelWeights[c];
199     }
200     return distance <= threshold;
201   }
202 };
203 
FindTextLikePatches(const Image3F & opsin,const PassesEncoderState * JXL_RESTRICT state,ThreadPool * pool,AuxOut * aux_out,bool is_xyb)204 std::vector<PatchInfo> FindTextLikePatches(
205     const Image3F& opsin, const PassesEncoderState* JXL_RESTRICT state,
206     ThreadPool* pool, AuxOut* aux_out, bool is_xyb) {
207   if (state->cparams.patches == Override::kOff) return {};
208 
209   PatchColorspaceInfo pci(is_xyb);
210   float kSimilarThreshold = 0.8f;
211 
212   auto is_similar_impl = [&pci](std::pair<uint32_t, uint32_t> p1,
213                                 std::pair<uint32_t, uint32_t> p2,
214                                 const float* JXL_RESTRICT rows[3],
215                                 size_t stride, float threshold) {
216     float v1[3], v2[3];
217     for (size_t c = 0; c < 3; c++) {
218       v1[c] = rows[c][p1.second * stride + p1.first];
219       v2[c] = rows[c][p2.second * stride + p2.first];
220     }
221     return pci.is_similar_v(v1, v2, threshold);
222   };
223 
224   std::atomic<bool> has_screenshot_areas{false};
225   const size_t opsin_stride = opsin.PixelsPerRow();
226   const float* JXL_RESTRICT opsin_rows[3] = {opsin.ConstPlaneRow(0, 0),
227                                              opsin.ConstPlaneRow(1, 0),
228                                              opsin.ConstPlaneRow(2, 0)};
229 
230   auto is_same = [&opsin_rows, opsin_stride](std::pair<uint32_t, uint32_t> p1,
231                                              std::pair<uint32_t, uint32_t> p2) {
232     for (size_t c = 0; c < 3; c++) {
233       float v1 = opsin_rows[c][p1.second * opsin_stride + p1.first];
234       float v2 = opsin_rows[c][p2.second * opsin_stride + p2.first];
235       if (std::fabs(v1 - v2) > 1e-4) {
236         return false;
237       }
238     }
239     return true;
240   };
241 
242   auto is_similar = [&](std::pair<uint32_t, uint32_t> p1,
243                         std::pair<uint32_t, uint32_t> p2) {
244     return is_similar_impl(p1, p2, opsin_rows, opsin_stride, kSimilarThreshold);
245   };
246 
247   constexpr int64_t kPatchSide = 4;
248   constexpr int64_t kExtraSide = 4;
249 
250   // Look for kPatchSide size squares, naturally aligned, that all have the same
251   // pixel values.
252   ImageB is_screenshot_like(DivCeil(opsin.xsize(), kPatchSide),
253                             DivCeil(opsin.ysize(), kPatchSide));
254   ZeroFillImage(&is_screenshot_like);
255   uint8_t* JXL_RESTRICT screenshot_row = is_screenshot_like.Row(0);
256   const size_t screenshot_stride = is_screenshot_like.PixelsPerRow();
257   const auto process_row = [&](uint64_t y, int _) {
258     for (uint64_t x = 0; x < opsin.xsize() / kPatchSide; x++) {
259       bool all_same = true;
260       for (size_t iy = 0; iy < static_cast<size_t>(kPatchSide); iy++) {
261         for (size_t ix = 0; ix < static_cast<size_t>(kPatchSide); ix++) {
262           size_t cx = x * kPatchSide + ix;
263           size_t cy = y * kPatchSide + iy;
264           if (!is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) {
265             all_same = false;
266             break;
267           }
268         }
269       }
270       if (!all_same) continue;
271       size_t num = 0;
272       size_t num_same = 0;
273       for (int64_t iy = -kExtraSide; iy < kExtraSide + kPatchSide; iy++) {
274         for (int64_t ix = -kExtraSide; ix < kExtraSide + kPatchSide; ix++) {
275           int64_t cx = x * kPatchSide + ix;
276           int64_t cy = y * kPatchSide + iy;
277           if (cx < 0 || static_cast<uint64_t>(cx) >= opsin.xsize() ||  //
278               cy < 0 || static_cast<uint64_t>(cy) >= opsin.ysize()) {
279             continue;
280           }
281           num++;
282           if (is_same({cx, cy}, {x * kPatchSide, y * kPatchSide})) num_same++;
283         }
284       }
285       // Too few equal pixels nearby.
286       if (num_same * 8 < num * 7) continue;
287       screenshot_row[y * screenshot_stride + x] = 1;
288       has_screenshot_areas = true;
289     }
290   };
291   RunOnPool(pool, 0, opsin.ysize() / kPatchSide, ThreadPool::SkipInit(),
292             process_row, "IsScreenshotLike");
293 
294   // TODO(veluca): also parallelize the rest of this function.
295   if (WantDebugOutput(aux_out)) {
296     aux_out->DumpPlaneNormalized("screenshot_like", is_screenshot_like);
297   }
298 
299   constexpr int kSearchRadius = 1;
300 
301   if (!ApplyOverride(state->cparams.patches, has_screenshot_areas)) {
302     return {};
303   }
304 
305   // Search for "similar enough" pixels near the screenshot-like areas.
306   ImageB is_background(opsin.xsize(), opsin.ysize());
307   ZeroFillImage(&is_background);
308   Image3F background(opsin.xsize(), opsin.ysize());
309   ZeroFillImage(&background);
310   constexpr size_t kDistanceLimit = 50;
311   float* JXL_RESTRICT background_rows[3] = {
312       background.PlaneRow(0, 0),
313       background.PlaneRow(1, 0),
314       background.PlaneRow(2, 0),
315   };
316   const size_t background_stride = background.PixelsPerRow();
317   uint8_t* JXL_RESTRICT is_background_row = is_background.Row(0);
318   const size_t is_background_stride = is_background.PixelsPerRow();
319   std::vector<
320       std::pair<std::pair<uint32_t, uint32_t>, std::pair<uint32_t, uint32_t>>>
321       queue;
322   size_t queue_front = 0;
323   for (size_t y = 0; y < opsin.ysize(); y++) {
324     for (size_t x = 0; x < opsin.xsize(); x++) {
325       if (!screenshot_row[screenshot_stride * (y / kPatchSide) +
326                           (x / kPatchSide)])
327         continue;
328       queue.push_back({{x, y}, {x, y}});
329     }
330   }
331   while (queue.size() != queue_front) {
332     std::pair<uint32_t, uint32_t> cur = queue[queue_front].first;
333     std::pair<uint32_t, uint32_t> src = queue[queue_front].second;
334     queue_front++;
335     if (is_background_row[cur.second * is_background_stride + cur.first])
336       continue;
337     is_background_row[cur.second * is_background_stride + cur.first] = 1;
338     for (size_t c = 0; c < 3; c++) {
339       background_rows[c][cur.second * background_stride + cur.first] =
340           opsin_rows[c][src.second * opsin_stride + src.first];
341     }
342     for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
343       for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
344         if (dx == 0 && dy == 0) continue;
345         int next_first = cur.first + dx;
346         int next_second = cur.second + dy;
347         if (next_first < 0 || next_second < 0 ||
348             static_cast<uint32_t>(next_first) >= opsin.xsize() ||
349             static_cast<uint32_t>(next_second) >= opsin.ysize()) {
350           continue;
351         }
352         if (static_cast<uint32_t>(
353                 std::abs(next_first - static_cast<int>(src.first)) +
354                 std::abs(next_second - static_cast<int>(src.second))) >
355             kDistanceLimit) {
356           continue;
357         }
358         std::pair<uint32_t, uint32_t> next{next_first, next_second};
359         if (is_similar(src, next)) {
360           if (!screenshot_row[next.second / kPatchSide * screenshot_stride +
361                               next.first / kPatchSide] ||
362               is_same(src, next)) {
363             if (!is_background_row[next.second * is_background_stride +
364                                    next.first])
365               queue.emplace_back(next, src);
366           }
367         }
368       }
369     }
370   }
371   queue.clear();
372 
373   ImageF ccs;
374   std::mt19937 rng;
375   std::uniform_real_distribution<float> dist(0.5, 1.0);
376   bool paint_ccs = false;
377   if (WantDebugOutput(aux_out)) {
378     aux_out->DumpPlaneNormalized("is_background", is_background);
379     if (is_xyb) {
380       aux_out->DumpXybImage("background", background);
381     } else {
382       aux_out->DumpImage("background", background);
383     }
384     ccs = ImageF(opsin.xsize(), opsin.ysize());
385     ZeroFillImage(&ccs);
386     paint_ccs = true;
387   }
388 
389   constexpr float kVerySimilarThreshold = 0.03f;
390   constexpr float kHasSimilarThreshold = 0.03f;
391 
392   const float* JXL_RESTRICT const_background_rows[3] = {
393       background_rows[0], background_rows[1], background_rows[2]};
394   auto is_similar_b = [&](std::pair<int, int> p1, std::pair<int, int> p2) {
395     return is_similar_impl(p1, p2, const_background_rows, background_stride,
396                            kVerySimilarThreshold);
397   };
398 
399   constexpr int kMinPeak = 2;
400   constexpr int kHasSimilarRadius = 2;
401 
402   std::vector<PatchInfo> info;
403 
404   // Find small CC outside the "similar enough" areas, compute bounding boxes,
405   // and run heuristics to exclude some patches.
406   ImageB visited(opsin.xsize(), opsin.ysize());
407   ZeroFillImage(&visited);
408   uint8_t* JXL_RESTRICT visited_row = visited.Row(0);
409   const size_t visited_stride = visited.PixelsPerRow();
410   std::vector<std::pair<uint32_t, uint32_t>> cc;
411   std::vector<std::pair<uint32_t, uint32_t>> stack;
412   for (size_t y = 0; y < opsin.ysize(); y++) {
413     for (size_t x = 0; x < opsin.xsize(); x++) {
414       if (is_background_row[y * is_background_stride + x]) continue;
415       cc.clear();
416       stack.clear();
417       stack.emplace_back(x, y);
418       size_t min_x = x;
419       size_t max_x = x;
420       size_t min_y = y;
421       size_t max_y = y;
422       std::pair<uint32_t, uint32_t> reference;
423       bool found_border = false;
424       bool all_similar = true;
425       while (!stack.empty()) {
426         std::pair<uint32_t, uint32_t> cur = stack.back();
427         stack.pop_back();
428         if (visited_row[cur.second * visited_stride + cur.first]) continue;
429         visited_row[cur.second * visited_stride + cur.first] = 1;
430         if (cur.first < min_x) min_x = cur.first;
431         if (cur.first > max_x) max_x = cur.first;
432         if (cur.second < min_y) min_y = cur.second;
433         if (cur.second > max_y) max_y = cur.second;
434         if (paint_ccs) {
435           cc.push_back(cur);
436         }
437         for (int dx = -kSearchRadius; dx <= kSearchRadius; dx++) {
438           for (int dy = -kSearchRadius; dy <= kSearchRadius; dy++) {
439             if (dx == 0 && dy == 0) continue;
440             int next_first = static_cast<int32_t>(cur.first) + dx;
441             int next_second = static_cast<int32_t>(cur.second) + dy;
442             if (next_first < 0 || next_second < 0 ||
443                 static_cast<uint32_t>(next_first) >= opsin.xsize() ||
444                 static_cast<uint32_t>(next_second) >= opsin.ysize()) {
445               continue;
446             }
447             std::pair<uint32_t, uint32_t> next{next_first, next_second};
448             if (!is_background_row[next.second * is_background_stride +
449                                    next.first]) {
450               stack.push_back(next);
451             } else {
452               if (!found_border) {
453                 reference = next;
454                 found_border = true;
455               } else {
456                 if (!is_similar_b(next, reference)) all_similar = false;
457               }
458             }
459           }
460         }
461       }
462       if (!found_border || !all_similar || max_x - min_x >= kMaxPatchSize ||
463           max_y - min_y >= kMaxPatchSize) {
464         continue;
465       }
466       size_t bpos = background_stride * reference.second + reference.first;
467       float ref[3] = {background_rows[0][bpos], background_rows[1][bpos],
468                       background_rows[2][bpos]};
469       bool has_similar = false;
470       for (size_t iy = std::max<int>(
471                static_cast<int32_t>(min_y) - kHasSimilarRadius, 0);
472            iy < std::min(max_y + kHasSimilarRadius + 1, opsin.ysize()); iy++) {
473         for (size_t ix = std::max<int>(
474                  static_cast<int32_t>(min_x) - kHasSimilarRadius, 0);
475              ix < std::min(max_x + kHasSimilarRadius + 1, opsin.xsize());
476              ix++) {
477           size_t opos = opsin_stride * iy + ix;
478           float px[3] = {opsin_rows[0][opos], opsin_rows[1][opos],
479                          opsin_rows[2][opos]};
480           if (pci.is_similar_v(ref, px, kHasSimilarThreshold)) {
481             has_similar = true;
482           }
483         }
484       }
485       if (!has_similar) continue;
486       info.emplace_back();
487       info.back().second.emplace_back(min_x, min_y);
488       QuantizedPatch& patch = info.back().first;
489       patch.xsize = max_x - min_x + 1;
490       patch.ysize = max_y - min_y + 1;
491       int max_value = 0;
492       for (size_t c : {1, 0, 2}) {
493         for (size_t iy = min_y; iy <= max_y; iy++) {
494           for (size_t ix = min_x; ix <= max_x; ix++) {
495             size_t offset = (iy - min_y) * patch.xsize + ix - min_x;
496             patch.fpixels[c][offset] =
497                 opsin_rows[c][iy * opsin_stride + ix] - ref[c];
498             int val = pci.Quantize(patch.fpixels[c][offset], c);
499             patch.pixels[c][offset] = val;
500             if (std::abs(val) > max_value) max_value = std::abs(val);
501           }
502         }
503       }
504       if (max_value < kMinPeak) {
505         info.pop_back();
506         continue;
507       }
508       if (paint_ccs) {
509         float cc_color = dist(rng);
510         for (std::pair<uint32_t, uint32_t> p : cc) {
511           ccs.Row(p.second)[p.first] = cc_color;
512         }
513       }
514     }
515   }
516 
517   if (paint_ccs) {
518     JXL_ASSERT(WantDebugOutput(aux_out));
519     aux_out->DumpPlaneNormalized("ccs", ccs);
520   }
521   if (info.empty()) {
522     return {};
523   }
524 
525   // Remove duplicates.
526   constexpr size_t kMinPatchOccurences = 2;
527   std::sort(info.begin(), info.end());
528   size_t unique = 0;
529   for (size_t i = 1; i < info.size(); i++) {
530     if (info[i].first == info[unique].first) {
531       info[unique].second.insert(info[unique].second.end(),
532                                  info[i].second.begin(), info[i].second.end());
533     } else {
534       if (info[unique].second.size() >= kMinPatchOccurences) {
535         unique++;
536       }
537       info[unique] = info[i];
538     }
539   }
540   if (info[unique].second.size() >= kMinPatchOccurences) {
541     unique++;
542   }
543   info.resize(unique);
544 
545   size_t max_patch_size = 0;
546 
547   for (size_t i = 0; i < info.size(); i++) {
548     size_t pixels = info[i].first.xsize * info[i].first.ysize;
549     if (pixels > max_patch_size) max_patch_size = pixels;
550   }
551 
552   // don't use patches if all patches are smaller than this
553   constexpr size_t kMinMaxPatchSize = 20;
554   if (max_patch_size < kMinMaxPatchSize) return {};
555 
556   // Ensure that the specified set of patches doesn't produce out-of-bounds
557   // pixels.
558   // TODO(veluca): figure out why this is still necessary even with RCTs that
559   // don't depend on bit depth.
560   if (state->cparams.modular_mode && state->cparams.quality_pair.first >= 100) {
561     constexpr size_t kMaxPatchArea = kMaxPatchSize * kMaxPatchSize;
562     std::vector<float> min_then_max_px(2 * kMaxPatchArea);
563     for (size_t i = 0; i < info.size(); i++) {
564       for (size_t c = 0; c < 3; c++) {
565         float* JXL_RESTRICT min_px = min_then_max_px.data();
566         float* JXL_RESTRICT max_px = min_px + kMaxPatchArea;
567         std::fill(min_px, min_px + kMaxPatchArea, 1);
568         std::fill(max_px, max_px + kMaxPatchArea, 0);
569         size_t xsize = info[i].first.xsize;
570         for (size_t j = 0; j < info[i].second.size(); j++) {
571           size_t bx = info[i].second[j].first;
572           size_t by = info[i].second[j].second;
573           for (size_t iy = 0; iy < info[i].first.ysize; iy++) {
574             for (size_t ix = 0; ix < xsize; ix++) {
575               float v = opsin_rows[c][(by + iy) * opsin_stride + bx + ix];
576               if (v < min_px[iy * xsize + ix]) min_px[iy * xsize + ix] = v;
577               if (v > max_px[iy * xsize + ix]) max_px[iy * xsize + ix] = v;
578             }
579           }
580         }
581         for (size_t iy = 0; iy < info[i].first.ysize; iy++) {
582           for (size_t ix = 0; ix < xsize; ix++) {
583             float smallest = min_px[iy * xsize + ix];
584             float biggest = max_px[iy * xsize + ix];
585             JXL_ASSERT(smallest <= biggest);
586             float& out = info[i].first.fpixels[c][iy * xsize + ix];
587             // Clamp fpixels so that subtracting the patch never creates a
588             // negative value, or a value above 1.
589             JXL_ASSERT(biggest - 1 <= smallest);
590             out = std::max(smallest, out);
591             out = std::min(biggest - 1.f, out);
592           }
593         }
594       }
595     }
596   }
597   return info;
598 }
599 
600 }  // namespace
601 
FindBestPatchDictionary(const Image3F & opsin,PassesEncoderState * JXL_RESTRICT state,ThreadPool * pool,AuxOut * aux_out,bool is_xyb)602 void FindBestPatchDictionary(const Image3F& opsin,
603                              PassesEncoderState* JXL_RESTRICT state,
604                              ThreadPool* pool, AuxOut* aux_out, bool is_xyb) {
605   state->shared.image_features.patches = PatchDictionary();
606   state->shared.image_features.patches.SetPassesSharedState(&state->shared);
607 
608   std::vector<PatchInfo> info =
609       FindTextLikePatches(opsin, state, pool, aux_out, is_xyb);
610 
611   // TODO(veluca): this doesn't work if both dots and patches are enabled.
612   // For now, since dots and patches are not likely to occur in the same kind of
613   // images, disable dots if some patches were found.
614   if (info.empty() &&
615       ApplyOverride(
616           state->cparams.dots,
617           state->cparams.speed_tier <= SpeedTier::kSquirrel &&
618               state->cparams.butteraugli_distance >= kMinButteraugliForDots)) {
619     info = FindDotDictionary(state->cparams, opsin, state->shared.cmap, pool);
620   }
621 
622   if (info.empty()) return;
623 
624   std::sort(
625       info.begin(), info.end(), [&](const PatchInfo& a, const PatchInfo& b) {
626         return a.first.xsize * a.first.ysize > b.first.xsize * b.first.ysize;
627       });
628 
629   size_t max_x_size = 0;
630   size_t max_y_size = 0;
631   size_t total_pixels = 0;
632 
633   for (size_t i = 0; i < info.size(); i++) {
634     size_t pixels = info[i].first.xsize * info[i].first.ysize;
635     if (max_x_size < info[i].first.xsize) max_x_size = info[i].first.xsize;
636     if (max_y_size < info[i].first.ysize) max_y_size = info[i].first.ysize;
637     total_pixels += pixels;
638   }
639 
640   // Bin-packing & conversion of patches.
641   constexpr float kBinPackingSlackness = 1.05f;
642   size_t ref_xsize = std::max<float>(max_x_size, std::sqrt(total_pixels));
643   size_t ref_ysize = std::max<float>(max_y_size, std::sqrt(total_pixels));
644   std::vector<std::pair<size_t, size_t>> ref_positions(info.size());
645   // TODO(veluca): allow partial overlaps of patches that have the same pixels.
646   size_t max_y = 0;
647   do {
648     max_y = 0;
649     // Increase packed image size.
650     ref_xsize = ref_xsize * kBinPackingSlackness + 1;
651     ref_ysize = ref_ysize * kBinPackingSlackness + 1;
652 
653     ImageB occupied(ref_xsize, ref_ysize);
654     ZeroFillImage(&occupied);
655     uint8_t* JXL_RESTRICT occupied_rows = occupied.Row(0);
656     size_t occupied_stride = occupied.PixelsPerRow();
657 
658     bool success = true;
659     // For every patch...
660     for (size_t patch = 0; patch < info.size(); patch++) {
661       size_t x0 = 0;
662       size_t y0 = 0;
663       size_t xsize = info[patch].first.xsize;
664       size_t ysize = info[patch].first.ysize;
665       bool found = false;
666       // For every possible start position ...
667       for (; y0 + ysize <= ref_ysize; y0++) {
668         x0 = 0;
669         for (; x0 + xsize <= ref_xsize; x0++) {
670           bool has_occupied_pixel = false;
671           size_t x = x0;
672           // Check if it is possible to place the patch in this position in the
673           // reference frame.
674           for (size_t y = y0; y < y0 + ysize; y++) {
675             x = x0;
676             for (; x < x0 + xsize; x++) {
677               if (occupied_rows[y * occupied_stride + x]) {
678                 has_occupied_pixel = true;
679                 break;
680               }
681             }
682           }  // end of positioning check
683           if (!has_occupied_pixel) {
684             found = true;
685             break;
686           }
687           x0 = x;  // Jump to next pixel after the occupied one.
688         }
689         if (found) break;
690       }  // end of start position checking
691 
692       // We didn't find a possible position: repeat from the beginning with a
693       // larger reference frame size.
694       if (!found) {
695         success = false;
696         break;
697       }
698 
699       // We found a position: mark the corresponding positions in the reference
700       // image as used.
701       ref_positions[patch] = {x0, y0};
702       for (size_t y = y0; y < y0 + ysize; y++) {
703         for (size_t x = x0; x < x0 + xsize; x++) {
704           occupied_rows[y * occupied_stride + x] = true;
705         }
706       }
707       max_y = std::max(max_y, y0 + ysize);
708     }
709 
710     if (success) break;
711   } while (true);
712 
713   JXL_ASSERT(ref_ysize >= max_y);
714 
715   ref_ysize = max_y;
716 
717   Image3F reference_frame(ref_xsize, ref_ysize);
718   // TODO(veluca): figure out a better way to fill the image.
719   ZeroFillImage(&reference_frame);
720   std::vector<PatchPosition> positions;
721   float* JXL_RESTRICT ref_rows[3] = {
722       reference_frame.PlaneRow(0, 0),
723       reference_frame.PlaneRow(1, 0),
724       reference_frame.PlaneRow(2, 0),
725   };
726   size_t ref_stride = reference_frame.PixelsPerRow();
727 
728   for (size_t i = 0; i < info.size(); i++) {
729     PatchReferencePosition ref_pos;
730     ref_pos.xsize = info[i].first.xsize;
731     ref_pos.ysize = info[i].first.ysize;
732     ref_pos.x0 = ref_positions[i].first;
733     ref_pos.y0 = ref_positions[i].second;
734     ref_pos.ref = 0;
735     for (size_t y = 0; y < ref_pos.ysize; y++) {
736       for (size_t x = 0; x < ref_pos.xsize; x++) {
737         for (size_t c = 0; c < 3; c++) {
738           ref_rows[c][(y + ref_pos.y0) * ref_stride + x + ref_pos.x0] =
739               info[i].first.fpixels[c][y * ref_pos.xsize + x];
740         }
741       }
742     }
743     // Add color channels, ignore other channels.
744     std::vector<PatchBlending> blending_info(
745         state->shared.metadata->m.extra_channel_info.size() + 1,
746         PatchBlending{PatchBlendMode::kNone, 0, false});
747     blending_info[0].mode = PatchBlendMode::kAdd;
748     for (const auto& pos : info[i].second) {
749       positions.emplace_back(
750           PatchPosition{pos.first, pos.second, blending_info, ref_pos});
751     }
752   }
753 
754   CompressParams cparams = state->cparams;
755   cparams.resampling = 1;
756   cparams.ec_resampling = 1;
757   // Recursive application of patches could create very weird issues.
758   cparams.patches = Override::kOff;
759   cparams.dots = Override::kOff;
760   cparams.noise = Override::kOff;
761   cparams.modular_mode = true;
762   cparams.responsive = 0;
763   cparams.progressive_dc = 0;
764   cparams.progressive_mode = false;
765   cparams.qprogressive_mode = false;
766   // Use gradient predictor and not Predictor::Best.
767   cparams.options.predictor = Predictor::Gradient;
768   // TODO(veluca): possibly change heuristics here.
769   if (!cparams.modular_mode) {
770     cparams.quality_pair.first = cparams.quality_pair.second =
771         80 - cparams.butteraugli_distance * 12;
772   } else {
773     cparams.quality_pair.first = (100 + 3 * cparams.quality_pair.first) * 0.25f;
774     cparams.quality_pair.second =
775         (100 + 3 * cparams.quality_pair.second) * 0.25f;
776   }
777   FrameInfo patch_frame_info;
778   patch_frame_info.save_as_reference = 0;  // always saved.
779   patch_frame_info.frame_type = FrameType::kReferenceOnly;
780   patch_frame_info.save_before_color_transform = true;
781 
782   ImageBundle ib(&state->shared.metadata->m);
783   // TODO(veluca): metadata.color_encoding is a lie: ib is in XYB, but there is
784   // no simple way to express that yet.
785   patch_frame_info.ib_needs_color_transform = false;
786   patch_frame_info.save_as_reference = 0;
787   ib.SetFromImage(std::move(reference_frame),
788                   state->shared.metadata->m.color_encoding);
789   if (!ib.metadata()->extra_channel_info.empty()) {
790     // Add dummy extra channels to the patch image: patches do not yet support
791     // extra channels, but the codec expects that the amount of extra channels
792     // in frames matches that in the metadata of the codestream.
793     std::vector<ImageF> extra_channels;
794     extra_channels.reserve(ib.metadata()->extra_channel_info.size());
795     for (size_t i = 0; i < ib.metadata()->extra_channel_info.size(); i++) {
796       extra_channels.emplace_back(ib.xsize(), ib.ysize());
797       // Must initialize the image with data to not affect blending with
798       // uninitialized memory.
799       // TODO(lode): patches must copy and use the real extra channels instead.
800       FillImage(1.0f, &extra_channels.back());
801     }
802     ib.SetExtraChannels(std::move(extra_channels));
803   }
804 
805   PassesEncoderState roundtrip_state;
806   auto special_frame = std::unique_ptr<BitWriter>(new BitWriter());
807   JXL_CHECK(EncodeFrame(cparams, patch_frame_info, state->shared.metadata, ib,
808                         &roundtrip_state, pool, special_frame.get(), nullptr));
809   const Span<const uint8_t> encoded = special_frame->GetSpan();
810   state->special_frames.emplace_back(std::move(special_frame));
811   if (cparams.butteraugli_distance < kMinButteraugliToSubtractOriginalPatches) {
812     BitReader br(encoded);
813     ImageBundle decoded(&state->shared.metadata->m);
814     PassesDecoderState dec_state;
815     JXL_CHECK(dec_state.output_encoding_info.Set(state->shared.metadata->m));
816     JXL_CHECK(DecodeFrame({}, &dec_state, pool, &br, &decoded,
817                           *state->shared.metadata, /*constraints=*/nullptr));
818     JXL_CHECK(br.Close());
819     state->shared.reference_frames[0] =
820         std::move(dec_state.shared_storage.reference_frames[0]);
821   } else {
822     state->shared.reference_frames[0].storage = std::move(ib);
823   }
824   state->shared.reference_frames[0].frame =
825       &state->shared.reference_frames[0].storage;
826   // TODO(veluca): this assumes that applying patches is commutative, which is
827   // not true for all blending modes. This code only produces kAdd patches, so
828   // this works out.
829   std::sort(positions.begin(), positions.end());
830   PatchDictionaryEncoder::SetPositions(&state->shared.image_features.patches,
831                                        std::move(positions));
832 }
833 }  // namespace jxl
834