1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_cluster.h"
7
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <map>
12 #include <memory>
13 #include <numeric>
14 #include <queue>
15 #include <tuple>
16
17 #undef HWY_TARGET_INCLUDE
18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
19 #include <hwy/foreach_target.h>
20 #include <hwy/highway.h>
21
22 #include "lib/jxl/ac_context.h"
23 #include "lib/jxl/base/profiler.h"
24 #include "lib/jxl/fast_math-inl.h"
25 HWY_BEFORE_NAMESPACE();
26 namespace jxl {
27 namespace HWY_NAMESPACE {
28
29 template <class V>
Entropy(V count,V inv_total,V total)30 V Entropy(V count, V inv_total, V total) {
31 const HWY_CAPPED(float, Histogram::kRounding) d;
32 const auto zero = Set(d, 0.0f);
33 return IfThenZeroElse(count == total,
34 zero - count * FastLog2f(d, inv_total * count));
35 }
36
HistogramEntropy(const Histogram & a)37 void HistogramEntropy(const Histogram& a) {
38 a.entropy_ = 0.0f;
39 if (a.total_count_ == 0) return;
40
41 const HWY_CAPPED(float, Histogram::kRounding) df;
42 const HWY_CAPPED(int32_t, Histogram::kRounding) di;
43
44 const auto inv_tot = Set(df, 1.0f / a.total_count_);
45 auto entropy_lanes = Zero(df);
46 auto total = Set(df, a.total_count_);
47
48 for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
49 const auto counts = LoadU(di, &a.data_[i]);
50 entropy_lanes += Entropy(ConvertTo(df, counts), inv_tot, total);
51 }
52 a.entropy_ += GetLane(SumOfLanes(entropy_lanes));
53 }
54
HistogramDistance(const Histogram & a,const Histogram & b)55 float HistogramDistance(const Histogram& a, const Histogram& b) {
56 if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
57
58 const HWY_CAPPED(float, Histogram::kRounding) df;
59 const HWY_CAPPED(int32_t, Histogram::kRounding) di;
60
61 const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
62 auto distance_lanes = Zero(df);
63 auto total = Set(df, a.total_count_ + b.total_count_);
64
65 for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
66 i += Lanes(di)) {
67 const auto a_counts =
68 a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
69 const auto b_counts =
70 b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
71 const auto counts = ConvertTo(df, a_counts + b_counts);
72 distance_lanes += Entropy(counts, inv_tot, total);
73 }
74 const float total_distance = GetLane(SumOfLanes(distance_lanes));
75 return total_distance - a.entropy_ - b.entropy_;
76 }
77
78 // First step of a k-means clustering with a fancy distance metric.
FastClusterHistograms(const std::vector<Histogram> & in,const size_t num_contexts,size_t max_histograms,float min_distance,std::vector<Histogram> * out,std::vector<uint32_t> * histogram_symbols)79 void FastClusterHistograms(const std::vector<Histogram>& in,
80 const size_t num_contexts, size_t max_histograms,
81 float min_distance, std::vector<Histogram>* out,
82 std::vector<uint32_t>* histogram_symbols) {
83 PROFILER_FUNC;
84 size_t largest_idx = 0;
85 for (size_t i = 0; i < num_contexts; i++) {
86 HistogramEntropy(in[i]);
87 if (in[i].total_count_ > in[largest_idx].total_count_) {
88 largest_idx = i;
89 }
90 }
91 out->clear();
92 out->reserve(max_histograms);
93 std::vector<float> dists(num_contexts, std::numeric_limits<float>::max());
94 histogram_symbols->clear();
95 histogram_symbols->resize(num_contexts, max_histograms);
96
97 while (out->size() < max_histograms && out->size() < num_contexts) {
98 (*histogram_symbols)[largest_idx] = out->size();
99 out->push_back(in[largest_idx]);
100 largest_idx = 0;
101 for (size_t i = 0; i < num_contexts; i++) {
102 dists[i] = std::min(HistogramDistance(in[i], out->back()), dists[i]);
103 // Avoid repeating histograms
104 if ((*histogram_symbols)[i] != max_histograms) continue;
105 if (dists[i] > dists[largest_idx]) largest_idx = i;
106 }
107 if (dists[largest_idx] < min_distance) break;
108 }
109
110 for (size_t i = 0; i < num_contexts; i++) {
111 if ((*histogram_symbols)[i] != max_histograms) continue;
112 size_t best = 0;
113 float best_dist = HistogramDistance(in[i], (*out)[best]);
114 for (size_t j = 1; j < out->size(); j++) {
115 float dist = HistogramDistance(in[i], (*out)[j]);
116 if (dist < best_dist) {
117 best = j;
118 best_dist = dist;
119 }
120 }
121 (*out)[best].AddHistogram(in[i]);
122 HistogramEntropy((*out)[best]);
123 (*histogram_symbols)[i] = best;
124 }
125 }
126
127 // NOLINTNEXTLINE(google-readability-namespace-comments)
128 } // namespace HWY_NAMESPACE
129 } // namespace jxl
130 HWY_AFTER_NAMESPACE();
131
132 #if HWY_ONCE
133 namespace jxl {
134 HWY_EXPORT(FastClusterHistograms); // Local function
135 HWY_EXPORT(HistogramEntropy); // Local function
136
ShannonEntropy() const137 float Histogram::ShannonEntropy() const {
138 HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
139 return entropy_;
140 }
141
142 namespace {
143 // -----------------------------------------------------------------------------
144 // Histogram refinement
145
146 // Reorder histograms in *out so that the new symbols in *symbols come in
147 // increasing order.
HistogramReindex(std::vector<Histogram> * out,std::vector<uint32_t> * symbols)148 void HistogramReindex(std::vector<Histogram>* out,
149 std::vector<uint32_t>* symbols) {
150 std::vector<Histogram> tmp(*out);
151 std::map<int, int> new_index;
152 int next_index = 0;
153 for (uint32_t symbol : *symbols) {
154 if (new_index.find(symbol) == new_index.end()) {
155 new_index[symbol] = next_index;
156 (*out)[next_index] = tmp[symbol];
157 ++next_index;
158 }
159 }
160 out->resize(next_index);
161 for (uint32_t& symbol : *symbols) {
162 symbol = new_index[symbol];
163 }
164 }
165
166 } // namespace
167
168 // Clusters similar histograms in 'in' together, the selected histograms are
169 // placed in 'out', and for each index in 'in', *histogram_symbols will
170 // indicate which of the 'out' histograms is the best approximation.
ClusterHistograms(const HistogramParams params,const std::vector<Histogram> & in,const size_t num_contexts,size_t max_histograms,std::vector<Histogram> * out,std::vector<uint32_t> * histogram_symbols)171 void ClusterHistograms(const HistogramParams params,
172 const std::vector<Histogram>& in,
173 const size_t num_contexts, size_t max_histograms,
174 std::vector<Histogram>* out,
175 std::vector<uint32_t>* histogram_symbols) {
176 constexpr float kMinDistanceForDistinctFast = 64.0f;
177 constexpr float kMinDistanceForDistinctBest = 16.0f;
178 max_histograms = std::min(max_histograms, params.max_histograms);
179 if (params.clustering == HistogramParams::ClusteringType::kFastest) {
180 HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
181 (in, num_contexts, 4, kMinDistanceForDistinctFast, out, histogram_symbols);
182 } else if (params.clustering == HistogramParams::ClusteringType::kFast) {
183 HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
184 (in, num_contexts, max_histograms, kMinDistanceForDistinctFast, out,
185 histogram_symbols);
186 } else {
187 PROFILER_FUNC;
188 HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
189 (in, num_contexts, max_histograms, kMinDistanceForDistinctBest, out,
190 histogram_symbols);
191 for (size_t i = 0; i < out->size(); i++) {
192 (*out)[i].entropy_ =
193 ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size());
194 }
195 uint32_t next_version = 2;
196 std::vector<uint32_t> version(out->size(), 1);
197 std::vector<uint32_t> renumbering(out->size());
198 std::iota(renumbering.begin(), renumbering.end(), 0);
199
200 // Try to pair up clusters if doing so reduces the total cost.
201
202 struct HistogramPair {
203 // validity of a pair: p.version == max(version[i], version[j])
204 float cost;
205 uint32_t first;
206 uint32_t second;
207 uint32_t version;
208 // We use > because priority queues sort in *decreasing* order, but we
209 // want lower cost elements to appear first.
210 bool operator<(const HistogramPair& other) const {
211 return std::make_tuple(cost, first, second, version) >
212 std::make_tuple(other.cost, other.first, other.second,
213 other.version);
214 }
215 };
216
217 // Create list of all pairs by increasing merging cost.
218 std::priority_queue<HistogramPair> pairs_to_merge;
219 for (uint32_t i = 0; i < out->size(); i++) {
220 for (uint32_t j = i + 1; j < out->size(); j++) {
221 Histogram histo;
222 histo.AddHistogram((*out)[i]);
223 histo.AddHistogram((*out)[j]);
224 float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
225 (*out)[i].entropy_ - (*out)[j].entropy_;
226 // Avoid enqueueing pairs that are not advantageous to merge.
227 if (cost >= 0) continue;
228 pairs_to_merge.push(
229 HistogramPair{cost, i, j, std::max(version[i], version[j])});
230 }
231 }
232
233 // Merge the best pair to merge, add new pairs that get formed as a
234 // consequence.
235 while (!pairs_to_merge.empty()) {
236 uint32_t first = pairs_to_merge.top().first;
237 uint32_t second = pairs_to_merge.top().second;
238 uint32_t ver = pairs_to_merge.top().version;
239 pairs_to_merge.pop();
240 if (ver != std::max(version[first], version[second]) ||
241 version[first] == 0 || version[second] == 0) {
242 continue;
243 }
244 (*out)[first].AddHistogram((*out)[second]);
245 (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(),
246 (*out)[first].data_.size());
247 for (size_t i = 0; i < renumbering.size(); i++) {
248 if (renumbering[i] == second) {
249 renumbering[i] = first;
250 }
251 }
252 version[second] = 0;
253 version[first] = next_version++;
254 for (uint32_t j = 0; j < out->size(); j++) {
255 if (j == first) continue;
256 if (version[j] == 0) continue;
257 Histogram histo;
258 histo.AddHistogram((*out)[first]);
259 histo.AddHistogram((*out)[j]);
260 float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
261 (*out)[first].entropy_ - (*out)[j].entropy_;
262 // Avoid enqueueing pairs that are not advantageous to merge.
263 if (cost >= 0) continue;
264 pairs_to_merge.push(
265 HistogramPair{cost, std::min(first, j), std::max(first, j),
266 std::max(version[first], version[j])});
267 }
268 }
269 std::vector<uint32_t> reverse_renumbering(out->size(), -1);
270 size_t num_alive = 0;
271 for (size_t i = 0; i < out->size(); i++) {
272 if (version[i] == 0) continue;
273 (*out)[num_alive++] = (*out)[i];
274 reverse_renumbering[i] = num_alive - 1;
275 }
276 out->resize(num_alive);
277 for (size_t i = 0; i < histogram_symbols->size(); i++) {
278 (*histogram_symbols)[i] =
279 reverse_renumbering[renumbering[(*histogram_symbols)[i]]];
280 }
281 }
282
283 // Convert the context map to a canonical form.
284 HistogramReindex(out, histogram_symbols);
285 }
286
287 } // namespace jxl
288 #endif // HWY_ONCE
289