1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/compute/kernels/common.h"
19 
20 namespace arrow {
21 
22 using internal::checked_cast;
23 using internal::checked_pointer_cast;
24 using util::string_view;
25 
26 namespace compute {
27 namespace internal {
28 
29 namespace {
30 
31 struct Equal {
32   template <typename T>
Callarrow::compute::internal::__anon6653aee70111::Equal33   static constexpr bool Call(KernelContext*, const T& left, const T& right) {
34     return left == right;
35   }
36 };
37 
38 struct NotEqual {
39   template <typename T>
Callarrow::compute::internal::__anon6653aee70111::NotEqual40   static constexpr bool Call(KernelContext*, const T& left, const T& right) {
41     return left != right;
42   }
43 };
44 
45 struct Greater {
46   template <typename T>
Callarrow::compute::internal::__anon6653aee70111::Greater47   static constexpr bool Call(KernelContext*, const T& left, const T& right) {
48     return left > right;
49   }
50 };
51 
52 struct GreaterEqual {
53   template <typename T>
Callarrow::compute::internal::__anon6653aee70111::GreaterEqual54   static constexpr bool Call(KernelContext*, const T& left, const T& right) {
55     return left >= right;
56   }
57 };
58 
59 // Implement Less, LessEqual by flipping arguments to Greater, GreaterEqual
60 
61 template <typename Op>
AddIntegerCompare(const std::shared_ptr<DataType> & ty,ScalarFunction * func)62 void AddIntegerCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
63   auto exec =
64       GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
65   DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
66 }
67 
68 template <typename InType, typename Op>
AddGenericCompare(const std::shared_ptr<DataType> & ty,ScalarFunction * func)69 void AddGenericCompare(const std::shared_ptr<DataType>& ty, ScalarFunction* func) {
70   DCHECK_OK(
71       func->AddKernel({ty, ty}, boolean(),
72                       applicator::ScalarBinaryEqualTypes<BooleanType, InType, Op>::Exec));
73 }
74 
75 struct CompareFunction : ScalarFunction {
76   using ScalarFunction::ScalarFunction;
77 
DispatchBestarrow::compute::internal::__anon6653aee70111::CompareFunction78   Result<const Kernel*> DispatchBest(std::vector<ValueDescr>* values) const override {
79     RETURN_NOT_OK(CheckArity(*values));
80 
81     using arrow::compute::detail::DispatchExactImpl;
82     if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
83 
84     EnsureDictionaryDecoded(values);
85     ReplaceNullWithOtherType(values);
86 
87     if (auto type = CommonNumeric(*values)) {
88       ReplaceTypes(type, values);
89     } else if (auto type = CommonTimestamp(*values)) {
90       ReplaceTypes(type, values);
91     } else if (auto type = CommonBinary(*values)) {
92       ReplaceTypes(type, values);
93     }
94 
95     if (auto kernel = DispatchExactImpl(this, *values)) return kernel;
96     return arrow::compute::detail::NoMatchingKernel(this, *values);
97   }
98 };
99 
100 template <typename Op>
MakeCompareFunction(std::string name,const FunctionDoc * doc)101 std::shared_ptr<ScalarFunction> MakeCompareFunction(std::string name,
102                                                     const FunctionDoc* doc) {
103   auto func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
104 
105   DCHECK_OK(func->AddKernel(
106       {boolean(), boolean()}, boolean(),
107       applicator::ScalarBinary<BooleanType, BooleanType, BooleanType, Op>::Exec));
108 
109   for (const std::shared_ptr<DataType>& ty : IntTypes()) {
110     AddIntegerCompare<Op>(ty, func.get());
111   }
112   AddIntegerCompare<Op>(date32(), func.get());
113   AddIntegerCompare<Op>(date64(), func.get());
114 
115   AddGenericCompare<FloatType, Op>(float32(), func.get());
116   AddGenericCompare<DoubleType, Op>(float64(), func.get());
117 
118   // Add timestamp kernels
119   for (auto unit : AllTimeUnits()) {
120     InputType in_type(match::TimestampTypeUnit(unit));
121     auto exec =
122         GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
123             int64());
124     DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
125   }
126 
127   // Duration
128   for (auto unit : AllTimeUnits()) {
129     InputType in_type(match::DurationTypeUnit(unit));
130     auto exec =
131         GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
132             int64());
133     DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
134   }
135 
136   // Time32 and Time64
137   for (auto unit : {TimeUnit::SECOND, TimeUnit::MILLI}) {
138     InputType in_type(match::Time32TypeUnit(unit));
139     auto exec =
140         GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
141             int32());
142     DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
143   }
144   for (auto unit : {TimeUnit::MICRO, TimeUnit::NANO}) {
145     InputType in_type(match::Time64TypeUnit(unit));
146     auto exec =
147         GeneratePhysicalInteger<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(
148             int64());
149     DCHECK_OK(func->AddKernel({in_type, in_type}, boolean(), std::move(exec)));
150   }
151 
152   for (const std::shared_ptr<DataType>& ty : BaseBinaryTypes()) {
153     auto exec =
154         GenerateVarBinaryBase<applicator::ScalarBinaryEqualTypes, BooleanType, Op>(*ty);
155     DCHECK_OK(func->AddKernel({ty, ty}, boolean(), std::move(exec)));
156   }
157 
158   return func;
159 }
160 
MakeFlippedFunction(std::string name,const ScalarFunction & func,const FunctionDoc * doc)161 std::shared_ptr<ScalarFunction> MakeFlippedFunction(std::string name,
162                                                     const ScalarFunction& func,
163                                                     const FunctionDoc* doc) {
164   auto flipped_func = std::make_shared<CompareFunction>(name, Arity::Binary(), doc);
165   for (const ScalarKernel* kernel : func.kernels()) {
166     ScalarKernel flipped_kernel = *kernel;
167     flipped_kernel.exec = MakeFlippedBinaryExec(kernel->exec);
168     DCHECK_OK(flipped_func->AddKernel(std::move(flipped_kernel)));
169   }
170   return flipped_func;
171 }
172 
173 const FunctionDoc equal_doc{"Compare values for equality (x == y)",
174                             ("A null on either side emits a null comparison result."),
175                             {"x", "y"}};
176 
177 const FunctionDoc not_equal_doc{"Compare values for inequality (x != y)",
178                                 ("A null on either side emits a null comparison result."),
179                                 {"x", "y"}};
180 
181 const FunctionDoc greater_doc{"Compare values for ordered inequality (x > y)",
182                               ("A null on either side emits a null comparison result."),
183                               {"x", "y"}};
184 
185 const FunctionDoc greater_equal_doc{
186     "Compare values for ordered inequality (x >= y)",
187     ("A null on either side emits a null comparison result."),
188     {"x", "y"}};
189 
190 const FunctionDoc less_doc{"Compare values for ordered inequality (x < y)",
191                            ("A null on either side emits a null comparison result."),
192                            {"x", "y"}};
193 
194 const FunctionDoc less_equal_doc{
195     "Compare values for ordered inequality (x <= y)",
196     ("A null on either side emits a null comparison result."),
197     {"x", "y"}};
198 
199 }  // namespace
200 
RegisterScalarComparison(FunctionRegistry * registry)201 void RegisterScalarComparison(FunctionRegistry* registry) {
202   DCHECK_OK(registry->AddFunction(MakeCompareFunction<Equal>("equal", &equal_doc)));
203   DCHECK_OK(
204       registry->AddFunction(MakeCompareFunction<NotEqual>("not_equal", &not_equal_doc)));
205 
206   auto greater = MakeCompareFunction<Greater>("greater", &greater_doc);
207   auto greater_equal =
208       MakeCompareFunction<GreaterEqual>("greater_equal", &greater_equal_doc);
209 
210   auto less = MakeFlippedFunction("less", *greater, &less_doc);
211   auto less_equal = MakeFlippedFunction("less_equal", *greater_equal, &less_equal_doc);
212   DCHECK_OK(registry->AddFunction(std::move(less)));
213   DCHECK_OK(registry->AddFunction(std::move(less_equal)));
214   DCHECK_OK(registry->AddFunction(std::move(greater)));
215   DCHECK_OK(registry->AddFunction(std::move(greater_equal)));
216 }
217 
218 }  // namespace internal
219 }  // namespace compute
220 }  // namespace arrow
221