1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/compute/kernels/common.h"
19
20 namespace arrow {
21
22 using internal::checked_cast;
23 using internal::checked_pointer_cast;
24 using util::string_view;
25
26 namespace compute {
27 namespace internal {
28
29 namespace {
30
31 struct Equal {
32 template <typename T>
Callarrow::compute::internal::__anon6653aee70111::Equal33 static constexpr bool Call(KernelContext*, const T& left, const T& right) {
34 return left == right;
35 }
36 };
37
38 struct NotEqual {
39 template <typename T>
Callarrow::compute::internal::__anon6653aee70111::NotEqual40 static constexpr bool Call(KernelContext*, const T& left, const T& right) {
41 return left != right;
42 }
43 };
44
45 struct Greater {
46 template <typename T>
Callarrow::compute::internal::__anon6653aee70111::Greater47 static constexpr bool Call(KernelContext*, const T& left, const T& right) {
48 return left > right;
49 }
50 };
51
52 struct GreaterEqual {
53 template <typename T>
Callarrow::compute::internal::__anon6653aee70111::GreaterEqual54 static constexpr bool Call(KernelContext*, const T& left, const T& right) {
55 return left >= right;
56 }
57 };
58
59 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
60
61 template <typename Op>
AddIntegerCompare(const std::shared_ptr<DataType> & ty,ScalarFunction * func)62 void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
63 auto exec =
64 GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
65 DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
66 }
67
68 template <typename InType, typename Op>
AddGenericCompare(const std::shared_ptr<DataType> & ty,ScalarFunction * func)69 void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
70 DCHECK_OK(
71 func->AddKernel({ty, ty}, boolean(),
72 applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec));
73 }
74
75 struct CompareFunction : ScalarFunction {
76 using ScalarFunction::ScalarFunction;
77
DispatchBestarrow::compute::internal::__anon6653aee70111::CompareFunction78 Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
79 RETURN_NOT_OK(CheckArity(*values));
80
81 using arrow::compute::detail::DispatchExactImpl;
82 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
83
84 EnsureDictionaryDecoded(values);
85 ReplaceNullWithOtherType(values);
86
87 if (auto type = CommonNumeric(*values)) {
88 ReplaceTypes(type, values);
89 } else if (auto type = CommonTimestamp(*values)) {
90 ReplaceTypes(type, values);
91 } else if (auto type = CommonBinary(*values)) {
92 ReplaceTypes(type, values);
93 }
94
95 if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
96 return arrow::compute::detail::NoMatchingKernel(this, *values);
97 }
98 };
99
100 template <typename Op>
MakeCompareFunction(std::string name,const FunctionDoc * doc)101 std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name,
102 const FunctionDoc* doc) {
103 auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
104
105 DCHECK_OK(func->AddKernel(
106 {boolean(), boolean()}, boolean(),
107 applicator::ScalarBinary<BooleanType, BooleanType, BooleanType, Op>::Exec));
108
109 for (const std::shared_ptr<DataType>& ty : IntTypes()) {
110 AddIntegerCompare<Op>(ty, func.get());
111 }
112 AddIntegerCompare<Op>(date32(), func.get());
113 AddIntegerCompare<Op>(date64(), func.get());
114
115 AddGenericCompare<FloatType, Op>(float32(), func.get());
116 AddGenericCompare<DoubleType, Op>(float64(), func.get());
117
118 // Add timestamp kernels
119 for (auto unit : AllTimeUnits()) {
120 InputType in_type(match::TimestampTypeUnit(unit));
121 auto exec =
122 GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
123 int64());
124 DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
125 }
126
127 // Duration
128 for (auto unit : AllTimeUnits()) {
129 InputType in_type(match::DurationTypeUnit(unit));
130 auto exec =
131 GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
132 int64());
133 DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
134 }
135
136 // Time32 and Time64
137 for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) {
138 InputType in_type(match::Time32TypeUnit(unit));
139 auto exec =
140 GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
141 int32());
142 DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
143 }
144 for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) {
145 InputType in_type(match::Time64TypeUnit(unit));
146 auto exec =
147 GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
148 int64());
149 DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
150 }
151
152 for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
153 auto exec =
154 GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
155 DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
156 }
157
158 return func;
159 }
160
MakeFlippedFunction(std::string name,const ScalarFunction & func,const FunctionDoc * doc)161 std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name,
162 const ScalarFunction& func,
163 const FunctionDoc* doc) {
164 auto flipped_func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
165 for (const ScalarKernel* kernel : func.kernels()) {
166 ScalarKernel flipped_kernel = *kernel;
167 flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec);
168 DCHECK_OK(flipped_func->AddKernel(std::move(flipped_kernel)));
169 }
170 return flipped_func;
171 }
172
173 const FunctionDoc equal_doc{"Compare values for equality (x == y)",
174 ("A null on either side emits a null comparison result."),
175 {"x", "y"}};
176
177 const FunctionDoc not_equal_doc{"Compare values for inequality (x != y)",
178 ("A null on either side emits a null comparison result."),
179 {"x", "y"}};
180
181 const FunctionDoc greater_doc{"Compare values for ordered inequality (x > y)",
182 ("A null on either side emits a null comparison result."),
183 {"x", "y"}};
184
185 const FunctionDoc greater_equal_doc{
186 "Compare values for ordered inequality (x >= y)",
187 ("A null on either side emits a null comparison result."),
188 {"x", "y"}};
189
190 const FunctionDoc less_doc{"Compare values for ordered inequality (x < y)",
191 ("A null on either side emits a null comparison result."),
192 {"x", "y"}};
193
194 const FunctionDoc less_equal_doc{
195 "Compare values for ordered inequality (x <= y)",
196 ("A null on either side emits a null comparison result."),
197 {"x", "y"}};
198
199 } // namespace
200
RegisterScalarComparison(FunctionRegistry * registry)201 void RegisterScalarComparison(FunctionRegistry* registry) {
202 DCHECK_OK(registry->AddFunction(MakeCompareFunction<Equal>("equal", &equal_doc)));
203 DCHECK_OK(
204 registry->AddFunction(MakeCompareFunction<NotEqual>("not_equal", ¬_equal_doc)));
205
206 auto greater = MakeCompareFunction<Greater>("greater", &greater_doc);
207 auto greater_equal =
208 MakeCompareFunction<GreaterEqual>("greater_equal", &greater_equal_doc);
209
210 auto less = MakeFlippedFunction("less", *greater, &less_doc);
211 auto less_equal = MakeFlippedFunction("less_equal", *greater_equal, &less_equal_doc);
212 DCHECK_OK(registry->AddFunction(std::move(less)));
213 DCHECK_OK(registry->AddFunction(std::move(less_equal)));
214 DCHECK_OK(registry->AddFunction(std::move(greater)));
215 DCHECK_OK(registry->AddFunction(std::move(greater_equal)));
216 }
217
218 } // namespace internal
219 } // namespace compute
220 } // namespace arrow
221