1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_cluster.h"
7 
8 #include <algorithm>
9 #include <cmath>
10 #include <limits>
11 #include <map>
12 #include <memory>
13 #include <numeric>
14 #include <queue>
15 #include <tuple>
16 
17 #undef HWY_TARGET_INCLUDE
18 #define HWY_TARGET_INCLUDE "lib/jxl/enc_cluster.cc"
19 #include <hwy/foreach_target.h>
20 #include <hwy/highway.h>
21 
22 #include "lib/jxl/ac_context.h"
23 #include "lib/jxl/base/profiler.h"
24 #include "lib/jxl/fast_math-inl.h"
25 HWY_BEFORE_NAMESPACE();
26 namespace jxl {
27 namespace HWY_NAMESPACE {
28 
29 template <class V>
Entropy(V count,V inv_total,V total)30 V Entropy(V count, V inv_total, V total) {
31   const HWY_CAPPED(float, Histogram::kRounding) d;
32   const auto zero = Set(d, 0.0f);
33   return IfThenZeroElse(count == total,
34                         zero - count * FastLog2f(d, inv_total * count));
35 }
36 
HistogramEntropy(const Histogram & a)37 void HistogramEntropy(const Histogram& a) {
38   a.entropy_ = 0.0f;
39   if (a.total_count_ == 0) return;
40 
41   const HWY_CAPPED(float, Histogram::kRounding) df;
42   const HWY_CAPPED(int32_t, Histogram::kRounding) di;
43 
44   const auto inv_tot = Set(df, 1.0f / a.total_count_);
45   auto entropy_lanes = Zero(df);
46   auto total = Set(df, a.total_count_);
47 
48   for (size_t i = 0; i < a.data_.size(); i += Lanes(di)) {
49     const auto counts = LoadU(di, &a.data_[i]);
50     entropy_lanes += Entropy(ConvertTo(df, counts), inv_tot, total);
51   }
52   a.entropy_ += GetLane(SumOfLanes(df, entropy_lanes));
53 }
54 
HistogramDistance(const Histogram & a,const Histogram & b)55 float HistogramDistance(const Histogram& a, const Histogram& b) {
56   if (a.total_count_ == 0 || b.total_count_ == 0) return 0;
57 
58   const HWY_CAPPED(float, Histogram::kRounding) df;
59   const HWY_CAPPED(int32_t, Histogram::kRounding) di;
60 
61   const auto inv_tot = Set(df, 1.0f / (a.total_count_ + b.total_count_));
62   auto distance_lanes = Zero(df);
63   auto total = Set(df, a.total_count_ + b.total_count_);
64 
65   for (size_t i = 0; i < std::max(a.data_.size(), b.data_.size());
66        i += Lanes(di)) {
67     const auto a_counts =
68         a.data_.size() > i ? LoadU(di, &a.data_[i]) : Zero(di);
69     const auto b_counts =
70         b.data_.size() > i ? LoadU(di, &b.data_[i]) : Zero(di);
71     const auto counts = ConvertTo(df, a_counts + b_counts);
72     distance_lanes += Entropy(counts, inv_tot, total);
73   }
74   const float total_distance = GetLane(SumOfLanes(df, distance_lanes));
75   return total_distance - a.entropy_ - b.entropy_;
76 }
77 
78 // First step of a k-means clustering with a fancy distance metric.
FastClusterHistograms(const std::vector<Histogram> & in,const size_t num_contexts_in,size_t max_histograms,float min_distance,std::vector<Histogram> * out,std::vector<uint32_t> * histogram_symbols)79 void FastClusterHistograms(const std::vector<Histogram>& in,
80                            const size_t num_contexts_in, size_t max_histograms,
81                            float min_distance, std::vector<Histogram>* out,
82                            std::vector<uint32_t>* histogram_symbols) {
83   PROFILER_FUNC;
84   size_t largest_idx = 0;
85   std::vector<uint32_t> nonempty_histograms;
86   nonempty_histograms.reserve(in.size());
87   for (size_t i = 0; i < num_contexts_in; i++) {
88     if (in[i].total_count_ == 0) continue;
89     HistogramEntropy(in[i]);
90     if (in[i].total_count_ > in[largest_idx].total_count_) {
91       largest_idx = i;
92     }
93     nonempty_histograms.push_back(i);
94   }
95   // No symbols.
96   if (nonempty_histograms.empty()) {
97     out->resize(1);
98     histogram_symbols->clear();
99     histogram_symbols->resize(in.size(), 0);
100     return;
101   }
102   largest_idx = std::find(nonempty_histograms.begin(),
103                           nonempty_histograms.end(), largest_idx) -
104                 nonempty_histograms.begin();
105   size_t num_contexts = nonempty_histograms.size();
106   out->clear();
107   out->reserve(max_histograms);
108   std::vector<float> dists(num_contexts, std::numeric_limits<float>::max());
109   histogram_symbols->clear();
110   histogram_symbols->resize(in.size(), max_histograms);
111 
112   while (out->size() < max_histograms && out->size() < num_contexts) {
113     (*histogram_symbols)[nonempty_histograms[largest_idx]] = out->size();
114     out->push_back(in[nonempty_histograms[largest_idx]]);
115     largest_idx = 0;
116     for (size_t i = 0; i < num_contexts; i++) {
117       dists[i] = std::min(
118           HistogramDistance(in[nonempty_histograms[i]], out->back()), dists[i]);
119       // Avoid repeating histograms
120       if ((*histogram_symbols)[nonempty_histograms[i]] != max_histograms) {
121         continue;
122       }
123       if (dists[i] > dists[largest_idx]) largest_idx = i;
124     }
125     if (dists[largest_idx] < min_distance) break;
126   }
127 
128   for (size_t i = 0; i < num_contexts_in; i++) {
129     if ((*histogram_symbols)[i] != max_histograms) continue;
130     if (in[i].total_count_ == 0) {
131       (*histogram_symbols)[i] = 0;
132       continue;
133     }
134     size_t best = 0;
135     float best_dist = HistogramDistance(in[i], (*out)[best]);
136     for (size_t j = 1; j < out->size(); j++) {
137       float dist = HistogramDistance(in[i], (*out)[j]);
138       if (dist < best_dist) {
139         best = j;
140         best_dist = dist;
141       }
142     }
143     (*out)[best].AddHistogram(in[i]);
144     HistogramEntropy((*out)[best]);
145     (*histogram_symbols)[i] = best;
146   }
147 }
148 
149 // NOLINTNEXTLINE(google-readability-namespace-comments)
150 }  // namespace HWY_NAMESPACE
151 }  // namespace jxl
152 HWY_AFTER_NAMESPACE();
153 
154 #if HWY_ONCE
155 namespace jxl {
156 HWY_EXPORT(FastClusterHistograms);  // Local function
157 HWY_EXPORT(HistogramEntropy);       // Local function
158 
ShannonEntropy() const159 float Histogram::ShannonEntropy() const {
160   HWY_DYNAMIC_DISPATCH(HistogramEntropy)(*this);
161   return entropy_;
162 }
163 
164 namespace {
165 // -----------------------------------------------------------------------------
166 // Histogram refinement
167 
168 // Reorder histograms in *out so that the new symbols in *symbols come in
169 // increasing order.
HistogramReindex(std::vector<Histogram> * out,std::vector<uint32_t> * symbols)170 void HistogramReindex(std::vector<Histogram>* out,
171                       std::vector<uint32_t>* symbols) {
172   std::vector<Histogram> tmp(*out);
173   std::map<int, int> new_index;
174   int next_index = 0;
175   for (uint32_t symbol : *symbols) {
176     if (new_index.find(symbol) == new_index.end()) {
177       new_index[symbol] = next_index;
178       (*out)[next_index] = tmp[symbol];
179       ++next_index;
180     }
181   }
182   out->resize(next_index);
183   for (uint32_t& symbol : *symbols) {
184     symbol = new_index[symbol];
185   }
186 }
187 
188 }  // namespace
189 
190 // Clusters similar histograms in 'in' together, the selected histograms are
191 // placed in 'out', and for each index in 'in', *histogram_symbols will
192 // indicate which of the 'out' histograms is the best approximation.
ClusterHistograms(const HistogramParams params,const std::vector<Histogram> & in,const size_t num_contexts,size_t max_histograms,std::vector<Histogram> * out,std::vector<uint32_t> * histogram_symbols)193 void ClusterHistograms(const HistogramParams params,
194                        const std::vector<Histogram>& in,
195                        const size_t num_contexts, size_t max_histograms,
196                        std::vector<Histogram>* out,
197                        std::vector<uint32_t>* histogram_symbols) {
198   constexpr float kMinDistanceForDistinctFast = 64.0f;
199   constexpr float kMinDistanceForDistinctBest = 16.0f;
200   max_histograms = std::min(max_histograms, params.max_histograms);
201   if (params.clustering == HistogramParams::ClusteringType::kFastest) {
202     HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
203     (in, num_contexts, 4, kMinDistanceForDistinctFast, out, histogram_symbols);
204   } else if (params.clustering == HistogramParams::ClusteringType::kFast) {
205     HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
206     (in, num_contexts, max_histograms, kMinDistanceForDistinctFast, out,
207      histogram_symbols);
208   } else {
209     PROFILER_FUNC;
210     HWY_DYNAMIC_DISPATCH(FastClusterHistograms)
211     (in, num_contexts, max_histograms, kMinDistanceForDistinctBest, out,
212      histogram_symbols);
213     for (size_t i = 0; i < out->size(); i++) {
214       (*out)[i].entropy_ =
215           ANSPopulationCost((*out)[i].data_.data(), (*out)[i].data_.size());
216     }
217     uint32_t next_version = 2;
218     std::vector<uint32_t> version(out->size(), 1);
219     std::vector<uint32_t> renumbering(out->size());
220     std::iota(renumbering.begin(), renumbering.end(), 0);
221 
222     // Try to pair up clusters if doing so reduces the total cost.
223 
224     struct HistogramPair {
225       // validity of a pair: p.version == max(version[i], version[j])
226       float cost;
227       uint32_t first;
228       uint32_t second;
229       uint32_t version;
230       // We use > because priority queues sort in *decreasing* order, but we
231       // want lower cost elements to appear first.
232       bool operator<(const HistogramPair& other) const {
233         return std::make_tuple(cost, first, second, version) >
234                std::make_tuple(other.cost, other.first, other.second,
235                                other.version);
236       }
237     };
238 
239     // Create list of all pairs by increasing merging cost.
240     std::priority_queue<HistogramPair> pairs_to_merge;
241     for (uint32_t i = 0; i < out->size(); i++) {
242       for (uint32_t j = i + 1; j < out->size(); j++) {
243         Histogram histo;
244         histo.AddHistogram((*out)[i]);
245         histo.AddHistogram((*out)[j]);
246         float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
247                      (*out)[i].entropy_ - (*out)[j].entropy_;
248         // Avoid enqueueing pairs that are not advantageous to merge.
249         if (cost >= 0) continue;
250         pairs_to_merge.push(
251             HistogramPair{cost, i, j, std::max(version[i], version[j])});
252       }
253     }
254 
255     // Merge the best pair to merge, add new pairs that get formed as a
256     // consequence.
257     while (!pairs_to_merge.empty()) {
258       uint32_t first = pairs_to_merge.top().first;
259       uint32_t second = pairs_to_merge.top().second;
260       uint32_t ver = pairs_to_merge.top().version;
261       pairs_to_merge.pop();
262       if (ver != std::max(version[first], version[second]) ||
263           version[first] == 0 || version[second] == 0) {
264         continue;
265       }
266       (*out)[first].AddHistogram((*out)[second]);
267       (*out)[first].entropy_ = ANSPopulationCost((*out)[first].data_.data(),
268                                                  (*out)[first].data_.size());
269       for (size_t i = 0; i < renumbering.size(); i++) {
270         if (renumbering[i] == second) {
271           renumbering[i] = first;
272         }
273       }
274       version[second] = 0;
275       version[first] = next_version++;
276       for (uint32_t j = 0; j < out->size(); j++) {
277         if (j == first) continue;
278         if (version[j] == 0) continue;
279         Histogram histo;
280         histo.AddHistogram((*out)[first]);
281         histo.AddHistogram((*out)[j]);
282         float cost = ANSPopulationCost(histo.data_.data(), histo.data_.size()) -
283                      (*out)[first].entropy_ - (*out)[j].entropy_;
284         // Avoid enqueueing pairs that are not advantageous to merge.
285         if (cost >= 0) continue;
286         pairs_to_merge.push(
287             HistogramPair{cost, std::min(first, j), std::max(first, j),
288                           std::max(version[first], version[j])});
289       }
290     }
291     std::vector<uint32_t> reverse_renumbering(out->size(), -1);
292     size_t num_alive = 0;
293     for (size_t i = 0; i < out->size(); i++) {
294       if (version[i] == 0) continue;
295       (*out)[num_alive++] = (*out)[i];
296       reverse_renumbering[i] = num_alive - 1;
297     }
298     out->resize(num_alive);
299     for (size_t i = 0; i < histogram_symbols->size(); i++) {
300       (*histogram_symbols)[i] =
301           reverse_renumbering[renumbering[(*histogram_symbols)[i]]];
302     }
303   }
304 
305   // Convert the context map to a canonical form.
306   HistogramReindex(out, histogram_symbols);
307 }
308 
309 }  // namespace jxl
310 #endif  // HWY_ONCE
311