1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/render_pipeline/stage_write.h"
7
8 #include "lib/jxl/common.h"
9 #include "lib/jxl/image_bundle.h"
10
11 #undef HWY_TARGET_INCLUDE
12 #define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_write.cc"
13 #include <hwy/foreach_target.h>
14 #include <hwy/highway.h>
15
16 HWY_BEFORE_NAMESPACE();
17 namespace jxl {
18 namespace HWY_NAMESPACE {
19
20 template <typename D, typename V>
StoreRGBA(D d,V r,V g,V b,V a,bool alpha,size_t n,size_t extra,uint8_t * buf)21 void StoreRGBA(D d, V r, V g, V b, V a, bool alpha, size_t n, size_t extra,
22 uint8_t* buf) {
23 #if HWY_TARGET == HWY_SCALAR
24 buf[0] = r.raw;
25 buf[1] = g.raw;
26 buf[2] = b.raw;
27 if (alpha) {
28 buf[3] = a.raw;
29 }
30 #elif HWY_TARGET == HWY_NEON
31 if (alpha) {
32 uint8x8x4_t data = {r.raw, g.raw, b.raw, a.raw};
33 if (extra >= 8) {
34 vst4_u8(buf, data);
35 } else {
36 uint8_t tmp[8 * 4];
37 vst4_u8(tmp, data);
38 memcpy(buf, tmp, n * 4);
39 }
40 } else {
41 uint8x8x3_t data = {r.raw, g.raw, b.raw};
42 if (extra >= 8) {
43 vst3_u8(buf, data);
44 } else {
45 uint8_t tmp[8 * 3];
46 vst3_u8(tmp, data);
47 memcpy(buf, tmp, n * 3);
48 }
49 }
50 #else
51 // TODO(veluca): implement this for x86.
52 size_t mul = alpha ? 4 : 3;
53 HWY_ALIGN uint8_t bytes[16];
54 Store(r, d, bytes);
55 for (size_t i = 0; i < n; i++) {
56 buf[mul * i] = bytes[i];
57 }
58 Store(g, d, bytes);
59 for (size_t i = 0; i < n; i++) {
60 buf[mul * i + 1] = bytes[i];
61 }
62 Store(b, d, bytes);
63 for (size_t i = 0; i < n; i++) {
64 buf[mul * i + 2] = bytes[i];
65 }
66 if (alpha) {
67 Store(a, d, bytes);
68 for (size_t i = 0; i < n; i++) {
69 buf[4 * i + 3] = bytes[i];
70 }
71 }
72 #endif
73 }
74
75 class WriteToU8Stage : public RenderPipelineStage {
76 public:
WriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)77 WriteToU8Stage(uint8_t* rgb, size_t stride, size_t width, size_t height,
78 bool rgba, bool has_alpha, size_t alpha_c)
79 : RenderPipelineStage(RenderPipelineStage::Settings()),
80 rgb_(rgb),
81 stride_(stride),
82 width_(width),
83 height_(height),
84 rgba_(rgba),
85 has_alpha_(has_alpha),
86 alpha_c_(alpha_c) {}
87
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const88 void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
89 size_t xextra, size_t xsize, size_t xpos, size_t ypos,
90 float* JXL_RESTRICT temp) const final {
91 if (ypos >= height_) return;
92 size_t bytes = rgba_ ? 4 : 3;
93 const float* JXL_RESTRICT row_in_r = GetInputRow(input_rows, 0, 0);
94 const float* JXL_RESTRICT row_in_g = GetInputRow(input_rows, 1, 0);
95 const float* JXL_RESTRICT row_in_b = GetInputRow(input_rows, 2, 0);
96 const float* JXL_RESTRICT row_in_a =
97 has_alpha_ ? GetInputRow(input_rows, alpha_c_, 0) : nullptr;
98 size_t base_ptr = ypos * stride_ + bytes * (xpos - xextra);
99 using D = HWY_CAPPED(float, 4);
100 const D d;
101 D::Rebind<uint32_t> du;
102 auto zero = Zero(d);
103 auto one = Set(d, 1.0f);
104 auto mul = Set(d, 255.0f);
105
106 ssize_t x0 = -RoundUpTo(xextra, Lanes(d));
107 ssize_t x1 = RoundUpTo(xsize + xextra, Lanes(d));
108
109 for (ssize_t x = x0; x < x1; x += Lanes(d)) {
110 auto rf = Clamp(zero, Load(d, row_in_r + x), one) * mul;
111 auto gf = Clamp(zero, Load(d, row_in_g + x), one) * mul;
112 auto bf = Clamp(zero, Load(d, row_in_b + x), one) * mul;
113 auto af = row_in_a ? Clamp(zero, Load(d, row_in_a + x), one) * mul
114 : Set(d, 255.0f);
115 auto r8 = U8FromU32(BitCast(du, NearestInt(rf)));
116 auto g8 = U8FromU32(BitCast(du, NearestInt(gf)));
117 auto b8 = U8FromU32(BitCast(du, NearestInt(bf)));
118 auto a8 = U8FromU32(BitCast(du, NearestInt(af)));
119 size_t n = width_ - xpos - x;
120 if (JXL_LIKELY(n >= Lanes(d))) {
121 StoreRGBA(D::Rebind<uint8_t>(), r8, g8, b8, a8, rgba_, Lanes(d), n,
122 rgb_ + base_ptr + bytes * x);
123 } else {
124 StoreRGBA(D::Rebind<uint8_t>(), r8, g8, b8, a8, rgba_, n, n,
125 rgb_ + base_ptr + bytes * x);
126 }
127 }
128 }
129
GetChannelMode(size_t c) const130 RenderPipelineChannelMode GetChannelMode(size_t c) const final {
131 return c < 3 || (has_alpha_ && c == alpha_c_)
132 ? RenderPipelineChannelMode::kInput
133 : RenderPipelineChannelMode::kIgnored;
134 }
135
136 private:
137 uint8_t* rgb_;
138 size_t stride_;
139 size_t width_;
140 size_t height_;
141 bool rgba_;
142 bool has_alpha_;
143 size_t alpha_c_;
144 std::vector<float> opaque_alpha_;
145 };
146
GetWriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)147 std::unique_ptr<RenderPipelineStage> GetWriteToU8Stage(
148 uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba,
149 bool has_alpha, size_t alpha_c) {
150 return jxl::make_unique<WriteToU8Stage>(rgb, stride, width, height, rgba,
151 has_alpha, alpha_c);
152 }
153
154 // NOLINTNEXTLINE(google-readability-namespace-comments)
155 } // namespace HWY_NAMESPACE
156 } // namespace jxl
157 HWY_AFTER_NAMESPACE();
158
159 #if HWY_ONCE
160
161 namespace jxl {
162
163 HWY_EXPORT(GetWriteToU8Stage);
164
165 namespace {
166 class WriteToImageBundleStage : public RenderPipelineStage {
167 public:
WriteToImageBundleStage(ImageBundle * image_bundle)168 explicit WriteToImageBundleStage(ImageBundle* image_bundle)
169 : RenderPipelineStage(RenderPipelineStage::Settings()),
170 image_bundle_(image_bundle) {}
171
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const172 void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
173 size_t xextra, size_t xsize, size_t xpos, size_t ypos,
174 float* JXL_RESTRICT temp) const final {
175 for (size_t c = 0; c < 3; c++) {
176 memcpy(image_bundle_->color()->PlaneRow(c, ypos) + xpos - xextra,
177 GetInputRow(input_rows, c, 0) - xextra,
178 sizeof(float) * (xsize + 2 * xextra));
179 }
180 for (size_t ec = 0; ec < image_bundle_->extra_channels().size(); ec++) {
181 JXL_ASSERT(ec < image_bundle_->extra_channels().size());
182 JXL_ASSERT(image_bundle_->extra_channels()[ec].xsize() <=
183 xpos + xsize + xextra);
184 memcpy(image_bundle_->extra_channels()[ec].Row(ypos) + xpos - xextra,
185 GetInputRow(input_rows, 3 + ec, 0) - xextra,
186 sizeof(float) * (xsize + 2 * xextra));
187 }
188 }
189
GetChannelMode(size_t c) const190 RenderPipelineChannelMode GetChannelMode(size_t c) const final {
191 return c < 3 + image_bundle_->extra_channels().size()
192 ? RenderPipelineChannelMode::kInput
193 : RenderPipelineChannelMode::kIgnored;
194 }
195
196 private:
197 ImageBundle* image_bundle_;
198 };
199
200 class WriteToImage3FStage : public RenderPipelineStage {
201 public:
WriteToImage3FStage(Image3F * image)202 explicit WriteToImage3FStage(Image3F* image)
203 : RenderPipelineStage(RenderPipelineStage::Settings()), image_(image) {}
204
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const205 void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
206 size_t xextra, size_t xsize, size_t xpos, size_t ypos,
207 float* JXL_RESTRICT temp) const final {
208 for (size_t c = 0; c < 3; c++) {
209 memcpy(image_->PlaneRow(c, ypos) + xpos - xextra,
210 GetInputRow(input_rows, c, 0) - xextra,
211 sizeof(float) * (xsize + 2 * xextra));
212 }
213 }
214
GetChannelMode(size_t c) const215 RenderPipelineChannelMode GetChannelMode(size_t c) const final {
216 return c < 3 ? RenderPipelineChannelMode::kInput
217 : RenderPipelineChannelMode::kIgnored;
218 }
219
220 private:
221 Image3F* image_;
222 };
223
224 class WriteToPixelCallbackStage : public RenderPipelineStage {
225 public:
WriteToPixelCallbackStage(const std::function<void (const float *,size_t,size_t,size_t)> & pixel_callback,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)226 WriteToPixelCallbackStage(
227 const std::function<void(const float*, size_t, size_t, size_t)>&
228 pixel_callback,
229 size_t width, size_t height, bool rgba, bool has_alpha, size_t alpha_c)
230 : RenderPipelineStage(RenderPipelineStage::Settings()),
231 pixel_callback_(pixel_callback),
232 width_(width),
233 height_(height),
234 rgba_(rgba),
235 has_alpha_(has_alpha),
236 alpha_c_(alpha_c),
237 opaque_alpha_(kMaxPixelsPerCall, 1.0f) {
238 settings_.temp_buffer_size = kMaxPixelsPerCall * (rgba_ ? 4 : 3);
239 }
240
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const241 void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
242 size_t xextra, size_t xsize, size_t xpos, size_t ypos,
243 float* JXL_RESTRICT temp) const final {
244 if (ypos >= height_) return;
245 const float* line_buffers[4];
246 for (size_t c = 0; c < 3; c++) {
247 line_buffers[c] = GetInputRow(input_rows, c, 0);
248 }
249 if (has_alpha_) {
250 line_buffers[3] = GetInputRow(input_rows, alpha_c_, 0);
251 } else {
252 line_buffers[3] = opaque_alpha_.data();
253 }
254 // TODO(veluca): SIMD.
255 ssize_t limit = std::min(xextra + xsize, width_ - xpos);
256 for (ssize_t x0 = -xextra; x0 < limit; x0 += kMaxPixelsPerCall) {
257 size_t j = 0;
258 size_t ix = 0;
259 for (; ix < kMaxPixelsPerCall && ssize_t(ix) + x0 < limit; ix++) {
260 temp[j++] = line_buffers[0][x0 + ix];
261 temp[j++] = line_buffers[1][x0 + ix];
262 temp[j++] = line_buffers[2][x0 + ix];
263 if (rgba_) {
264 temp[j++] = line_buffers[3][x0 + ix];
265 }
266 }
267 pixel_callback_(temp, xpos + x0, ypos, ix);
268 }
269 }
270
GetChannelMode(size_t c) const271 RenderPipelineChannelMode GetChannelMode(size_t c) const final {
272 return c < 3 || (has_alpha_ && c == alpha_c_)
273 ? RenderPipelineChannelMode::kInput
274 : RenderPipelineChannelMode::kIgnored;
275 }
276
277 private:
278 static constexpr size_t kMaxPixelsPerCall = 1024;
279 const std::function<void(const float*, size_t, size_t, size_t)>&
280 pixel_callback_;
281 size_t width_;
282 size_t height_;
283 bool rgba_;
284 bool has_alpha_;
285 size_t alpha_c_;
286 std::vector<float> opaque_alpha_;
287 };
288
289 } // namespace
290
GetWriteToImageBundleStage(ImageBundle * image_bundle)291 std::unique_ptr<RenderPipelineStage> GetWriteToImageBundleStage(
292 ImageBundle* image_bundle) {
293 return jxl::make_unique<WriteToImageBundleStage>(image_bundle);
294 }
295
GetWriteToImage3FStage(Image3F * image)296 std::unique_ptr<RenderPipelineStage> GetWriteToImage3FStage(Image3F* image) {
297 return jxl::make_unique<WriteToImage3FStage>(image);
298 }
299
GetWriteToU8Stage(uint8_t * rgb,size_t stride,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)300 std::unique_ptr<RenderPipelineStage> GetWriteToU8Stage(
301 uint8_t* rgb, size_t stride, size_t width, size_t height, bool rgba,
302 bool has_alpha, size_t alpha_c) {
303 return HWY_DYNAMIC_DISPATCH(GetWriteToU8Stage)(rgb, stride, width, height,
304 rgba, has_alpha, alpha_c);
305 }
306
GetWriteToPixelCallbackStage(const std::function<void (const float *,size_t,size_t,size_t)> & pixel_callback,size_t width,size_t height,bool rgba,bool has_alpha,size_t alpha_c)307 std::unique_ptr<RenderPipelineStage> GetWriteToPixelCallbackStage(
308 const std::function<void(const float*, size_t, size_t, size_t)>&
309 pixel_callback,
310 size_t width, size_t height, bool rgba, bool has_alpha, size_t alpha_c) {
311 return jxl::make_unique<WriteToPixelCallbackStage>(
312 pixel_callback, width, height, rgba, has_alpha, alpha_c);
313 }
314
315 } // namespace jxl
316
317 #endif
318