1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/render_pipeline/stage_write.h"
7 
8 #include "lib/jxl/common.h"
9 #include "lib/jxl/image_bundle.h"
10 
11 #undef HWY_TARGET_INCLUDE
12 #define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_write.cc"
13 #include <hwy/foreach_target.h>
14 #include <hwy/highway.h>
15 
16 HWY_BEFORE_NAMESPACE();
17 namespace jxl {
18 namespace HWY_NAMESPACE {
19 
20 template <typename D, typename V>
StoreRGBA(D d,V r,V g,V b,V a,bool alpha,size_t n,size_t extra,uint8_t * buf)21 void StoreRGBA(D d, V r, V g, V b, V a, bool alpha, size_t n, size_t extra,
22                uint8_t* buf) {
23 #if HWY_TARGET == HWY_SCALAR
24   buf[0] = r.raw;
25   buf[1] = g.raw;
26   buf[2] = b.raw;
27   if (alpha) {
28     buf[3] = a.raw;
29   }
30 #elif HWY_TARGET == HWY_NEON
31   if (alpha) {
32     uint8x8x4_t data = {r.raw, g.raw, b.raw, a.raw};
33     if (extra >= 8) {
34       vst4_u8(buf, data);
35     } else {
36       uint8_t tmp[8 * 4];
37       vst4_u8(tmp, data);
38       memcpy(buf, tmp, n * 4);
39     }
40   } else {
41     uint8x8x3_t data = {r.raw, g.raw, b.raw};
42     if (extra >= 8) {
43       vst3_u8(buf, data);
44     } else {
45       uint8_t tmp[8 * 3];
46       vst3_u8(tmp, data);
47       memcpy(buf, tmp, n * 3);
48     }
49   }
50 #else
51   // TODO(veluca): implement this for x86.
52   size_t mul = alpha ? 4 : 3;
53   HWY_ALIGN uint8_t bytes[16];
54   Store(r, d, bytes);
55   for (size_t i = 0; i < n; i++) {
56     buf[mul * i] = bytes[i];
57   }
58   Store(g, d, bytes);
59   for (size_t i = 0; i < n; i++) {
60     buf[mul * i + 1] = bytes[i];
61   }
62   Store(b, d, bytes);
63   for (size_t i = 0; i < n; i++) {
64     buf[mul * i + 2] = bytes[i];
65   }
66   if (alpha) {
67     Store(a, d, bytes);
68     for (size_t i = 0; i < n; i++) {
69       buf[4 * i + 3] = bytes[i];
70     }
71   }
72 #endif
73 }
74 
75 class WriteToU8Stage : public RenderPipelineStage {
76  public:
WriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)77   WriteToU8Stage(uint8_t* rgb, size_t stride, size_t width, size_t height,
78                  bool rgba, bool has_alpha, size_t alpha_c)
79       : RenderPipelineStage(RenderPipelineStage::Settings()),
80         rgb_(rgb),
81         stride_(stride),
82         width_(width),
83         height_(height),
84         rgba_(rgba),
85         has_alpha_(has_alpha),
86         alpha_c_(alpha_c) {}
87 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const88   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
89                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
90                   float* JXL_RESTRICT temp) const final {
91     if (ypos >= height_) return;
92     size_t bytes = rgba_ ? 4 : 3;
93     const float* JXL_RESTRICT row_in_r = GetInputRow(input_rows, 0, 0);
94     const float* JXL_RESTRICT row_in_g = GetInputRow(input_rows, 1, 0);
95     const float* JXL_RESTRICT row_in_b = GetInputRow(input_rows, 2, 0);
96     const float* JXL_RESTRICT row_in_a =
97         has_alpha_ ? GetInputRow(input_rows, alpha_c_, 0) : nullptr;
98     size_t base_ptr = ypos * stride_ + bytes * (xpos - xextra);
99     using D = HWY_CAPPED(float, 4);
100     const D d;
101     D::Rebind<uint32_t> du;
102     auto zero = Zero(d);
103     auto one = Set(d, 1.0f);
104     auto mul = Set(d, 255.0f);
105 
106     ssize_t x0 = -RoundUpTo(xextra, Lanes(d));
107     ssize_t x1 = RoundUpTo(xsize + xextra, Lanes(d));
108 
109     for (ssize_t x = x0; x < x1; x += Lanes(d)) {
110       auto rf = Clamp(zero, Load(d, row_in_r + x), one) * mul;
111       auto gf = Clamp(zero, Load(d, row_in_g + x), one) * mul;
112       auto bf = Clamp(zero, Load(d, row_in_b + x), one) * mul;
113       auto af = row_in_a ? Clamp(zero, Load(d, row_in_a + x), one) * mul
114                          : Set(d, 255.0f);
115       auto r8 = U8FromU32(BitCast(du, NearestInt(rf)));
116       auto g8 = U8FromU32(BitCast(du, NearestInt(gf)));
117       auto b8 = U8FromU32(BitCast(du, NearestInt(bf)));
118       auto a8 = U8FromU32(BitCast(du, NearestInt(af)));
119       size_t n = width_ - xpos - x;
120       if (JXL_LIKELY(n >= Lanes(d))) {
121         StoreRGBA(D::Rebind<uint8_t>(), r8, g8, b8, a8, rgba_, Lanes(d), n,
122                   rgb_ + base_ptr + bytes * x);
123       } else {
124         StoreRGBA(D::Rebind<uint8_t>(), r8, g8, b8, a8, rgba_, n, n,
125                   rgb_ + base_ptr + bytes * x);
126       }
127     }
128   }
129 
GetChannelMode(size_t c) const130   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
131     return c < 3 || (has_alpha_ && c == alpha_c_)
132                ? RenderPipelineChannelMode::kInput
133                : RenderPipelineChannelMode::kIgnored;
134   }
135 
136  private:
137   uint8_t* rgb_;
138   size_t stride_;
139   size_t width_;
140   size_t height_;
141   bool rgba_;
142   bool has_alpha_;
143   size_t alpha_c_;
144   std::vector<float> opaque_alpha_;
145 };
146 
GetWriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)147 std::unique_ptr<RenderPipelineStage> GetWriteToU8Stage(
148     uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba,
149     bool has_alpha, size_t alpha_c) {
150   return jxl::make_unique<WriteToU8Stage>(rgb, stride, width, height, rgba,
151                                           has_alpha, alpha_c);
152 }
153 
154 // NOLINTNEXTLINE(google-readability-namespace-comments)
155 }  // namespace HWY_NAMESPACE
156 }  // namespace jxl
157 HWY_AFTER_NAMESPACE();
158 
159 #if HWY_ONCE
160 
161 namespace jxl {
162 
163 HWY_EXPORT(GetWriteToU8Stage);
164 
165 namespace {
166 class WriteToImageBundleStage : public RenderPipelineStage {
167  public:
WriteToImageBundleStage(ImageBundle * image_bundle)168   explicit WriteToImageBundleStage(ImageBundle* image_bundle)
169       : RenderPipelineStage(RenderPipelineStage::Settings()),
170         image_bundle_(image_bundle) {}
171 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const172   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
173                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
174                   float* JXL_RESTRICT temp) const final {
175     for (size_t c = 0; c < 3; c++) {
176       memcpy(image_bundle_->color()->PlaneRow(c, ypos) + xpos - xextra,
177              GetInputRow(input_rows, c, 0) - xextra,
178              sizeof(float) * (xsize + 2 * xextra));
179     }
180     for (size_t ec = 0; ec < image_bundle_->extra_channels().size(); ec++) {
181       JXL_ASSERT(ec < image_bundle_->extra_channels().size());
182       JXL_ASSERT(image_bundle_->extra_channels()[ec].xsize() <=
183                  xpos + xsize + xextra);
184       memcpy(image_bundle_->extra_channels()[ec].Row(ypos) + xpos - xextra,
185              GetInputRow(input_rows, 3 + ec, 0) - xextra,
186              sizeof(float) * (xsize + 2 * xextra));
187     }
188   }
189 
GetChannelMode(size_t c) const190   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
191     return c < 3 + image_bundle_->extra_channels().size()
192                ? RenderPipelineChannelMode::kInput
193                : RenderPipelineChannelMode::kIgnored;
194   }
195 
196  private:
197   ImageBundle* image_bundle_;
198 };
199 
200 class WriteToImage3FStage : public RenderPipelineStage {
201  public:
WriteToImage3FStage(Image3F * image)202   explicit WriteToImage3FStage(Image3F* image)
203       : RenderPipelineStage(RenderPipelineStage::Settings()), image_(image) {}
204 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const205   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
206                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
207                   float* JXL_RESTRICT temp) const final {
208     for (size_t c = 0; c < 3; c++) {
209       memcpy(image_->PlaneRow(c, ypos) + xpos - xextra,
210              GetInputRow(input_rows, c, 0) - xextra,
211              sizeof(float) * (xsize + 2 * xextra));
212     }
213   }
214 
GetChannelMode(size_t c) const215   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
216     return c < 3 ? RenderPipelineChannelMode::kInput
217                  : RenderPipelineChannelMode::kIgnored;
218   }
219 
220  private:
221   Image3F* image_;
222 };
223 
224 class WriteToPixelCallbackStage : public RenderPipelineStage {
225  public:
WriteToPixelCallbackStage(const std::function<void (const float *,size_t,size_t,size_t)> & pixel_callback,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)226   WriteToPixelCallbackStage(
227       const std::function<void(const float*, size_t, size_t, size_t)>&
228           pixel_callback,
229       size_t width, size_t height, bool rgba, bool has_alpha, size_t alpha_c)
230       : RenderPipelineStage(RenderPipelineStage::Settings()),
231         pixel_callback_(pixel_callback),
232         width_(width),
233         height_(height),
234         rgba_(rgba),
235         has_alpha_(has_alpha),
236         alpha_c_(alpha_c),
237         opaque_alpha_(kMaxPixelsPerCall, 1.0f) {
238     settings_.temp_buffer_size = kMaxPixelsPerCall * (rgba_ ? 4 : 3);
239   }
240 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const241   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
242                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
243                   float* JXL_RESTRICT temp) const final {
244     if (ypos >= height_) return;
245     const float* line_buffers[4];
246     for (size_t c = 0; c < 3; c++) {
247       line_buffers[c] = GetInputRow(input_rows, c, 0);
248     }
249     if (has_alpha_) {
250       line_buffers[3] = GetInputRow(input_rows, alpha_c_, 0);
251     } else {
252       line_buffers[3] = opaque_alpha_.data();
253     }
254     // TODO(veluca): SIMD.
255     ssize_t limit = std::min(xextra + xsize, width_ - xpos);
256     for (ssize_t x0 = -xextra; x0 < limit; x0 += kMaxPixelsPerCall) {
257       size_t j = 0;
258       size_t ix = 0;
259       for (; ix < kMaxPixelsPerCall && ssize_t(ix) + x0 < limit; ix++) {
260         temp[j++] = line_buffers[0][x0 + ix];
261         temp[j++] = line_buffers[1][x0 + ix];
262         temp[j++] = line_buffers[2][x0 + ix];
263         if (rgba_) {
264           temp[j++] = line_buffers[3][x0 + ix];
265         }
266       }
267       pixel_callback_(temp, xpos + x0, ypos, ix);
268     }
269   }
270 
GetChannelMode(size_t c) const271   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
272     return c < 3 || (has_alpha_ && c == alpha_c_)
273                ? RenderPipelineChannelMode::kInput
274                : RenderPipelineChannelMode::kIgnored;
275   }
276 
277  private:
278   static constexpr size_t kMaxPixelsPerCall = 1024;
279   const std::function<void(const float*, size_t, size_t, size_t)>&
280       pixel_callback_;
281   size_t width_;
282   size_t height_;
283   bool rgba_;
284   bool has_alpha_;
285   size_t alpha_c_;
286   std::vector<float> opaque_alpha_;
287 };
288 
289 }  // namespace
290 
GetWriteToImageBundleStage(ImageBundle * image_bundle)291 std::unique_ptr<RenderPipelineStage> GetWriteToImageBundleStage(
292     ImageBundle* image_bundle) {
293   return jxl::make_unique<WriteToImageBundleStage>(image_bundle);
294 }
295 
GetWriteToImage3FStage(Image3F * image)296 std::unique_ptr<RenderPipelineStage> GetWriteToImage3FStage(Image3F* image) {
297   return jxl::make_unique<WriteToImage3FStage>(image);
298 }
299 
GetWriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)300 std::unique_ptr<RenderPipelineStage> GetWriteToU8Stage(
301     uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba,
302     bool has_alpha, size_t alpha_c) {
303   return HWY_DYNAMIC_DISPATCH(GetWriteToU8Stage)(rgb, stride, width, height,
304                                                  rgba, has_alpha, alpha_c);
305 }
306 
GetWriteToPixelCallbackStage(const std::function<void (const float *,size_t,size_t,size_t)> & pixel_callback,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)307 std::unique_ptr<RenderPipelineStage> GetWriteToPixelCallbackStage(
308     const std::function<void(const float*, size_t, size_t, size_t)>&
309         pixel_callback,
310     size_t width, size_t height, bool rgba, bool has_alpha, size_t alpha_c) {
311   return jxl::make_unique<WriteToPixelCallbackStage>(
312       pixel_callback, width, height, rgba, has_alpha, alpha_c);
313 }
314 
315 }  // namespace jxl
316 
317 #endif
318