1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/dataset/scanner.h"
19 
20 #include <memory>
21 
22 #include "arrow/dataset/test_util.h"
23 #include "arrow/record_batch.h"
24 #include "arrow/testing/generator.h"
25 #include "arrow/testing/util.h"
26 
27 namespace arrow {
28 namespace dataset {
29 
30 class TestScanner : public DatasetFixtureMixin {
31  protected:
32   static constexpr int64_t kNumberChildDatasets = 2;
33   static constexpr int64_t kNumberBatches = 16;
34   static constexpr int64_t kBatchSize = 1024;
35 
MakeScanner(std::shared_ptr<RecordBatch> batch)36   Scanner MakeScanner(std::shared_ptr<RecordBatch> batch) {
37     std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches),
38                                                       batch};
39 
40     DatasetVector children{static_cast<size_t>(kNumberChildDatasets),
41                            std::make_shared<InMemoryDataset>(batch->schema(), batches)};
42 
43     EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children));
44 
45     return Scanner{dataset, options_, ctx_};
46   }
47 
AssertScannerEqualsRepetitionsOf(Scanner scanner,std::shared_ptr<RecordBatch> batch,const int64_t total_batches=kNumberChildDatasets * kNumberBatches)48   void AssertScannerEqualsRepetitionsOf(
49       Scanner scanner, std::shared_ptr<RecordBatch> batch,
50       const int64_t total_batches = kNumberChildDatasets * kNumberBatches) {
51     auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
52 
53     // Verifies that the unified BatchReader is equivalent to flattening all the
54     // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]]
55     AssertScannerEquals(expected.get(), &scanner);
56   }
57 };  // namespace dataset
58 
59 constexpr int64_t TestScanner::kNumberChildDatasets;
60 constexpr int64_t TestScanner::kNumberBatches;
61 constexpr int64_t TestScanner::kBatchSize;
62 
TEST_F(TestScanner,Scan)63 TEST_F(TestScanner, Scan) {
64   SetSchema({field("i32", int32()), field("f64", float64())});
65   auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
66   AssertScannerEqualsRepetitionsOf(MakeScanner(batch), batch);
67 }
68 
TEST_F(TestScanner,ScanWithCappedBatchSize)69 TEST_F(TestScanner, ScanWithCappedBatchSize) {
70   SetSchema({field("i32", int32()), field("f64", float64())});
71   auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
72   options_->batch_size = kBatchSize / 2;
73   auto expected = batch->Slice(kBatchSize / 2);
74   AssertScannerEqualsRepetitionsOf(MakeScanner(batch), expected,
75                                    kNumberChildDatasets * kNumberBatches * 2);
76 }
77 
TEST_F(TestScanner,FilteredScan)78 TEST_F(TestScanner, FilteredScan) {
79   SetSchema({field("f64", float64())});
80 
81   double value = 0.5;
82   ASSERT_OK_AND_ASSIGN(auto f64,
83                        ArrayFromBuilderVisitor(float64(), kBatchSize, kBatchSize / 2,
84                                                [&](DoubleBuilder* builder) {
85                                                  builder->UnsafeAppend(value);
86                                                  builder->UnsafeAppend(-value);
87                                                  value += 1.0;
88                                                }));
89 
90   options_->filter = ("f64"_ > 0.0).Copy();
91   options_->evaluator = std::make_shared<TreeEvaluator>();
92 
93   auto batch = RecordBatch::Make(schema_, f64->length(), {f64});
94 
95   value = 0.5;
96   ASSERT_OK_AND_ASSIGN(
97       auto f64_filtered,
98       ArrayFromBuilderVisitor(float64(), kBatchSize / 2, [&](DoubleBuilder* builder) {
99         builder->UnsafeAppend(value);
100         value += 1.0;
101       }));
102 
103   auto filtered_batch =
104       RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered});
105 
106   AssertScannerEqualsRepetitionsOf(MakeScanner(batch), filtered_batch);
107 }
108 
TEST_F(TestScanner,MaterializeMissingColumn)109 TEST_F(TestScanner, MaterializeMissingColumn) {
110   SetSchema({field("i32", int32()), field("f64", float64())});
111   auto batch_missing_f64 =
112       ConstantArrayGenerator::Zeroes(kBatchSize, schema({field("i32", int32())}));
113 
114   ASSERT_OK(options_->projector.SetDefaultValue(schema_->GetFieldIndex("f64"),
115                                                 MakeScalar(2.5)));
116 
117   ASSERT_OK_AND_ASSIGN(auto f64, ArrayFromBuilderVisitor(float64(), kBatchSize,
118                                                          [&](DoubleBuilder* builder) {
119                                                            builder->UnsafeAppend(2.5);
120                                                          }));
121   auto batch_with_f64 =
122       RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64});
123 
124   AssertScannerEqualsRepetitionsOf(MakeScanner(batch_missing_f64), batch_with_f64);
125 }
126 
TEST_F(TestScanner,ToTable)127 TEST_F(TestScanner, ToTable) {
128   SetSchema({field("i32", int32()), field("f64", float64())});
129   auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
130   std::vector<std::shared_ptr<RecordBatch>> batches{kNumberBatches * kNumberChildDatasets,
131                                                     batch};
132 
133   ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
134 
135   auto scanner = MakeScanner(batch);
136   std::shared_ptr<Table> actual;
137 
138   ctx_->use_threads = false;
139   ASSERT_OK_AND_ASSIGN(actual, scanner.ToTable());
140   AssertTablesEqual(*expected, *actual);
141 
142   // There is no guarantee on the ordering when using multiple threads, but
143   // since the RecordBatch is always the same it will pass.
144   ctx_->use_threads = true;
145   ASSERT_OK_AND_ASSIGN(actual, scanner.ToTable());
146   AssertTablesEqual(*expected, *actual);
147 }
148 
149 class TestScannerBuilder : public ::testing::Test {
SetUp()150   void SetUp() {
151     DatasetVector sources;
152 
153     schema_ = schema({
154         field("b", boolean()),
155         field("i8", int8()),
156         field("i16", int16()),
157         field("i32", int32()),
158         field("i64", int64()),
159     });
160 
161     ASSERT_OK_AND_ASSIGN(dataset_, UnionDataset::Make(schema_, sources));
162   }
163 
164  protected:
165   std::shared_ptr<ScanContext> ctx_;
166   std::shared_ptr<Schema> schema_;
167   std::shared_ptr<Dataset> dataset_;
168 };
169 
TEST_F(TestScannerBuilder,TestProject)170 TEST_F(TestScannerBuilder, TestProject) {
171   ScannerBuilder builder(dataset_, ctx_);
172 
173   // It is valid to request no columns, e.g. `SELECT 1 FROM t WHERE t.a > 0`.
174   // still needs to touch the `a` column.
175   ASSERT_OK(builder.Project({}));
176   ASSERT_OK(builder.Project({"i64", "b", "i8"}));
177   ASSERT_OK(builder.Project({"i16", "i16"}));
178 
179   ASSERT_RAISES(Invalid, builder.Project({"not_found_column"}));
180   ASSERT_RAISES(Invalid, builder.Project({"i8", "not_found_column"}));
181 }
182 
TEST_F(TestScannerBuilder,TestFilter)183 TEST_F(TestScannerBuilder, TestFilter) {
184   ScannerBuilder builder(dataset_, ctx_);
185 
186   ASSERT_OK(builder.Filter(scalar(true)));
187   ASSERT_OK(builder.Filter("i64"_ == int64_t(10)));
188   ASSERT_OK(builder.Filter("i64"_ == int64_t(10) || "b"_ == true));
189 
190   ASSERT_RAISES(TypeError, builder.Filter("i64"_ == int32_t(10)));
191   ASSERT_RAISES(Invalid, builder.Filter("not_a_column"_ == true));
192   ASSERT_RAISES(Invalid,
193                 builder.Filter("i64"_ == int64_t(10) || "not_a_column"_ == true));
194 }
195 
196 using testing::ElementsAre;
197 using testing::IsEmpty;
198 
TEST(ScanOptions,TestMaterializedFields)199 TEST(ScanOptions, TestMaterializedFields) {
200   auto i32 = field("i32", int32());
201   auto i64 = field("i64", int64());
202 
203   auto opts = ScanOptions::Make(schema({}));
204   EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
205 
206   opts->filter = ("i32"_ == 10).Copy();
207   EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
208 
209   opts = ScanOptions::Make(schema({i32, i64}));
210   EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
211 
212   opts = opts->ReplaceSchema(schema({i32}));
213   EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
214 
215   opts->filter = ("i32"_ == 10).Copy();
216   EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32"));
217 
218   opts->filter = ("i64"_ == 10).Copy();
219   EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
220 }
221 
222 }  // namespace dataset
223 }  // namespace arrow
224