1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/dec_external_image.h"
7 
8 #include <string.h>
9 
10 #include <algorithm>
11 #include <array>
12 #include <functional>
13 #include <utility>
14 #include <vector>
15 
16 #undef HWY_TARGET_INCLUDE
17 #define HWY_TARGET_INCLUDE "lib/jxl/dec_external_image.cc"
18 #include <hwy/foreach_target.h>
19 #include <hwy/highway.h>
20 
21 #include "lib/jxl/alpha.h"
22 #include "lib/jxl/base/byte_order.h"
23 #include "lib/jxl/base/cache_aligned.h"
24 #include "lib/jxl/base/compiler_specific.h"
25 #include "lib/jxl/color_management.h"
26 #include "lib/jxl/common.h"
27 #include "lib/jxl/sanitizers.h"
28 #include "lib/jxl/transfer_functions-inl.h"
29 
30 HWY_BEFORE_NAMESPACE();
31 namespace jxl {
32 namespace HWY_NAMESPACE {
33 
FloatToU32(const float * in,uint32_t * out,size_t num,float mul,size_t bits_per_sample)34 void FloatToU32(const float* in, uint32_t* out, size_t num, float mul,
35                 size_t bits_per_sample) {
36   // TODO(eustas): investigate 24..31 bpp cases.
37   if (bits_per_sample == 32) {
38     // Conversion to real 32-bit *unsigned* integers requires more intermediate
39     // precision that what is given by the usual f32 -> i32 conversion
40     // instructions, so we run the non-SIMD path for those.
41     const uint32_t cap = (1ull << bits_per_sample) - 1;
42     for (size_t x = 0; x < num; x++) {
43       float v = in[x];
44       if (v >= 1.0f) {
45         out[x] = cap;
46       } else if (v >= 0.0f) {  // Inverted condition => NaN -> 0.
47         out[x] = static_cast<uint32_t>(v * mul + 0.5f);
48       } else {
49         out[x] = 0;
50       }
51     }
52     return;
53   }
54 
55   // General SIMD case for less than 32 bits output.
56   const HWY_FULL(float) d;
57   const hwy::HWY_NAMESPACE::Rebind<uint32_t, decltype(d)> du;
58 
59   // Unpoison accessing partially-uninitialized vectors with memory sanitizer.
60   // This is because we run NearestInt() on the vector, which triggers msan even
61   // it it safe to do so since the values are not mixed between lanes.
62   const size_t num_round_up = RoundUpTo(num, Lanes(d));
63   msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num));
64 
65   const auto one = Set(d, 1.0f);
66   const auto scale = Set(d, mul);
67   for (size_t x = 0; x < num; x += Lanes(d)) {
68     auto v = Load(d, in + x);
69     // Clamp turns NaN to 'min'.
70     v = Clamp(v, Zero(d), one);
71     auto i = NearestInt(v * scale);
72     Store(BitCast(du, i), du, out + x);
73   }
74 
75   // Poison back the output.
76   msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num));
77 }
78 
FloatToF16(const float * in,hwy::float16_t * out,size_t num)79 void FloatToF16(const float* in, hwy::float16_t* out, size_t num) {
80   const HWY_FULL(float) d;
81   const hwy::HWY_NAMESPACE::Rebind<hwy::float16_t, decltype(d)> du;
82 
83   // Unpoison accessing partially-uninitialized vectors with memory sanitizer.
84   // This is because we run DemoteTo() on the vector which triggers msan.
85   const size_t num_round_up = RoundUpTo(num, Lanes(d));
86   msan::UnpoisonMemory(in + num, sizeof(in[0]) * (num_round_up - num));
87 
88   for (size_t x = 0; x < num; x += Lanes(d)) {
89     auto v = Load(d, in + x);
90     auto v16 = DemoteTo(du, v);
91     Store(v16, du, out + x);
92   }
93 
94   // Poison back the output.
95   msan::PoisonMemory(out + num, sizeof(out[0]) * (num_round_up - num));
96 }
97 
98 // NOLINTNEXTLINE(google-readability-namespace-comments)
99 }  // namespace HWY_NAMESPACE
100 }  // namespace jxl
101 HWY_AFTER_NAMESPACE();
102 
103 #if HWY_ONCE
104 
105 namespace jxl {
106 namespace {
107 
108 // Stores a float in big endian
StoreBEFloat(float value,uint8_t * p)109 void StoreBEFloat(float value, uint8_t* p) {
110   uint32_t u;
111   memcpy(&u, &value, 4);
112   StoreBE32(u, p);
113 }
114 
115 // Stores a float in little endian
StoreLEFloat(float value,uint8_t * p)116 void StoreLEFloat(float value, uint8_t* p) {
117   uint32_t u;
118   memcpy(&u, &value, 4);
119   StoreLE32(u, p);
120 }
121 
122 // The orientation may not be identity.
123 // TODO(lode): SIMDify where possible
124 template <typename T>
UndoOrientation(jxl::Orientation undo_orientation,const Plane<T> & image,Plane<T> & out,jxl::ThreadPool * pool)125 void UndoOrientation(jxl::Orientation undo_orientation, const Plane<T>& image,
126                      Plane<T>& out, jxl::ThreadPool* pool) {
127   const size_t xsize = image.xsize();
128   const size_t ysize = image.ysize();
129 
130   if (undo_orientation == Orientation::kFlipHorizontal) {
131     out = Plane<T>(xsize, ysize);
132     RunOnPool(
133         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
134         [&](const int task, int /*thread*/) {
135           const int64_t y = task;
136           const T* JXL_RESTRICT row_in = image.Row(y);
137           T* JXL_RESTRICT row_out = out.Row(y);
138           for (size_t x = 0; x < xsize; ++x) {
139             row_out[xsize - x - 1] = row_in[x];
140           }
141         },
142         "UndoOrientation");
143   } else if (undo_orientation == Orientation::kRotate180) {
144     out = Plane<T>(xsize, ysize);
145     RunOnPool(
146         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
147         [&](const int task, int /*thread*/) {
148           const int64_t y = task;
149           const T* JXL_RESTRICT row_in = image.Row(y);
150           T* JXL_RESTRICT row_out = out.Row(ysize - y - 1);
151           for (size_t x = 0; x < xsize; ++x) {
152             row_out[xsize - x - 1] = row_in[x];
153           }
154         },
155         "UndoOrientation");
156   } else if (undo_orientation == Orientation::kFlipVertical) {
157     out = Plane<T>(xsize, ysize);
158     RunOnPool(
159         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
160         [&](const int task, int /*thread*/) {
161           const int64_t y = task;
162           const T* JXL_RESTRICT row_in = image.Row(y);
163           T* JXL_RESTRICT row_out = out.Row(ysize - y - 1);
164           for (size_t x = 0; x < xsize; ++x) {
165             row_out[x] = row_in[x];
166           }
167         },
168         "UndoOrientation");
169   } else if (undo_orientation == Orientation::kTranspose) {
170     out = Plane<T>(ysize, xsize);
171     RunOnPool(
172         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
173         [&](const int task, int /*thread*/) {
174           const int64_t y = task;
175           const T* JXL_RESTRICT row_in = image.Row(y);
176           for (size_t x = 0; x < xsize; ++x) {
177             out.Row(x)[y] = row_in[x];
178           }
179         },
180         "UndoOrientation");
181   } else if (undo_orientation == Orientation::kRotate90) {
182     out = Plane<T>(ysize, xsize);
183     RunOnPool(
184         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
185         [&](const int task, int /*thread*/) {
186           const int64_t y = task;
187           const T* JXL_RESTRICT row_in = image.Row(y);
188           for (size_t x = 0; x < xsize; ++x) {
189             out.Row(x)[ysize - y - 1] = row_in[x];
190           }
191         },
192         "UndoOrientation");
193   } else if (undo_orientation == Orientation::kAntiTranspose) {
194     out = Plane<T>(ysize, xsize);
195     RunOnPool(
196         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
197         [&](const int task, int /*thread*/) {
198           const int64_t y = task;
199           const T* JXL_RESTRICT row_in = image.Row(y);
200           for (size_t x = 0; x < xsize; ++x) {
201             out.Row(xsize - x - 1)[ysize - y - 1] = row_in[x];
202           }
203         },
204         "UndoOrientation");
205   } else if (undo_orientation == Orientation::kRotate270) {
206     out = Plane<T>(ysize, xsize);
207     RunOnPool(
208         pool, 0, static_cast<uint32_t>(ysize), ThreadPool::SkipInit(),
209         [&](const int task, int /*thread*/) {
210           const int64_t y = task;
211           const T* JXL_RESTRICT row_in = image.Row(y);
212           for (size_t x = 0; x < xsize; ++x) {
213             out.Row(xsize - x - 1)[y] = row_in[x];
214           }
215         },
216         "UndoOrientation");
217   }
218 }
219 }  // namespace
220 
221 HWY_EXPORT(FloatToU32);
222 HWY_EXPORT(FloatToF16);
223 
224 namespace {
225 
226 using StoreFuncType = void(uint32_t value, uint8_t* dest);
227 template <StoreFuncType StoreFunc>
StoreUintRow(uint32_t * JXL_RESTRICT * rows_u32,size_t num_channels,size_t xsize,size_t bytes_per_sample,uint8_t * JXL_RESTRICT out)228 void StoreUintRow(uint32_t* JXL_RESTRICT* rows_u32, size_t num_channels,
229                   size_t xsize, size_t bytes_per_sample,
230                   uint8_t* JXL_RESTRICT out) {
231   for (size_t x = 0; x < xsize; ++x) {
232     for (size_t c = 0; c < num_channels; c++) {
233       StoreFunc(rows_u32[c][x],
234                 out + (num_channels * x + c) * bytes_per_sample);
235     }
236   }
237 }
238 
239 template <void(StoreFunc)(float, uint8_t*)>
StoreFloatRow(const float * JXL_RESTRICT * rows_in,size_t num_channels,size_t xsize,uint8_t * JXL_RESTRICT out)240 void StoreFloatRow(const float* JXL_RESTRICT* rows_in, size_t num_channels,
241                    size_t xsize, uint8_t* JXL_RESTRICT out) {
242   for (size_t x = 0; x < xsize; ++x) {
243     for (size_t c = 0; c < num_channels; c++) {
244       StoreFunc(rows_in[c][x], out + (num_channels * x + c) * sizeof(float));
245     }
246   }
247 }
248 
Store8(uint32_t value,uint8_t * dest)249 void JXL_INLINE Store8(uint32_t value, uint8_t* dest) { *dest = value & 0xff; }
250 
251 // Maximum number of channels for the ConvertChannelsToExternal function.
252 const size_t kConvertMaxChannels = 4;
253 
254 // Converts a list of channels to an interleaved image, applying transformations
255 // when needed.
256 // The input channels are given as a (non-const!) array of channel pointers and
257 // interleaved in that order.
258 //
259 // Note: if a pointer in channels[] is nullptr, a 1.0 value will be used
260 // instead. This is useful for handling when a user requests an alpha channel
261 // from an image that doesn't have one. The first channel in the list may not
262 // be nullptr, since it is used to determine the image size.
ConvertChannelsToExternal(const ImageF * channels[],size_t num_channels,size_t bits_per_sample,bool float_out,JxlEndianness endianness,size_t stride,jxl::ThreadPool * pool,void * out_image,size_t out_size,JxlImageOutCallback out_callback,void * out_opaque,jxl::Orientation undo_orientation)263 Status ConvertChannelsToExternal(const ImageF* channels[], size_t num_channels,
264                                  size_t bits_per_sample, bool float_out,
265                                  JxlEndianness endianness, size_t stride,
266                                  jxl::ThreadPool* pool, void* out_image,
267                                  size_t out_size,
268                                  JxlImageOutCallback out_callback,
269                                  void* out_opaque,
270                                  jxl::Orientation undo_orientation) {
271   JXL_DASSERT(num_channels != 0 && num_channels <= kConvertMaxChannels);
272   JXL_DASSERT(channels[0] != nullptr);
273 
274   if (bits_per_sample < 1 || bits_per_sample > 32) {
275     return JXL_FAILURE("Invalid bits_per_sample value.");
276   }
277   if (!!out_image == !!out_callback) {
278     return JXL_FAILURE(
279         "Must provide either an out_image or an out_callback, but not both.");
280   }
281   // TODO(deymo): Implement 1-bit per pixel packed in 8 samples per byte.
282   if (bits_per_sample == 1) {
283     return JXL_FAILURE("packed 1-bit per sample is not yet supported");
284   }
285 
286   // bytes_per_channel and is only valid for bits_per_sample > 1.
287   const size_t bytes_per_channel = DivCeil(bits_per_sample, jxl::kBitsPerByte);
288   const size_t bytes_per_pixel = num_channels * bytes_per_channel;
289 
290   std::vector<std::vector<uint8_t>> row_out_callback;
291   auto InitOutCallback = [&](size_t num_threads) {
292     if (out_callback) {
293       row_out_callback.resize(num_threads);
294       for (size_t i = 0; i < num_threads; ++i) {
295         row_out_callback[i].resize(stride);
296       }
297     }
298   };
299 
300   // Channels used to store the transformed original channels if needed.
301   ImageF temp_channels[kConvertMaxChannels];
302   if (undo_orientation != Orientation::kIdentity) {
303     for (size_t c = 0; c < num_channels; ++c) {
304       if (channels[c]) {
305         UndoOrientation(undo_orientation, *channels[c], temp_channels[c], pool);
306         channels[c] = &(temp_channels[c]);
307       }
308     }
309   }
310 
311   // First channel may not be nullptr.
312   size_t xsize = channels[0]->xsize();
313   size_t ysize = channels[0]->ysize();
314 
315   if (stride < bytes_per_pixel * xsize) {
316     return JXL_FAILURE(
317         "stride is smaller than scanline width in bytes: %zu vs %zu", stride,
318         bytes_per_pixel * xsize);
319   }
320 
321   const bool little_endian =
322       endianness == JXL_LITTLE_ENDIAN ||
323       (endianness == JXL_NATIVE_ENDIAN && IsLittleEndian());
324 
325   // Handle the case where a channel is nullptr by creating a single row with
326   // ones to use instead.
327   ImageF ones;
328   for (size_t c = 0; c < num_channels; ++c) {
329     if (!channels[c]) {
330       ones = ImageF(xsize, 1);
331       FillImage(1.0f, &ones);
332       break;
333     }
334   }
335 
336   if (float_out) {
337     if (bits_per_sample == 16) {
338       bool swap_endianness = little_endian != IsLittleEndian();
339       Plane<hwy::float16_t> f16_cache;
340       RunOnPool(
341           pool, 0, static_cast<uint32_t>(ysize),
342           [&](size_t num_threads) {
343             f16_cache =
344                 Plane<hwy::float16_t>(xsize, num_channels * num_threads);
345             InitOutCallback(num_threads);
346             return true;
347           },
348           [&](const int task, int thread) {
349             const int64_t y = task;
350             const float* JXL_RESTRICT row_in[kConvertMaxChannels];
351             for (size_t c = 0; c < num_channels; c++) {
352               row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
353             }
354             hwy::float16_t* JXL_RESTRICT row_f16[kConvertMaxChannels];
355             for (size_t c = 0; c < num_channels; c++) {
356               row_f16[c] = f16_cache.Row(c + thread * num_channels);
357               HWY_DYNAMIC_DISPATCH(FloatToF16)
358               (row_in[c], row_f16[c], xsize);
359             }
360             uint8_t* row_out =
361                 out_callback
362                     ? row_out_callback[thread].data()
363                     : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
364             // interleave the one scanline
365             hwy::float16_t* row_f16_out =
366                 reinterpret_cast<hwy::float16_t*>(row_out);
367             for (size_t x = 0; x < xsize; x++) {
368               for (size_t c = 0; c < num_channels; c++) {
369                 row_f16_out[x * num_channels + c] = row_f16[c][x];
370               }
371             }
372             if (swap_endianness) {
373               size_t size = xsize * num_channels * 2;
374               for (size_t i = 0; i < size; i += 2) {
375                 std::swap(row_out[i + 0], row_out[i + 1]);
376               }
377             }
378             if (out_callback) {
379               (*out_callback)(out_opaque, 0, y, xsize, row_out);
380             }
381           },
382           "ConvertF16");
383     } else if (bits_per_sample == 32) {
384       RunOnPool(
385           pool, 0, static_cast<uint32_t>(ysize),
386           [&](size_t num_threads) {
387             InitOutCallback(num_threads);
388             return true;
389           },
390           [&](const int task, int thread) {
391             const int64_t y = task;
392             uint8_t* row_out =
393                 out_callback
394                     ? row_out_callback[thread].data()
395                     : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
396             const float* JXL_RESTRICT row_in[kConvertMaxChannels];
397             for (size_t c = 0; c < num_channels; c++) {
398               row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
399             }
400             if (little_endian) {
401               StoreFloatRow<StoreLEFloat>(row_in, num_channels, xsize, row_out);
402             } else {
403               StoreFloatRow<StoreBEFloat>(row_in, num_channels, xsize, row_out);
404             }
405             if (out_callback) {
406               (*out_callback)(out_opaque, 0, y, xsize, row_out);
407             }
408           },
409           "ConvertFloat");
410     } else {
411       return JXL_FAILURE("float other than 16-bit and 32-bit not supported");
412     }
413   } else {
414     // Multiplier to convert from floating point 0-1 range to the integer
415     // range.
416     float mul = (1ull << bits_per_sample) - 1;
417     Plane<uint32_t> u32_cache;
418     RunOnPool(
419         pool, 0, static_cast<uint32_t>(ysize),
420         [&](size_t num_threads) {
421           u32_cache = Plane<uint32_t>(xsize, num_channels * num_threads);
422           InitOutCallback(num_threads);
423           return true;
424         },
425         [&](const int task, int thread) {
426           const int64_t y = task;
427           uint8_t* row_out =
428               out_callback
429                   ? row_out_callback[thread].data()
430                   : &(reinterpret_cast<uint8_t*>(out_image))[stride * y];
431           const float* JXL_RESTRICT row_in[kConvertMaxChannels];
432           for (size_t c = 0; c < num_channels; c++) {
433             row_in[c] = channels[c] ? channels[c]->Row(y) : ones.Row(0);
434           }
435           uint32_t* JXL_RESTRICT row_u32[kConvertMaxChannels];
436           for (size_t c = 0; c < num_channels; c++) {
437             row_u32[c] = u32_cache.Row(c + thread * num_channels);
438             // row_u32[] is a per-thread temporary row storage, this isn't
439             // intended to be initialized on a previous run.
440             msan::PoisonMemory(row_u32[c], xsize * sizeof(row_u32[c][0]));
441             HWY_DYNAMIC_DISPATCH(FloatToU32)
442             (row_in[c], row_u32[c], xsize, mul, bits_per_sample);
443           }
444           // TODO(deymo): add bits_per_sample == 1 case here.
445           if (bits_per_sample <= 8) {
446             StoreUintRow<Store8>(row_u32, num_channels, xsize, 1, row_out);
447           } else if (bits_per_sample <= 16) {
448             if (little_endian) {
449               StoreUintRow<StoreLE16>(row_u32, num_channels, xsize, 2, row_out);
450             } else {
451               StoreUintRow<StoreBE16>(row_u32, num_channels, xsize, 2, row_out);
452             }
453           } else if (bits_per_sample <= 24) {
454             if (little_endian) {
455               StoreUintRow<StoreLE24>(row_u32, num_channels, xsize, 3, row_out);
456             } else {
457               StoreUintRow<StoreBE24>(row_u32, num_channels, xsize, 3, row_out);
458             }
459           } else {
460             if (little_endian) {
461               StoreUintRow<StoreLE32>(row_u32, num_channels, xsize, 4, row_out);
462             } else {
463               StoreUintRow<StoreBE32>(row_u32, num_channels, xsize, 4, row_out);
464             }
465           }
466           if (out_callback) {
467             (*out_callback)(out_opaque, 0, y, xsize, row_out);
468           }
469         },
470         "ConvertUint");
471   }
472   return true;
473 }
474 
475 }  // namespace
476 
ConvertToExternal(const jxl::ImageBundle & ib,size_t bits_per_sample,bool float_out,size_t num_channels,JxlEndianness endianness,size_t stride,jxl::ThreadPool * pool,void * out_image,size_t out_size,JxlImageOutCallback out_callback,void * out_opaque,jxl::Orientation undo_orientation)477 Status ConvertToExternal(const jxl::ImageBundle& ib, size_t bits_per_sample,
478                          bool float_out, size_t num_channels,
479                          JxlEndianness endianness, size_t stride,
480                          jxl::ThreadPool* pool, void* out_image,
481                          size_t out_size, JxlImageOutCallback out_callback,
482                          void* out_opaque, jxl::Orientation undo_orientation) {
483   bool want_alpha = num_channels == 2 || num_channels == 4;
484   size_t color_channels = num_channels <= 2 ? 1 : 3;
485 
486   const Image3F* color = &ib.color();
487   // Undo premultiplied alpha.
488   Image3F unpremul;
489   if (ib.AlphaIsPremultiplied() && ib.HasAlpha()) {
490     unpremul = Image3F(color->xsize(), color->ysize());
491     CopyImageTo(*color, &unpremul);
492     for (size_t y = 0; y < unpremul.ysize(); y++) {
493       UnpremultiplyAlpha(unpremul.PlaneRow(0, y), unpremul.PlaneRow(1, y),
494                          unpremul.PlaneRow(2, y), ib.alpha().Row(y),
495                          unpremul.xsize());
496     }
497     color = &unpremul;
498   }
499 
500   const ImageF* channels[kConvertMaxChannels];
501   size_t c = 0;
502   for (; c < color_channels; c++) {
503     channels[c] = &color->Plane(c);
504   }
505   if (want_alpha) {
506     channels[c++] = ib.HasAlpha() ? &ib.alpha() : nullptr;
507   }
508   JXL_ASSERT(num_channels == c);
509 
510   return ConvertChannelsToExternal(
511       channels, num_channels, bits_per_sample, float_out, endianness, stride,
512       pool, out_image, out_size, out_callback, out_opaque, undo_orientation);
513 }
514 
ConvertToExternal(const jxl::ImageF & channel,size_t bits_per_sample,bool float_out,JxlEndianness endianness,size_t stride,jxl::ThreadPool * pool,void * out_image,size_t out_size,JxlImageOutCallback out_callback,void * out_opaque,jxl::Orientation undo_orientation)515 Status ConvertToExternal(const jxl::ImageF& channel, size_t bits_per_sample,
516                          bool float_out, JxlEndianness endianness,
517                          size_t stride, jxl::ThreadPool* pool, void* out_image,
518                          size_t out_size, JxlImageOutCallback out_callback,
519                          void* out_opaque, jxl::Orientation undo_orientation) {
520   const ImageF* channels[1];
521   channels[0] = &channel;
522   return ConvertChannelsToExternal(
523       channels, 1, bits_per_sample, float_out, endianness, stride, pool,
524       out_image, out_size, out_callback, out_opaque, undo_orientation);
525 }
526 
527 }  // namespace jxl
528 #endif  // HWY_ONCE
529