1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include <math.h>
7 #include <stdint.h>
8 #include <stdio.h>
9 #include <stdlib.h>
10 #include <string.h>
11 
12 #include <algorithm>
13 #include <memory>
14 #include <mutex>
15 #include <numeric>
16 #include <random>
17 #include <string>
18 #include <utility>
19 #include <vector>
20 
21 #include "jxl/decode.h"
22 #include "lib/extras/codec.h"
23 #include "lib/extras/codec_png.h"
24 #include "lib/extras/color_hints.h"
25 #include "lib/extras/time.h"
26 #include "lib/jxl/alpha.h"
27 #include "lib/jxl/base/cache_aligned.h"
28 #include "lib/jxl/base/compiler_specific.h"
29 #include "lib/jxl/base/data_parallel.h"
30 #include "lib/jxl/base/file_io.h"
31 #include "lib/jxl/base/padded_bytes.h"
32 #include "lib/jxl/base/profiler.h"
33 #include "lib/jxl/base/robust_statistics.h"
34 #include "lib/jxl/base/span.h"
35 #include "lib/jxl/base/status.h"
36 #include "lib/jxl/base/thread_pool_internal.h"
37 #include "lib/jxl/codec_in_out.h"
38 #include "lib/jxl/color_encoding_internal.h"
39 #include "lib/jxl/color_management.h"
40 #include "lib/jxl/enc_butteraugli_comparator.h"
41 #include "lib/jxl/enc_butteraugli_pnorm.h"
42 #include "lib/jxl/image.h"
43 #include "lib/jxl/image_bundle.h"
44 #include "lib/jxl/image_ops.h"
45 #include "tools/benchmark/benchmark_args.h"
46 #include "tools/benchmark/benchmark_codec.h"
47 #include "tools/benchmark/benchmark_file_io.h"
48 #include "tools/benchmark/benchmark_stats.h"
49 #include "tools/benchmark/benchmark_utils.h"
50 #include "tools/codec_config.h"
51 #include "tools/cpu/cpu.h"
52 #include "tools/cpu/os_specific.h"
53 #include "tools/speed_stats.h"
54 
55 namespace jxl {
56 namespace {
57 
WritePNG(Image3F && image,ThreadPool * pool,const std::string & filename)58 Status WritePNG(Image3F&& image, ThreadPool* pool,
59                 const std::string& filename) {
60   CodecInOut io;
61   io.metadata.m.SetUintSamples(8);
62   io.metadata.m.color_encoding = ColorEncoding::SRGB();
63   io.SetFromImage(std::move(image), io.metadata.m.color_encoding);
64   PaddedBytes compressed;
65   JXL_CHECK(
66       extras::EncodeImagePNG(&io, io.Main().c_current(), 8, pool, &compressed));
67   return WriteFile(compressed, filename);
68 }
69 
ReadPNG(const std::string & filename,Image3F * image)70 Status ReadPNG(const std::string& filename, Image3F* image) {
71   CodecInOut io;
72   JXL_CHECK(SetFromFile(filename, ColorHints(), &io));
73   *image = CopyImage(*io.Main().color());
74   return true;
75 }
76 
DoCompress(const std::string & filename,const CodecInOut & io,const std::vector<std::string> & extra_metrics_commands,ImageCodec * codec,ThreadPoolInternal * inner_pool,PaddedBytes * compressed,BenchmarkStats * s)77 void DoCompress(const std::string& filename, const CodecInOut& io,
78                 const std::vector<std::string>& extra_metrics_commands,
79                 ImageCodec* codec, ThreadPoolInternal* inner_pool,
80                 PaddedBytes* compressed, BenchmarkStats* s) {
81   PROFILER_FUNC;
82   ++s->total_input_files;
83 
84   if (io.frames.size() != 1) {
85     // Multiple frames not supported (io.xsize() will checkfail)
86     s->total_errors++;
87     if (!Args()->silent_errors) {
88       JXL_WARNING("multiframe input image not supported %s", filename.c_str());
89     }
90     return;
91   }
92   const size_t xsize = io.xsize();
93   const size_t ysize = io.ysize();
94   const size_t input_pixels = xsize * ysize;
95 
96   jpegxl::tools::SpeedStats speed_stats;
97   jpegxl::tools::SpeedStats::Summary summary;
98 
99   bool valid = true;  // false if roundtrip, encoding or decoding errors occur.
100 
101   if (!Args()->decode_only && (io.xsize() == 0 || io.ysize() == 0)) {
102     // This means the benchmark couldn't load the image, e.g. due to invalid
103     // ICC profile. Warning message about that was already printed. Continue
104     // this function to indicate it as error in the stats.
105     valid = false;
106   }
107 
108   std::string ext = FileExtension(filename);
109   if (valid && !Args()->decode_only) {
110     for (size_t i = 0; i < Args()->encode_reps; ++i) {
111       if (codec->CanRecompressJpeg() && (ext == ".jpg" || ext == ".jpeg")) {
112         std::string data_in;
113         JXL_CHECK(ReadFile(filename, &data_in));
114         JXL_CHECK(
115             codec->RecompressJpeg(filename, data_in, compressed, &speed_stats));
116       } else {
117         Status status = codec->Compress(filename, &io, inner_pool, compressed,
118                                         &speed_stats);
119         if (!status) {
120           valid = false;
121           if (!Args()->silent_errors) {
122             std::string message = codec->GetErrorMessage();
123             if (!message.empty()) {
124               fprintf(stderr, "Error in %s codec: %s\n",
125                       codec->description().c_str(), message.c_str());
126             } else {
127               fprintf(stderr, "Error in %s codec\n",
128                       codec->description().c_str());
129             }
130           }
131         }
132       }
133     }
134     JXL_CHECK(speed_stats.GetSummary(&summary));
135     s->total_time_encode += summary.central_tendency;
136   }
137 
138   if (valid && Args()->decode_only) {
139     std::string data_in;
140     JXL_CHECK(ReadFile(filename, &data_in));
141     compressed->append((uint8_t*)data_in.data(),
142                        (uint8_t*)data_in.data() + data_in.size());
143   }
144 
145   // Decompress
146   CodecInOut io2;
147   io2.metadata.m = io.metadata.m;
148   if (valid) {
149     speed_stats = jpegxl::tools::SpeedStats();
150     for (size_t i = 0; i < Args()->decode_reps; ++i) {
151       if (!codec->Decompress(filename, Span<const uint8_t>(*compressed),
152                              inner_pool, &io2, &speed_stats)) {
153         if (!Args()->silent_errors) {
154           fprintf(stderr,
155                   "%s failed to decompress encoded image. Original source:"
156                   " %s\n",
157                   codec->description().c_str(), filename.c_str());
158         }
159         valid = false;
160       }
161 
162       // io2.dec_pixels increases each time, but the total should be independent
163       // of decode_reps, so only take the value from the first iteration.
164       if (i == 0) s->total_input_pixels += io2.dec_pixels;
165     }
166     JXL_CHECK(speed_stats.GetSummary(&summary));
167     s->total_time_decode += summary.central_tendency;
168   }
169 
170   std::string name = FileBaseName(filename);
171   std::string codec_name = codec->description();
172 
173   if (!valid) {
174     s->total_errors++;
175   }
176 
177   if (io.frames.size() != io2.frames.size()) {
178     if (!Args()->silent_errors) {
179       // Animated gifs not supported yet?
180       fprintf(stderr,
181               "Frame sizes not equal, is this an animated gif? %s %s %zu %zu\n",
182               codec_name.c_str(), name.c_str(), io.frames.size(),
183               io2.frames.size());
184     }
185     valid = false;
186   }
187 
188   bool lossless = codec->IsJpegTranscoder();
189   bool skip_butteraugli =
190       Args()->skip_butteraugli || Args()->decode_only || lossless;
191   ImageF distmap;
192   float max_distance = 1.0f;
193 
194   if (valid && !skip_butteraugli) {
195     JXL_ASSERT(io.frames.size() == io2.frames.size());
196     for (size_t i = 0; i < io.frames.size(); i++) {
197       const ImageBundle& ib1 = io.frames[i];
198       ImageBundle& ib2 = io2.frames[i];
199 
200       // Verify output
201       PROFILER_ZONE("Benchmark stats");
202       float distance;
203       if (SameSize(ib1, ib2)) {
204         ButteraugliParams params = codec->BaParams();
205         if (ib1.metadata()->IntensityTarget() !=
206             ib2.metadata()->IntensityTarget()) {
207           fprintf(stderr,
208                   "WARNING: input and output images have different intensity "
209                   "targets");
210         }
211         params.intensity_target = ib1.metadata()->IntensityTarget();
212         // Hack the default intensity target value to be 80.0, the intensity
213         // target of sRGB images and a more reasonable viewing default than
214         // JPEG XL file format's default.
215         if (fabs(params.intensity_target - 255.0f) < 1e-3) {
216           params.intensity_target = 80.0;
217         }
218         distance = ButteraugliDistance(ib1, ib2, params, &distmap, inner_pool);
219         // Ensure pixels in range 0-1
220         s->distance_2 += ComputeDistance2(ib1, ib2);
221       } else {
222         // TODO(veluca): re-upsample and compute proper distance.
223         distance = 1e+4f;
224         distmap = ImageF(1, 1);
225         distmap.Row(0)[0] = distance;
226         s->distance_2 += distance;
227       }
228       // Update stats
229       s->distance_p_norm +=
230           ComputeDistanceP(distmap, Args()->ba_params, Args()->error_pnorm) *
231           input_pixels;
232       s->max_distance = std::max(s->max_distance, distance);
233       s->distances.push_back(distance);
234       max_distance = std::max(max_distance, distance);
235     }
236   }
237 
238   s->total_compressed_size += compressed->size();
239   s->total_adj_compressed_size += compressed->size() * max_distance;
240   codec->GetMoreStats(s);
241 
242   if (io2.frames.size() == 1 &&
243       (Args()->save_compressed || Args()->save_decompressed)) {
244     JXL_ASSERT(io2.frames.size() == 1);
245     ImageBundle& ib2 = io2.Main();
246 
247     // By default the benchmark will save the image after roundtrip with the
248     // same color encoding as the image before roundtrip. Not all codecs
249     // necessarily preserve the amount of channels (1 for gray, 3 for RGB)
250     // though, since not all image formats necessarily allow a way to remember
251     // what amount of channels you happened to give the benchmark codec
252     // input (say, an RGB-only format) and that is fine since in the end what
253     // matters is that the pixels look the same on a 3-channel RGB monitor
254     // while using grayscale encoding is an internal compression optimization.
255     // If that is the case, output with the current color model instead,
256     // because CodecInOut does not automatically convert between 1 or 3
257     // channels, and giving a ColorEncoding  with a different amount of
258     // channels is not allowed.
259     const ColorEncoding* c_desired =
260         (ib2.metadata()->color_encoding.Channels() ==
261          ib2.c_current().Channels())
262             ? &ib2.metadata()->color_encoding
263             : &ib2.c_current();
264     // Allow overriding via --output_encoding.
265     if (!Args()->output_description.empty()) {
266       c_desired = &Args()->output_encoding;
267     }
268 
269     std::string dir = FileDirName(filename);
270     std::string outdir =
271         Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir;
272     // Make compatible for filename
273     std::replace(codec_name.begin(), codec_name.end(), ':', '_');
274     std::string compressed_fn = outdir + "/" + name + "." + codec_name;
275     std::string decompressed_fn = compressed_fn + Args()->output_extension;
276     std::string heatmap_fn = compressed_fn + ".heatmap.png";
277     JXL_CHECK(MakeDir(outdir));
278     if (Args()->save_compressed) {
279       std::string compressed_str(
280           reinterpret_cast<const char*>(compressed->data()),
281           compressed->size());
282       JXL_CHECK(WriteFile(compressed_str, compressed_fn));
283     }
284     if (Args()->save_decompressed && valid) {
285       // For verifying HDR: scale output.
286       if (Args()->mul_output != 0.0) {
287         fprintf(stderr, "WARNING: scaling outputs by %f\n", Args()->mul_output);
288         JXL_CHECK(ib2.TransformTo(ColorEncoding::LinearSRGB(ib2.IsGray()),
289                                   inner_pool));
290         ScaleImage(static_cast<float>(Args()->mul_output), ib2.color());
291       }
292 
293       JXL_CHECK(EncodeToFile(io2, *c_desired,
294                              ib2.metadata()->bit_depth.bits_per_sample,
295                              decompressed_fn));
296       if (!skip_butteraugli) {
297         float good = Args()->heatmap_good > 0.0f ? Args()->heatmap_good
298                                                  : ButteraugliFuzzyInverse(1.5);
299         float bad = Args()->heatmap_bad > 0.0f ? Args()->heatmap_bad
300                                                : ButteraugliFuzzyInverse(0.5);
301         JXL_CHECK(WritePNG(CreateHeatMapImage(distmap, good, bad), inner_pool,
302                            heatmap_fn));
303       }
304     }
305   }
306   if (!extra_metrics_commands.empty()) {
307     CodecInOut in_copy;
308     in_copy.SetFromImage(std::move(*io.Main().Copy().color()),
309                          io.Main().c_current());
310     TemporaryFile tmp_in("original", "pfm");
311     TemporaryFile tmp_out("decoded", "pfm");
312     TemporaryFile tmp_res("result", "txt");
313     std::string tmp_in_fn, tmp_out_fn, tmp_res_fn;
314     JXL_CHECK(tmp_in.GetFileName(&tmp_in_fn));
315     JXL_CHECK(tmp_out.GetFileName(&tmp_out_fn));
316     JXL_CHECK(tmp_res.GetFileName(&tmp_res_fn));
317 
318     // Convert everything to non-linear SRGB - this is what most metrics expect.
319     const ColorEncoding& c_desired = ColorEncoding::SRGB(io.Main().IsGray());
320     JXL_CHECK(EncodeToFile(io, c_desired,
321                            io.metadata.m.bit_depth.bits_per_sample, tmp_in_fn));
322     JXL_CHECK(EncodeToFile(
323         io2, c_desired, io.metadata.m.bit_depth.bits_per_sample, tmp_out_fn));
324     if (io.metadata.m.IntensityTarget() != io2.metadata.m.IntensityTarget()) {
325       fprintf(stderr,
326               "WARNING: original and decoded have different intensity targets "
327               "(%f vs. %f).\n",
328               io.metadata.m.IntensityTarget(),
329               io2.metadata.m.IntensityTarget());
330     }
331     std::string intensity_target;
332     {
333       std::ostringstream intensity_target_oss;
334       intensity_target_oss << io.metadata.m.IntensityTarget();
335       intensity_target = intensity_target_oss.str();
336     }
337     for (size_t i = 0; i < extra_metrics_commands.size(); i++) {
338       float res = nanf("");
339       bool error = false;
340       if (RunCommand(extra_metrics_commands[i],
341                      {tmp_in_fn, tmp_out_fn, tmp_res_fn, intensity_target})) {
342         FILE* f = fopen(tmp_res_fn.c_str(), "r");
343         if (fscanf(f, "%f", &res) != 1) {
344           error = true;
345         }
346         fclose(f);
347       } else {
348         error = true;
349       }
350       if (error) {
351         fprintf(stderr,
352                 "WARNING: Computation of metric with command %s failed\n",
353                 extra_metrics_commands[i].c_str());
354       }
355       s->extra_metrics.push_back(res);
356     }
357   }
358 
359   if (Args()->show_progress) {
360     fprintf(stderr, ".");
361     fflush(stderr);
362   }
363 }
364 
365 // Makes a base64 data URI for embedded image in HTML
Base64Image(const std::string & filename)366 std::string Base64Image(const std::string& filename) {
367   PaddedBytes bytes;
368   if (!ReadFile(filename, &bytes)) {
369     return "";
370   }
371   static const char* symbols =
372       "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
373   std::string result;
374   for (size_t i = 0; i < bytes.size(); i += 3) {
375     uint8_t o0 = bytes[i + 0];
376     uint8_t o1 = (i + 1 < bytes.size()) ? bytes[i + 1] : 0;
377     uint8_t o2 = (i + 2 < bytes.size()) ? bytes[i + 2] : 0;
378     uint32_t value = (o0 << 16) | (o1 << 8) | o2;
379     for (size_t j = 0; j < 4; j++) {
380       result += (i + j <= bytes.size()) ? symbols[(value >> (6 * (3 - j))) & 63]
381                                         : '=';
382     }
383   }
384   // NOTE: Chrome supports max 2MB of data this way for URLs, but appears to
385   // support larger images anyway as long as it's embedded in the HTML file
386   // itself. If more data is needed, use createObjectURL.
387   return "data:image;base64," + result;
388 }
389 
390 struct Task {
391   ImageCodecPtr codec;
392   size_t idx_image;
393   size_t idx_method;
394   const CodecInOut* image;
395   BenchmarkStats stats;
396 };
397 
WriteHtmlReport(const std::string & codec_desc,const std::vector<std::string> & fnames,const std::vector<const Task * > & tasks,const std::vector<const CodecInOut * > & images,bool self_contained)398 void WriteHtmlReport(const std::string& codec_desc,
399                      const std::vector<std::string>& fnames,
400                      const std::vector<const Task*>& tasks,
401                      const std::vector<const CodecInOut*>& images,
402                      bool self_contained) {
403   std::string toggle_js =
404       "<script type=\"text/javascript\">\n"
405       "  var codecname = '" +
406       codec_desc + "';\n";
407   toggle_js += R"(
408   var maintitle = codecname + ' - click images to toggle, press space to' +
409       ' toggle all, h to toggle all heatmaps. Zoom in with CTRL+wheel or' +
410       ' CTRL+plus.';
411   document.title = maintitle;
412   var counter = [];
413   function setState(i, s) {
414     var preview = document.getElementById("preview" + i);
415     var orig = document.getElementById("orig" + i);
416     var hm = document.getElementById("hm" + i);
417     if (s == 0) {
418       preview.style.display = 'none';
419       orig.style.display = 'block';
420       hm.style.display = 'none';
421     } else if (s == 1) {
422       preview.style.display = 'block';
423       orig.style.display = 'none';
424       hm.style.display = 'none';
425     } else if (s == 2) {
426       preview.style.display = 'none';
427       orig.style.display = 'none';
428       hm.style.display = 'block';
429     }
430   }
431   function toggle3(i) {
432     for (index = counter.length; index <= i; index++) {
433       counter.push(1);
434     }
435     setState(i, counter[i]);
436     counter[i] = (counter[i] + 1) % 3;
437     document.title = maintitle;
438   }
439   var toggleall_state = 1;
440   document.body.onkeydown = function(e) {
441     // space (32) to toggle orig/compr, 'h' (72) to toggle heatmap/compr
442     if (e.keyCode == 32 || e.keyCode == 72) {
443       var divs = document.getElementsByTagName('div');
444       var key_state = (e.keyCode == 32) ? 0 : 2;
445       toggleall_state = (toggleall_state == key_state) ? 1 : key_state;
446       document.title = codecname + ' - ' + (toggleall_state == 0 ?
447           'originals' : (toggleall_state == 1 ? 'compressed' : 'heatmaps'));
448       for (var i = 0; i < divs.length; i++) {
449         setState(i, toggleall_state);
450       }
451       return false;
452     }
453   };
454 </script>
455 )";
456   std::string out_html;
457   std::string outdir;
458   std::string toggleall = "function(e) {if(e.keyCode == 32 { ";
459   out_html += "<body bgcolor=\"#000\">\n";
460   out_html += "<style>img { image-rendering: pixelated; }</style>\n";
461   std::string codec_name = codec_desc;
462   // Make compatible for filename
463   std::replace(codec_name.begin(), codec_name.end(), ':', '_');
464   for (size_t i = 0; i < fnames.size(); ++i) {
465     std::string name = FileBaseName(fnames[i]);
466     std::string dir = FileDirName(fnames[i]);
467     outdir = Args()->output_dir.empty() ? dir + "/out" : Args()->output_dir;
468     std::string name_out = name + "." + codec_name + Args()->output_extension;
469     std::string heatmap_out = name + "." + codec_name + ".heatmap.png";
470 
471     std::string fname_orig = fnames[i];
472     std::string fname_out = outdir + "/" + name_out;
473     std::string fname_heatmap = outdir + "/" + heatmap_out;
474     std::string url_orig = Args()->originals_url.empty()
475                                ? ("file://" + fnames[i])
476                                : (Args()->originals_url + "/" + name);
477     std::string url_out = name_out;
478     std::string url_heatmap = heatmap_out;
479     if (self_contained) {
480       url_orig = Base64Image(fname_orig);
481       url_out = Base64Image(fname_out);
482       url_heatmap = Base64Image(fname_heatmap);
483     }
484     std::string number = StringPrintf("%zu", i);
485     const CodecInOut& image = *images[i];
486     size_t xsize = image.frames.size() == 1 ? image.xsize() : 0;
487     size_t ysize = image.frames.size() == 1 ? image.ysize() : 0;
488     std::string html_width = StringPrintf("%zupx", xsize);
489     std::string html_height = StringPrintf("%zupx", ysize);
490     double bpp = tasks[i]->stats.total_compressed_size * 8.0 /
491                  tasks[i]->stats.total_input_pixels;
492     double pnorm =
493         tasks[i]->stats.distance_p_norm / tasks[i]->stats.total_input_pixels;
494     double max_dist = tasks[i]->stats.max_distance;
495     std::string compressed_title = StringPrintf(
496         "compressed. bpp: %f, pnorm: %f, max dist: %f", bpp, pnorm, max_dist);
497     out_html += "<div onclick=\"toggle3(" + number +
498                 ");\" style=\"display:inline-block;width:" + html_width +
499                 ";height:" + html_height +
500                 ";\">\n"
501                 "  <img title=\"" +
502                 compressed_title + "\" id=\"preview" + number + "\" src=";
503     out_html += "\"" + url_out + "\"";
504     out_html +=
505         " style=\"display:block;\"/>\n"
506         "  <img title=\"original\" id=\"orig" +
507         number + "\" src=";
508     out_html += "\"" + url_orig + "\"";
509     out_html +=
510         " style=\"display:none;\"/>\n"
511         "  <img title=\"heatmap\" id=\"hm" +
512         number + "\" src=";
513     out_html += "\"" + url_heatmap + "\"";
514     out_html += " style=\"display:none;\"/>\n</div>\n";
515   }
516   out_html += "</body>\n";
517   out_html += toggle_js;
518   JXL_CHECK(WriteFile(out_html, outdir + "/index." + codec_name + ".html"));
519 }
520 
521 // Prints the detailed and aggregate statistics, in the correct order but as
522 // soon as possible when multithreaded tasks are done.
523 struct StatPrinter {
StatPrinterjxl::__anone21a92ee0111::StatPrinter524   StatPrinter(const std::vector<std::string>& methods,
525               const std::vector<std::string>& extra_metrics_names,
526               const std::vector<std::string>& fnames,
527               const std::vector<Task>& tasks)
528       : methods_(&methods),
529         extra_metrics_names_(&extra_metrics_names),
530         fnames_(&fnames),
531         tasks_(&tasks),
532         tasks_done_(0),
533         stats_printed_(0),
534         details_printed_(0) {
535     stats_done_.resize(methods.size(), 0);
536     details_done_.resize(tasks.size(), 0);
537     max_fname_width_ = 0;
538     for (const auto& fname : fnames) {
539       max_fname_width_ = std::max(max_fname_width_, FileBaseName(fname).size());
540     }
541     max_method_width_ = 0;
542     for (const auto& method : methods) {
543       max_method_width_ =
544           std::max(max_method_width_, FileBaseName(method).size());
545     }
546   }
547 
TaskDonejxl::__anone21a92ee0111::StatPrinter548   void TaskDone(size_t task_index, const Task& t) {
549     PROFILER_FUNC;
550     std::lock_guard<std::mutex> guard(mutex);
551     tasks_done_++;
552     if (Args()->print_details || Args()->show_progress) {
553       if (Args()->print_details) {
554         // Render individual results as soon as they are ready and all previous
555         // ones in task order are ready.
556         details_done_[task_index] = 1;
557         if (task_index == details_printed_) {
558           while (details_printed_ < tasks_->size() &&
559                  details_done_[details_printed_]) {
560             PrintDetails((*tasks_)[details_printed_]);
561             details_printed_++;
562           }
563         }
564       }
565       // When using "show_progress" or "print_details", the table must be
566       // rendered at the very end, else the details or progress would be
567       // rendered in-between the table rows.
568       if (tasks_done_ == tasks_->size()) {
569         PrintStatsHeader();
570         for (size_t i = 0; i < methods_->size(); i++) {
571           PrintStats((*methods_)[i], i);
572         }
573         PrintStatsFooter();
574       }
575     } else {
576       if (tasks_done_ == 1) {
577         PrintStatsHeader();
578       }
579       // Render lines of the table as soon as it is ready and all previous
580       // lines have been printed.
581       stats_done_[t.idx_method]++;
582       if (stats_done_[t.idx_method] == fnames_->size() &&
583           t.idx_method == stats_printed_) {
584         while (stats_printed_ < stats_done_.size() &&
585                stats_done_[stats_printed_] == fnames_->size()) {
586           PrintStats((*methods_)[stats_printed_], stats_printed_);
587           stats_printed_++;
588         }
589       }
590       if (tasks_done_ == tasks_->size()) {
591         PrintStatsFooter();
592       }
593     }
594   }
595 
PrintDetailsjxl::__anone21a92ee0111::StatPrinter596   void PrintDetails(const Task& t) {
597     double comp_bpp =
598         t.stats.total_compressed_size * 8.0 / t.stats.total_input_pixels;
599     double p_norm = t.stats.distance_p_norm / t.stats.total_input_pixels;
600     double bpp_p_norm = p_norm * comp_bpp;
601 
602     const double adj_comp_bpp =
603         t.stats.total_adj_compressed_size * 8.0 / t.stats.total_input_pixels;
604 
605     const double rmse =
606         std::sqrt(t.stats.distance_2 / t.stats.total_input_pixels);
607     const double psnr = t.stats.total_compressed_size == 0 ? 0.0
608                         : (t.stats.distance_2 == 0)
609                             ? 99.99
610                             : (20 * std::log10(1 / rmse));
611     size_t pixels = t.stats.total_input_pixels;
612 
613     const double enc_mps =
614         t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_encode);
615     const double dec_mps =
616         t.stats.total_input_pixels / (1000000.0 * t.stats.total_time_decode);
617     if (Args()->print_details_csv) {
618       printf("%s,%s,%zd,%zd,%zd,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f,%.8f",
619              (*methods_)[t.idx_method].c_str(),
620              FileBaseName((*fnames_)[t.idx_image]).c_str(),
621              t.stats.total_errors, t.stats.total_compressed_size, pixels,
622              enc_mps, dec_mps, comp_bpp, t.stats.max_distance, psnr, p_norm,
623              bpp_p_norm, adj_comp_bpp);
624       for (float m : t.stats.extra_metrics) {
625         printf(",%.8f", m);
626       }
627       printf("\n");
628     } else {
629       printf("%s", (*methods_)[t.idx_method].c_str());
630       for (size_t i = (*methods_)[t.idx_method].size(); i <= max_method_width_;
631            i++) {
632         printf(" ");
633       }
634       printf("%s", FileBaseName((*fnames_)[t.idx_image]).c_str());
635       for (size_t i = FileBaseName((*fnames_)[t.idx_image]).size();
636            i <= max_fname_width_; i++) {
637         printf(" ");
638       }
639       printf(
640           "error:%zd    size:%8zd    pixels:%9zd    enc_speed:%8.8f"
641           "    dec_speed:%8.8f    bpp:%10.8f    dist:%10.8f"
642           "    psnr:%10.8f    p:%10.8f    bppp:%10.8f    qabpp:%10.8f ",
643           t.stats.total_errors, t.stats.total_compressed_size, pixels, enc_mps,
644           dec_mps, comp_bpp, t.stats.max_distance, psnr, p_norm, bpp_p_norm,
645           adj_comp_bpp);
646       for (size_t i = 0; i < t.stats.extra_metrics.size(); i++) {
647         printf(" %s:%.8f", (*extra_metrics_names_)[i].c_str(),
648                t.stats.extra_metrics[i]);
649       }
650       printf("\n");
651     }
652     fflush(stdout);
653   }
654 
PrintStatsjxl::__anone21a92ee0111::StatPrinter655   void PrintStats(const std::string& method, size_t idx_method) {
656     PROFILER_FUNC;
657     // Assimilate all tasks with the same idx_method.
658     BenchmarkStats method_stats;
659     std::vector<const CodecInOut*> images;
660     std::vector<const Task*> tasks;
661     for (const Task& t : *tasks_) {
662       if (t.idx_method == idx_method) {
663         method_stats.Assimilate(t.stats);
664         images.push_back(t.image);
665         tasks.push_back(&t);
666       }
667     }
668 
669     std::string out;
670 
671     method_stats.PrintMoreStats();  // not concurrent
672     out += method_stats.PrintLine(method, fnames_->size());
673 
674     if (Args()->write_html_report) {
675       WriteHtmlReport(method, *fnames_, tasks, images,
676                       Args()->html_report_self_contained);
677     }
678 
679     stats_aggregate_.push_back(
680         method_stats.ComputeColumns(method, fnames_->size()));
681 
682     printf("%s", out.c_str());
683     fflush(stdout);
684   }
685 
PrintStatsHeaderjxl::__anone21a92ee0111::StatPrinter686   void PrintStatsHeader() {
687     if (Args()->markdown) {
688       if (Args()->show_progress) {
689         fprintf(stderr, "\n");
690         fflush(stderr);
691       }
692       printf("```\n");
693     }
694     if (fnames_->size() == 1) printf("%s\n", (*fnames_)[0].c_str());
695     printf("%s", PrintHeader(*extra_metrics_names_).c_str());
696     fflush(stdout);
697   }
698 
PrintStatsFooterjxl::__anone21a92ee0111::StatPrinter699   void PrintStatsFooter() {
700     printf(
701         "%s",
702         PrintAggregate(extra_metrics_names_->size(), stats_aggregate_).c_str());
703     if (Args()->markdown) printf("```\n");
704     printf("\n");
705     fflush(stdout);
706   }
707 
708   const std::vector<std::string>* methods_;
709   const std::vector<std::string>* extra_metrics_names_;
710   const std::vector<std::string>* fnames_;
711   const std::vector<Task>* tasks_;
712 
713   size_t tasks_done_;
714 
715   size_t stats_printed_;
716   std::vector<size_t> stats_done_;
717 
718   size_t details_printed_;
719   std::vector<size_t> details_done_;
720 
721   size_t max_fname_width_;
722   size_t max_method_width_;
723 
724   std::vector<std::vector<ColumnValue>> stats_aggregate_;
725 
726   std::mutex mutex;
727 };
728 
729 class Benchmark {
730   using StringVec = std::vector<std::string>;
731 
732  public:
733   // Return the exit code of the program.
Run()734   static int Run() {
735     int ret = EXIT_SUCCESS;
736     {
737       PROFILER_FUNC;
738 
739       const StringVec methods = GetMethods();
740       const StringVec extra_metrics_names = GetExtraMetricsNames();
741       const StringVec extra_metrics_commands = GetExtraMetricsCommands();
742       const StringVec fnames = GetFilenames();
743       bool all_color_aware;
744       bool jpeg_transcoding_requested;
745       // (non-const because Task.stats are updated)
746       std::vector<Task> tasks = CreateTasks(methods, fnames, &all_color_aware,
747                                             &jpeg_transcoding_requested);
748 
749       std::unique_ptr<ThreadPoolInternal> pool;
750       std::vector<std::unique_ptr<ThreadPoolInternal>> inner_pools;
751       InitThreads(static_cast<int>(tasks.size()), &pool, &inner_pools);
752 
753       const std::vector<CodecInOut> loaded_images = LoadImages(
754           fnames, all_color_aware, jpeg_transcoding_requested, pool.get());
755 
756       if (RunTasks(methods, extra_metrics_names, extra_metrics_commands, fnames,
757                    loaded_images, pool.get(), inner_pools, &tasks) != 0) {
758         ret = EXIT_FAILURE;
759         if (!Args()->silent_errors) {
760           fprintf(stderr, "There were error(s) in the benchmark.\n");
761         }
762       }
763     }
764 
765     // Must have exited profiler zone above before calling.
766     if (Args()->profiler) {
767       PROFILER_PRINT_RESULTS();
768     }
769     CacheAligned::PrintStats();
770     return ret;
771   }
772 
773  private:
NumCores()774   static int NumCores() {
775     jpegxl::tools::cpu::ProcessorTopology topology;
776     JXL_CHECK(DetectProcessorTopology(&topology));
777     const int num_cores =
778         static_cast<int>(topology.packages * topology.cores_per_package);
779     JXL_CHECK(num_cores != 0);
780     return num_cores;
781   }
782 
NumOuterThreads(const int num_cores,const int num_tasks)783   static int NumOuterThreads(const int num_cores, const int num_tasks) {
784     int num_threads = Args()->num_threads;
785     // Default to #cores
786     if (num_threads < 0) num_threads = num_cores;
787 
788     // As a safety precaution, limit the number of threads to 4x the number of
789     // available CPUs.
790     num_threads =
791         std::min<int>(num_threads, 4 * std::thread::hardware_concurrency());
792 
793     // Don't create more threads than there are tasks (pointless/wasteful).
794     num_threads = std::min(num_threads, num_tasks);
795 
796     // Just one thread is counterproductive.
797     if (num_threads == 1) num_threads = 0;
798 
799     return num_threads;
800   }
801 
NumInnerThreads(const int num_cores,const int num_threads)802   static int NumInnerThreads(const int num_cores, const int num_threads) {
803     int num_inner = Args()->inner_threads;
804 
805     // Default: distribute remaining cores among tasks.
806     if (num_inner < 0) {
807       const int cores_for_outer = num_cores - num_threads;
808       num_inner = num_threads == 0 ? num_cores : cores_for_outer / num_threads;
809     }
810 
811     // Just one thread is counterproductive.
812     if (num_inner == 1) num_inner = 0;
813 
814     return num_inner;
815   }
816 
817   // Pins the first worker thread in pool to cpus[*next_index] etc.
818   // Not thread-safe (non-atomic update of next_index).
PinThreads(ThreadPoolInternal * pool,const std::vector<int> & cpus,size_t * next_index)819   static void PinThreads(ThreadPoolInternal* pool, const std::vector<int>& cpus,
820                          size_t* next_index) {
821     // No benefit to pinning if no actual worker threads.
822     if (pool->NumWorkerThreads() == 0) return;
823 
824     pool->RunOnEachThread([&](int /*task*/, const int thread) {
825       const size_t index = *next_index + static_cast<size_t>(thread);
826       if (index < cpus.size()) {
827         // printf("pin pool %p thread %3d to index %3zu = cpu %3d\n",
828         //        static_cast<void*>(pool), thread, index, cpus[index]);
829         if (!jpegxl::tools::cpu::PinThreadToCPU(cpus[index])) {
830           fprintf(stderr, "WARNING: failed to pin thread %d, next %zu.\n",
831                   thread, *next_index);
832         }
833       }
834     });
835     *next_index += pool->NumWorkerThreads();
836   }
837 
InitThreads(const int num_tasks,std::unique_ptr<ThreadPoolInternal> * pool,std::vector<std::unique_ptr<ThreadPoolInternal>> * inner_pools)838   static void InitThreads(
839       const int num_tasks, std::unique_ptr<ThreadPoolInternal>* pool,
840       std::vector<std::unique_ptr<ThreadPoolInternal>>* inner_pools) {
841     const int num_cores = NumCores();
842     const int num_threads = NumOuterThreads(num_cores, num_tasks);
843     const int num_inner = NumInnerThreads(num_cores, num_threads);
844 
845     fprintf(stderr, "%d cores, %d tasks, %d threads, %d inner threads\n",
846             num_cores, num_tasks, num_threads, num_inner);
847 
848     pool->reset(new ThreadPoolInternal(num_threads));
849     // Main thread OR worker threads in pool each get a possibly empty nested
850     // pool (helps use all available cores when #tasks < #threads)
851     for (size_t i = 0; i < (*pool)->NumThreads(); ++i) {
852       inner_pools->emplace_back(new ThreadPoolInternal(num_inner));
853     }
854 
855     // Pin all actual worker threads to available CPUs.
856     const std::vector<int> cpus = jpegxl::tools::cpu::AvailableCPUs();
857     size_t next_index = 0;
858     PinThreads(pool->get(), cpus, &next_index);
859     for (std::unique_ptr<ThreadPoolInternal>& inner : *inner_pools) {
860       PinThreads(inner.get(), cpus, &next_index);
861     }
862   }
863 
GetMethods()864   static StringVec GetMethods() {
865     StringVec methods = SplitString(Args()->codec, ',');
866     for (auto it = methods.begin(); it != methods.end();) {
867       if (it->empty()) {
868         it = methods.erase(it);
869       } else {
870         ++it;
871       }
872     }
873     return methods;
874   }
875 
GetExtraMetricsNames()876   static StringVec GetExtraMetricsNames() {
877     StringVec metrics = SplitString(Args()->extra_metrics, ',');
878     for (auto it = metrics.begin(); it != metrics.end();) {
879       if (it->empty()) {
880         it = metrics.erase(it);
881       } else {
882         *it = SplitString(*it, ':')[0];
883         ++it;
884       }
885     }
886     return metrics;
887   }
888 
GetExtraMetricsCommands()889   static StringVec GetExtraMetricsCommands() {
890     StringVec metrics = SplitString(Args()->extra_metrics, ',');
891     for (auto it = metrics.begin(); it != metrics.end();) {
892       if (it->empty()) {
893         it = metrics.erase(it);
894       } else {
895         auto s = SplitString(*it, ':');
896         JXL_CHECK(s.size() == 2);
897         *it = s[1];
898         ++it;
899       }
900     }
901     return metrics;
902   }
903 
SampleFromInput(const StringVec & fnames,const std::string & sample_tmp_dir,int num_samples,size_t size)904   static StringVec SampleFromInput(const StringVec& fnames,
905                                    const std::string& sample_tmp_dir,
906                                    int num_samples, size_t size) {
907     JXL_CHECK(!sample_tmp_dir.empty());
908     fprintf(stderr, "Creating samples of %zux%zu tiles...\n", size, size);
909     StringVec fnames_out;
910     std::vector<Image3F> images;
911     std::vector<size_t> offsets;
912     size_t total_num_tiles = 0;
913     for (const auto& fname : fnames) {
914       Image3F img;
915       JXL_CHECK(ReadPNG(fname, &img));
916       JXL_CHECK(img.xsize() >= size);
917       JXL_CHECK(img.ysize() >= size);
918       total_num_tiles += (img.xsize() - size + 1) * (img.ysize() - size + 1);
919       offsets.push_back(total_num_tiles);
920       images.emplace_back(std::move(img));
921     }
922     JXL_CHECK(MakeDir(sample_tmp_dir));
923     std::mt19937_64 rng;
924     for (int i = 0; i < num_samples; ++i) {
925       int val = std::uniform_int_distribution<>(0, offsets.back())(rng);
926       size_t idx = (std::lower_bound(offsets.begin(), offsets.end(), val) -
927                     offsets.begin());
928       JXL_CHECK(idx < images.size());
929       const Image3F& img = images[idx];
930       int x0 = std::uniform_int_distribution<>(0, img.xsize() - size)(rng);
931       int y0 = std::uniform_int_distribution<>(0, img.ysize() - size)(rng);
932       Image3F sample(size, size);
933       for (size_t c = 0; c < 3; ++c) {
934         for (size_t y = 0; y < size; ++y) {
935           const float* JXL_RESTRICT row_in = img.PlaneRow(c, y0 + y);
936           float* JXL_RESTRICT row_out = sample.PlaneRow(c, y);
937           memcpy(row_out, &row_in[x0], size * sizeof(row_out[0]));
938         }
939       }
940       std::string fn_output =
941           StringPrintf("%s/%s.crop_%dx%d+%d+%d.png", sample_tmp_dir.c_str(),
942                        FileBaseName(fnames[idx]).c_str(), size, size, x0, y0);
943       ThreadPool* null_pool = nullptr;
944       JXL_CHECK(WritePNG(std::move(sample), null_pool, fn_output));
945       fnames_out.push_back(fn_output);
946     }
947     fprintf(stderr, "Created %d sample tiles\n", num_samples);
948     return fnames_out;
949   }
950 
GetFilenames()951   static StringVec GetFilenames() {
952     StringVec fnames;
953     JXL_CHECK(MatchFiles(Args()->input, &fnames));
954     if (fnames.empty()) {
955       JXL_ABORT("No input file matches pattern: '%s'", Args()->input.c_str());
956     }
957     if (Args()->print_details) {
958       std::sort(fnames.begin(), fnames.end());
959     }
960 
961     if (Args()->num_samples > 0) {
962       fnames = SampleFromInput(fnames, Args()->sample_tmp_dir,
963                                Args()->num_samples, Args()->sample_dimensions);
964     }
965     return fnames;
966   }
967 
968   // (Load only once, not for every codec)
LoadImages(const StringVec & fnames,const bool all_color_aware,const bool jpeg_transcoding_requested,ThreadPool * pool)969   static std::vector<CodecInOut> LoadImages(
970       const StringVec& fnames, const bool all_color_aware,
971       const bool jpeg_transcoding_requested, ThreadPool* pool) {
972     PROFILER_FUNC;
973     std::vector<CodecInOut> loaded_images;
974     loaded_images.resize(fnames.size());
975     RunOnPool(
976         pool, 0, static_cast<int>(fnames.size()), ThreadPool::SkipInit(),
977         [&](const int task, int /*thread*/) {
978           const size_t i = static_cast<size_t>(task);
979           Status ok = true;
980 
981           loaded_images[i].target_nits = Args()->intensity_target;
982           loaded_images[i].dec_target = jpeg_transcoding_requested
983                                             ? DecodeTarget::kQuantizedCoeffs
984                                             : DecodeTarget::kPixels;
985           if (!Args()->decode_only) {
986             ok = SetFromFile(fnames[i], Args()->color_hints, &loaded_images[i]);
987           }
988           if (!ok) {
989             if (!Args()->silent_errors) {
990               fprintf(stderr, "Failed to load image %s\n", fnames[i].c_str());
991             }
992             return;
993           }
994 
995           if (!Args()->decode_only && all_color_aware) {
996             const bool is_gray = loaded_images[i].Main().IsGray();
997             const ColorEncoding& c_desired = ColorEncoding::LinearSRGB(is_gray);
998             if (!loaded_images[i].TransformTo(c_desired, /*pool=*/nullptr)) {
999               JXL_ABORT("Failed to transform to lin. sRGB %s",
1000                         fnames[i].c_str());
1001             }
1002           }
1003 
1004           if (!Args()->decode_only && Args()->override_bitdepth != 0) {
1005             if (Args()->override_bitdepth == 32) {
1006               loaded_images[i].metadata.m.SetFloat32Samples();
1007             } else {
1008               loaded_images[i].metadata.m.SetUintSamples(
1009                   Args()->override_bitdepth);
1010             }
1011           }
1012         },
1013         "Load images");
1014     return loaded_images;
1015   }
1016 
CreateTasks(const StringVec & methods,const StringVec & fnames,bool * all_color_aware,bool * jpeg_transcoding_requested)1017   static std::vector<Task> CreateTasks(const StringVec& methods,
1018                                        const StringVec& fnames,
1019                                        bool* all_color_aware,
1020                                        bool* jpeg_transcoding_requested) {
1021     std::vector<Task> tasks;
1022     tasks.reserve(methods.size() * fnames.size());
1023     *all_color_aware = true;
1024     *jpeg_transcoding_requested = false;
1025     for (size_t idx_image = 0; idx_image < fnames.size(); ++idx_image) {
1026       for (size_t idx_method = 0; idx_method < methods.size(); ++idx_method) {
1027         tasks.emplace_back();
1028         Task& t = tasks.back();
1029         t.codec = CreateImageCodec(methods[idx_method]);
1030         *all_color_aware &= t.codec->IsColorAware();
1031         *jpeg_transcoding_requested |= t.codec->IsJpegTranscoder();
1032         t.idx_image = idx_image;
1033         t.idx_method = idx_method;
1034         // t.stats is default-initialized.
1035       }
1036     }
1037     JXL_ASSERT(tasks.size() == tasks.capacity());
1038     return tasks;
1039   }
1040 
1041   // Return the total number of errors.
RunTasks(const StringVec & methods,const StringVec & extra_metrics_names,const StringVec & extra_metrics_commands,const StringVec & fnames,const std::vector<CodecInOut> & loaded_images,ThreadPoolInternal * pool,const std::vector<std::unique_ptr<ThreadPoolInternal>> & inner_pools,std::vector<Task> * tasks)1042   static size_t RunTasks(
1043       const StringVec& methods, const StringVec& extra_metrics_names,
1044       const StringVec& extra_metrics_commands, const StringVec& fnames,
1045       const std::vector<CodecInOut>& loaded_images, ThreadPoolInternal* pool,
1046       const std::vector<std::unique_ptr<ThreadPoolInternal>>& inner_pools,
1047       std::vector<Task>* tasks) {
1048     PROFILER_FUNC;
1049     StatPrinter printer(methods, extra_metrics_names, fnames, *tasks);
1050     if (Args()->print_details_csv) {
1051       // Print CSV header
1052       printf(
1053           "method,image,error,size,pixels,enc_speed,dec_speed,"
1054           "bpp,dist,psnr,p,bppp,qabpp");
1055       for (const std::string& s : extra_metrics_names) {
1056         printf(",%s", s.c_str());
1057       }
1058       printf("\n");
1059     }
1060 
1061     std::vector<uint64_t> errors_thread;
1062     RunOnPool(
1063         pool, 0, tasks->size(),
1064         [&](size_t num_threads) {
1065           // Reduce false sharing by only writing every 8th slot (64 bytes).
1066           errors_thread.resize(8 * num_threads);
1067           return true;
1068         },
1069         [&](const int i, const int thread) {
1070           Task& t = (*tasks)[i];
1071           const CodecInOut& image = loaded_images[t.idx_image];
1072           t.image = &image;
1073           PaddedBytes compressed;
1074           DoCompress(fnames[t.idx_image], image, extra_metrics_commands,
1075                      t.codec.get(), inner_pools[thread].get(), &compressed,
1076                      &t.stats);
1077           printer.TaskDone(i, t);
1078           errors_thread[8 * thread] += t.stats.total_errors;
1079         },
1080         "Benchmark tasks");
1081     if (Args()->show_progress) fprintf(stderr, "\n");
1082     return std::accumulate(errors_thread.begin(), errors_thread.end(), 0);
1083   }
1084 };
1085 
BenchmarkMain(int argc,const char ** argv)1086 int BenchmarkMain(int argc, const char** argv) {
1087   fprintf(stderr, "benchmark_xl %s\n",
1088           jpegxl::tools::CodecConfigString(JxlDecoderVersion()).c_str());
1089 
1090   JXL_CHECK(Args()->AddCommandLineOptions());
1091 
1092   if (!Args()->Parse(argc, argv)) {
1093     fprintf(stderr, "Use '%s -h' for more information\n", argv[0]);
1094     return 1;
1095   }
1096 
1097   if (Args()->cmdline.HelpFlagPassed()) {
1098     Args()->PrintHelp();
1099     return 0;
1100   }
1101   if (!Args()->ValidateArgs()) {
1102     fprintf(stderr, "Use '%s -h' for more information\n", argv[0]);
1103     return 1;
1104   }
1105   return Benchmark::Run();
1106 }
1107 
1108 }  // namespace
1109 }  // namespace jxl
1110 
main(int argc,const char ** argv)1111 int main(int argc, const char** argv) { return jxl::BenchmarkMain(argc, argv); }
1112