1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/render_pipeline/stage_chroma_upsampling.h"
7 
8 #undef HWY_TARGET_INCLUDE
9 #define HWY_TARGET_INCLUDE "lib/jxl/render_pipeline/stage_chroma_upsampling.cc"
10 #include <hwy/foreach_target.h>
11 #include <hwy/highway.h>
12 
13 #include "lib/jxl/simd_util-inl.h"
14 
15 HWY_BEFORE_NAMESPACE();
16 namespace jxl {
17 namespace HWY_NAMESPACE {
18 
19 class HorizontalChromaUpsamplingStage : public RenderPipelineStage {
20  public:
HorizontalChromaUpsamplingStage(size_t channel)21   explicit HorizontalChromaUpsamplingStage(size_t channel)
22       : RenderPipelineStage(RenderPipelineStage::Settings::ShiftX(
23             /*shift=*/1, /*border=*/1)),
24         c_(channel) {}
25 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const26   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
27                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
28                   float* JXL_RESTRICT temp) const final {
29     PROFILER_ZONE("HorizontalChromaUpsampling");
30     HWY_FULL(float) df;
31     xextra = RoundUpTo(xextra, Lanes(df));
32     auto threefour = Set(df, 0.75f);
33     auto onefour = Set(df, 0.25f);
34     const float* row_in = GetInputRow(input_rows, c_, 0);
35     float* row_out = GetOutputRow(output_rows, c_, 0);
36     for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra);
37          x += Lanes(df)) {
38       auto current = Load(df, row_in + x) * threefour;
39       auto prev = LoadU(df, row_in + x - 1);
40       auto next = LoadU(df, row_in + x + 1);
41       auto left = MulAdd(onefour, prev, current);
42       auto right = MulAdd(onefour, next, current);
43       StoreInterleaved(df, left, right, row_out + x * 2);
44     }
45   }
46 
GetChannelMode(size_t c) const47   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
48     return c == c_ ? RenderPipelineChannelMode::kInOut
49                    : RenderPipelineChannelMode::kIgnored;
50   }
51 
52  private:
53   size_t c_;
54 };
55 
56 class VerticalChromaUpsamplingStage : public RenderPipelineStage {
57  public:
VerticalChromaUpsamplingStage(size_t channel)58   explicit VerticalChromaUpsamplingStage(size_t channel)
59       : RenderPipelineStage(RenderPipelineStage::Settings::ShiftY(
60             /*shift=*/1, /*border=*/1)),
61         c_(channel) {}
62 
ProcessRow(const RowInfo & input_rows,const RowInfo & output_rows,size_t xextra,size_t xsize,size_t xpos,size_t ypos,float * JXL_RESTRICT temp) const63   void ProcessRow(const RowInfo& input_rows, const RowInfo& output_rows,
64                   size_t xextra, size_t xsize, size_t xpos, size_t ypos,
65                   float* JXL_RESTRICT temp) const final {
66     PROFILER_ZONE("VerticalChromaUpsampling");
67     HWY_FULL(float) df;
68     xextra = RoundUpTo(xextra, Lanes(df));
69     auto threefour = Set(df, 0.75f);
70     auto onefour = Set(df, 0.25f);
71     const float* row_top = GetInputRow(input_rows, c_, -1);
72     const float* row_mid = GetInputRow(input_rows, c_, 0);
73     const float* row_bot = GetInputRow(input_rows, c_, 1);
74     float* row_out0 = GetOutputRow(output_rows, c_, 0);
75     float* row_out1 = GetOutputRow(output_rows, c_, 1);
76     for (ssize_t x = -xextra; x < static_cast<ssize_t>(xsize + xextra);
77          x += Lanes(df)) {
78       auto it = Load(df, row_top + x);
79       auto im = Load(df, row_mid + x);
80       auto ib = Load(df, row_bot + x);
81       auto im_scaled = im * threefour;
82       Store(MulAdd(it, onefour, im_scaled), df, row_out0 + x);
83       Store(MulAdd(ib, onefour, im_scaled), df, row_out1 + x);
84     }
85   }
86 
GetChannelMode(size_t c) const87   RenderPipelineChannelMode GetChannelMode(size_t c) const final {
88     return c == c_ ? RenderPipelineChannelMode::kInOut
89                    : RenderPipelineChannelMode::kIgnored;
90   }
91 
92  private:
93   size_t c_;
94 };
95 
GetChromaUpsamplingStage(size_t channel,bool horizontal)96 std::unique_ptr<RenderPipelineStage> GetChromaUpsamplingStage(size_t channel,
97                                                               bool horizontal) {
98   if (horizontal) {
99     return jxl::make_unique<HorizontalChromaUpsamplingStage>(channel);
100   } else {
101     return jxl::make_unique<VerticalChromaUpsamplingStage>(channel);
102   }
103 }
104 
105 // NOLINTNEXTLINE(google-readability-namespace-comments)
106 }  // namespace HWY_NAMESPACE
107 }  // namespace jxl
108 HWY_AFTER_NAMESPACE();
109 
110 #if HWY_ONCE
111 namespace jxl {
112 
113 HWY_EXPORT(GetChromaUpsamplingStage);
114 
GetChromaUpsamplingStage(size_t channel,bool horizontal)115 std::unique_ptr<RenderPipelineStage> GetChromaUpsamplingStage(size_t channel,
116                                                               bool horizontal) {
117   return HWY_DYNAMIC_DISPATCH(GetChromaUpsamplingStage)(channel, horizontal);
118 }
119 
120 }  // namespace jxl
121 #endif
122