1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/dec_noise.h"
7 
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11 
12 #include <algorithm>
13 #include <numeric>
14 #include <utility>
15 
16 #undef HWY_TARGET_INCLUDE
17 #define HWY_TARGET_INCLUDE "lib/jxl/dec_noise.cc"
18 #include <hwy/foreach_target.h>
19 #include <hwy/highway.h>
20 
21 #include "lib/jxl/base/compiler_specific.h"
22 #include "lib/jxl/chroma_from_luma.h"
23 #include "lib/jxl/image_ops.h"
24 #include "lib/jxl/opsin_params.h"
25 #include "lib/jxl/sanitizers.h"
26 #include "lib/jxl/xorshift128plus-inl.h"
27 HWY_BEFORE_NAMESPACE();
28 namespace jxl {
29 namespace HWY_NAMESPACE {
30 
31 // These templates are not found via ADL.
32 using hwy::HWY_NAMESPACE::ShiftRight;
33 using hwy::HWY_NAMESPACE::Vec;
34 
35 using D = HWY_CAPPED(float, kBlockDim);
36 using DI = hwy::HWY_NAMESPACE::Rebind<int, D>;
37 using DI8 = hwy::HWY_NAMESPACE::Repartition<uint8_t, D>;
38 
39 // Converts one vector's worth of random bits to floats in [1, 2).
40 // NOTE: as the convolution kernel sums to 0, it doesn't matter if inputs are in
41 // [0, 1) or in [1, 2).
BitsToFloat(const uint32_t * JXL_RESTRICT random_bits,float * JXL_RESTRICT floats)42 void BitsToFloat(const uint32_t* JXL_RESTRICT random_bits,
43                  float* JXL_RESTRICT floats) {
44   const HWY_FULL(float) df;
45   const HWY_FULL(uint32_t) du;
46 
47   const auto bits = Load(du, random_bits);
48   // 1.0 + 23 random mantissa bits = [1, 2)
49   const auto rand12 = BitCast(df, ShiftRight<9>(bits) | Set(du, 0x3F800000));
50   Store(rand12, df, floats);
51 }
52 
RandomImage(Xorshift128Plus * rng,const Rect & rect,ImageF * JXL_RESTRICT noise)53 void RandomImage(Xorshift128Plus* rng, const Rect& rect,
54                  ImageF* JXL_RESTRICT noise) {
55   const size_t xsize = rect.xsize();
56   const size_t ysize = rect.ysize();
57 
58   // May exceed the vector size, hence we have two loops over x below.
59   constexpr size_t kFloatsPerBatch =
60       Xorshift128Plus::N * sizeof(uint64_t) / sizeof(float);
61   HWY_ALIGN uint64_t batch[Xorshift128Plus::N];
62 
63   const HWY_FULL(float) df;
64   const size_t N = Lanes(df);
65 
66   for (size_t y = 0; y < ysize; ++y) {
67     float* JXL_RESTRICT row = rect.Row(noise, y);
68 
69     size_t x = 0;
70     // Only entire batches (avoids exceeding the image padding).
71     for (; x + kFloatsPerBatch <= xsize; x += kFloatsPerBatch) {
72       rng->Fill(batch);
73       for (size_t i = 0; i < kFloatsPerBatch; i += Lanes(df)) {
74         BitsToFloat(reinterpret_cast<const uint32_t*>(batch) + i, row + x + i);
75       }
76     }
77 
78     // Any remaining pixels, rounded up to vectors (safe due to padding).
79     rng->Fill(batch);
80     size_t batch_pos = 0;  // < kFloatsPerBatch
81     for (; x < xsize; x += N) {
82       BitsToFloat(reinterpret_cast<const uint32_t*>(batch) + batch_pos,
83                   row + x);
84       batch_pos += N;
85     }
86   }
87 }
88 
89 // [0, max_value]
90 template <class D, class V>
Clamp0ToMax(D d,const V x,const V max_value)91 static HWY_INLINE V Clamp0ToMax(D d, const V x, const V max_value) {
92   const auto clamped = Min(x, max_value);
93   return ZeroIfNegative(clamped);
94 }
95 
96 // x is in [0+delta, 1+delta], delta ~= 0.06
97 template <class StrengthEval>
NoiseStrength(const StrengthEval & eval,const typename StrengthEval::V x)98 typename StrengthEval::V NoiseStrength(const StrengthEval& eval,
99                                        const typename StrengthEval::V x) {
100   return Clamp0ToMax(D(), eval(x), Set(D(), 1.0f));
101 }
102 
103 // TODO(veluca): SIMD-fy.
104 class StrengthEvalLut {
105  public:
106   using V = Vec<D>;
107 
StrengthEvalLut(const NoiseParams & noise_params)108   explicit StrengthEvalLut(const NoiseParams& noise_params)
109 #if HWY_TARGET == HWY_SCALAR
110       : noise_params_(noise_params)
111 #endif
112   {
113 #if HWY_TARGET != HWY_SCALAR
114     uint32_t lut[8];
115     memcpy(lut, noise_params.lut, sizeof(lut));
116     for (size_t i = 0; i < 8; i++) {
117       low16_lut[2 * i] = (lut[i] >> 0) & 0xFF;
118       low16_lut[2 * i + 1] = (lut[i] >> 8) & 0xFF;
119       high16_lut[2 * i] = (lut[i] >> 16) & 0xFF;
120       high16_lut[2 * i + 1] = (lut[i] >> 24) & 0xFF;
121     }
122 #endif
123   }
124 
operator ()(const V vx) const125   V operator()(const V vx) const {
126     constexpr size_t kScale = NoiseParams::kNumNoisePoints - 2;
127     auto scaled_vx = Max(Zero(D()), vx * Set(D(), kScale));
128     auto floor_x = Floor(scaled_vx);
129     auto frac_x = scaled_vx - floor_x;
130     floor_x = IfThenElse(scaled_vx >= Set(D(), kScale), Set(D(), kScale - 1),
131                          floor_x);
132     frac_x = IfThenElse(scaled_vx >= Set(D(), kScale), Set(D(), 1), frac_x);
133     auto floor_x_int = ConvertTo(DI(), floor_x);
134 #if HWY_TARGET == HWY_SCALAR
135     auto low = Set(D(), noise_params_.lut[floor_x_int.raw]);
136     auto hi = Set(D(), noise_params_.lut[floor_x_int.raw + 1]);
137 #else
138     // Set each lane's bytes to {0, 0, 2x+1, 2x}.
139     auto floorx_indices_low =
140         floor_x_int * Set(DI(), 0x0202) + Set(DI(), 0x0100);
141     // Set each lane's bytes to {2x+1, 2x, 0, 0}.
142     auto floorx_indices_hi =
143         floor_x_int * Set(DI(), 0x02020000) + Set(DI(), 0x01000000);
144     // load LUT
145     auto low16 = BitCast(DI(), LoadDup128(DI8(), low16_lut));
146     auto lowm = Set(DI(), 0xFFFF);
147     auto hi16 = BitCast(DI(), LoadDup128(DI8(), high16_lut));
148     auto him = Set(DI(), 0xFFFF0000);
149     // low = noise_params.lut[floor_x]
150     auto low =
151         BitCast(D(), (TableLookupBytes(low16, floorx_indices_low) & lowm) |
152                          (TableLookupBytes(hi16, floorx_indices_hi) & him));
153     // hi = noise_params.lut[floor_x+1]
154     floorx_indices_low += Set(DI(), 0x0202);
155     floorx_indices_hi += Set(DI(), 0x02020000);
156     auto hi =
157         BitCast(D(), (TableLookupBytes(low16, floorx_indices_low) & lowm) |
158                          (TableLookupBytes(hi16, floorx_indices_hi) & him));
159 #endif
160     return MulAdd(hi - low, frac_x, low);
161   }
162 
163  private:
164 #if HWY_TARGET != HWY_SCALAR
165   // noise_params.lut transformed into two 16-bit lookup tables.
166   HWY_ALIGN uint8_t high16_lut[16];
167   HWY_ALIGN uint8_t low16_lut[16];
168 #else
169   const NoiseParams& noise_params_;
170 #endif
171 };
172 
173 template <class D>
AddNoiseToRGB(const D d,const Vec<D> rnd_noise_r,const Vec<D> rnd_noise_g,const Vec<D> rnd_noise_cor,const Vec<D> noise_strength_g,const Vec<D> noise_strength_r,float ytox,float ytob,float * JXL_RESTRICT out_x,float * JXL_RESTRICT out_y,float * JXL_RESTRICT out_b)174 void AddNoiseToRGB(const D d, const Vec<D> rnd_noise_r,
175                    const Vec<D> rnd_noise_g, const Vec<D> rnd_noise_cor,
176                    const Vec<D> noise_strength_g, const Vec<D> noise_strength_r,
177                    float ytox, float ytob, float* JXL_RESTRICT out_x,
178                    float* JXL_RESTRICT out_y, float* JXL_RESTRICT out_b) {
179   const auto kRGCorr = Set(d, 0.9921875f);   // 127/128
180   const auto kRGNCorr = Set(d, 0.0078125f);  // 1/128
181 
182   const auto red_noise = kRGNCorr * rnd_noise_r * noise_strength_r +
183                          kRGCorr * rnd_noise_cor * noise_strength_r;
184   const auto green_noise = kRGNCorr * rnd_noise_g * noise_strength_g +
185                            kRGCorr * rnd_noise_cor * noise_strength_g;
186 
187   auto vx = Load(d, out_x);
188   auto vy = Load(d, out_y);
189   auto vb = Load(d, out_b);
190 
191   vx += red_noise - green_noise + Set(d, ytox) * (red_noise + green_noise);
192   vy += red_noise + green_noise;
193   vb += Set(d, ytob) * (red_noise + green_noise);
194 
195   Store(vx, d, out_x);
196   Store(vy, d, out_y);
197   Store(vb, d, out_b);
198 }
199 
AddNoise(const NoiseParams & noise_params,const Rect & noise_rect,const Image3F & noise,const Rect & opsin_rect,const ColorCorrelationMap & cmap,Image3F * opsin)200 void AddNoise(const NoiseParams& noise_params, const Rect& noise_rect,
201               const Image3F& noise, const Rect& opsin_rect,
202               const ColorCorrelationMap& cmap, Image3F* opsin) {
203   if (!noise_params.HasAny()) return;
204   const StrengthEvalLut noise_model(noise_params);
205   D d;
206   const auto half = Set(d, 0.5f);
207 
208   const size_t xsize = opsin_rect.xsize();
209   const size_t ysize = opsin_rect.ysize();
210 
211   // With the prior subtract-random Laplacian approximation, rnd_* ranges were
212   // about [-1.5, 1.6]; Laplacian3 about doubles this to [-3.6, 3.6], so the
213   // normalizer is half of what it was before (0.5).
214   const auto norm_const = Set(d, 0.22f);
215 
216   float ytox = cmap.YtoXRatio(0);
217   float ytob = cmap.YtoBRatio(0);
218 
219   const size_t xsize_v = RoundUpTo(xsize, Lanes(d));
220 
221   for (size_t y = 0; y < ysize; ++y) {
222     float* JXL_RESTRICT row_x = opsin_rect.PlaneRow(opsin, 0, y);
223     float* JXL_RESTRICT row_y = opsin_rect.PlaneRow(opsin, 1, y);
224     float* JXL_RESTRICT row_b = opsin_rect.PlaneRow(opsin, 2, y);
225     const float* JXL_RESTRICT row_rnd_r = noise_rect.ConstPlaneRow(noise, 0, y);
226     const float* JXL_RESTRICT row_rnd_g = noise_rect.ConstPlaneRow(noise, 1, y);
227     const float* JXL_RESTRICT row_rnd_c = noise_rect.ConstPlaneRow(noise, 2, y);
228     // Needed by the calls to Floor() in StrengthEvalLut. Only arithmetic and
229     // shuffles are otherwise done on the data, so this is safe.
230     msan::UnpoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float));
231     msan::UnpoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float));
232     for (size_t x = 0; x < xsize; x += Lanes(d)) {
233       const auto vx = Load(d, row_x + x);
234       const auto vy = Load(d, row_y + x);
235       const auto in_g = vy - vx;
236       const auto in_r = vy + vx;
237       const auto noise_strength_g = NoiseStrength(noise_model, in_g * half);
238       const auto noise_strength_r = NoiseStrength(noise_model, in_r * half);
239       const auto addit_rnd_noise_red = Load(d, row_rnd_r + x) * norm_const;
240       const auto addit_rnd_noise_green = Load(d, row_rnd_g + x) * norm_const;
241       const auto addit_rnd_noise_correlated =
242           Load(d, row_rnd_c + x) * norm_const;
243       AddNoiseToRGB(D(), addit_rnd_noise_red, addit_rnd_noise_green,
244                     addit_rnd_noise_correlated, noise_strength_g,
245                     noise_strength_r, ytox, ytob, row_x + x, row_y + x,
246                     row_b + x);
247     }
248     msan::PoisonMemory(row_x + xsize, (xsize_v - xsize) * sizeof(float));
249     msan::PoisonMemory(row_y + xsize, (xsize_v - xsize) * sizeof(float));
250     msan::PoisonMemory(row_b + xsize, (xsize_v - xsize) * sizeof(float));
251   }
252 }
253 
RandomImage3(size_t seed,const Rect & rect,Image3F * JXL_RESTRICT noise)254 void RandomImage3(size_t seed, const Rect& rect, Image3F* JXL_RESTRICT noise) {
255   HWY_ALIGN Xorshift128Plus rng(seed);
256   RandomImage(&rng, rect, &noise->Plane(0));
257   RandomImage(&rng, rect, &noise->Plane(1));
258   RandomImage(&rng, rect, &noise->Plane(2));
259 }
260 
261 // NOLINTNEXTLINE(google-readability-namespace-comments)
262 }  // namespace HWY_NAMESPACE
263 }  // namespace jxl
264 HWY_AFTER_NAMESPACE();
265 
266 #if HWY_ONCE
267 namespace jxl {
268 
269 HWY_EXPORT(AddNoise);
AddNoise(const NoiseParams & noise_params,const Rect & noise_rect,const Image3F & noise,const Rect & opsin_rect,const ColorCorrelationMap & cmap,Image3F * opsin)270 void AddNoise(const NoiseParams& noise_params, const Rect& noise_rect,
271               const Image3F& noise, const Rect& opsin_rect,
272               const ColorCorrelationMap& cmap, Image3F* opsin) {
273   return HWY_DYNAMIC_DISPATCH(AddNoise)(noise_params, noise_rect, noise,
274                                         opsin_rect, cmap, opsin);
275 }
276 
277 HWY_EXPORT(RandomImage3);
RandomImage3(size_t seed,const Rect & rect,Image3F * JXL_RESTRICT noise)278 void RandomImage3(size_t seed, const Rect& rect, Image3F* JXL_RESTRICT noise) {
279   return HWY_DYNAMIC_DISPATCH(RandomImage3)(seed, rect, noise);
280 }
281 
DecodeFloatParam(float precision,float * val,BitReader * br)282 void DecodeFloatParam(float precision, float* val, BitReader* br) {
283   const int absval_quant = br->ReadFixedBits<10>();
284   *val = absval_quant / precision;
285 }
286 
DecodeNoise(BitReader * br,NoiseParams * noise_params)287 Status DecodeNoise(BitReader* br, NoiseParams* noise_params) {
288   for (float& i : noise_params->lut) {
289     DecodeFloatParam(kNoisePrecision, &i, br);
290   }
291   return true;
292 }
293 
294 }  // namespace jxl
295 #endif  // HWY_ONCE
296