1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/dataset/scanner.h"
19
20 #include <memory>
21
22 #include "arrow/dataset/test_util.h"
23 #include "arrow/record_batch.h"
24 #include "arrow/testing/generator.h"
25 #include "arrow/testing/util.h"
26
27 namespace arrow {
28 namespace dataset {
29
30 class TestScanner : public DatasetFixtureMixin {
31 protected:
32 static constexpr int64_t kNumberChildDatasets = 2;
33 static constexpr int64_t kNumberBatches = 16;
34 static constexpr int64_t kBatchSize = 1024;
35
MakeScanner(std::shared_ptr<RecordBatch> batch)36 Scanner MakeScanner(std::shared_ptr<RecordBatch> batch) {
37 std::vector<std::shared_ptr<RecordBatch>> batches{static_cast<size_t>(kNumberBatches),
38 batch};
39
40 DatasetVector children{static_cast<size_t>(kNumberChildDatasets),
41 std::make_shared<InMemoryDataset>(batch->schema(), batches)};
42
43 EXPECT_OK_AND_ASSIGN(auto dataset, UnionDataset::Make(batch->schema(), children));
44
45 return Scanner{dataset, options_, ctx_};
46 }
47
AssertScannerEqualsRepetitionsOf(Scanner scanner,std::shared_ptr<RecordBatch> batch,const int64_t total_batches=kNumberChildDatasets * kNumberBatches)48 void AssertScannerEqualsRepetitionsOf(
49 Scanner scanner, std::shared_ptr<RecordBatch> batch,
50 const int64_t total_batches = kNumberChildDatasets * kNumberBatches) {
51 auto expected = ConstantArrayGenerator::Repeat(total_batches, batch);
52
53 // Verifies that the unified BatchReader is equivalent to flattening all the
54 // structures of the scanner, i.e. Scanner[Dataset[ScanTask[RecordBatch]]]
55 AssertScannerEquals(expected.get(), &scanner);
56 }
57 }; // namespace dataset
58
59 constexpr int64_t TestScanner::kNumberChildDatasets;
60 constexpr int64_t TestScanner::kNumberBatches;
61 constexpr int64_t TestScanner::kBatchSize;
62
TEST_F(TestScanner,Scan)63 TEST_F(TestScanner, Scan) {
64 SetSchema({field("i32", int32()), field("f64", float64())});
65 auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
66 AssertScannerEqualsRepetitionsOf(MakeScanner(batch), batch);
67 }
68
TEST_F(TestScanner,ScanWithCappedBatchSize)69 TEST_F(TestScanner, ScanWithCappedBatchSize) {
70 SetSchema({field("i32", int32()), field("f64", float64())});
71 auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
72 options_->batch_size = kBatchSize / 2;
73 auto expected = batch->Slice(kBatchSize / 2);
74 AssertScannerEqualsRepetitionsOf(MakeScanner(batch), expected,
75 kNumberChildDatasets * kNumberBatches * 2);
76 }
77
TEST_F(TestScanner,FilteredScan)78 TEST_F(TestScanner, FilteredScan) {
79 SetSchema({field("f64", float64())});
80
81 double value = 0.5;
82 ASSERT_OK_AND_ASSIGN(auto f64,
83 ArrayFromBuilderVisitor(float64(), kBatchSize, kBatchSize / 2,
84 [&](DoubleBuilder* builder) {
85 builder->UnsafeAppend(value);
86 builder->UnsafeAppend(-value);
87 value += 1.0;
88 }));
89
90 options_->filter = ("f64"_ > 0.0).Copy();
91 options_->evaluator = std::make_shared<TreeEvaluator>();
92
93 auto batch = RecordBatch::Make(schema_, f64->length(), {f64});
94
95 value = 0.5;
96 ASSERT_OK_AND_ASSIGN(
97 auto f64_filtered,
98 ArrayFromBuilderVisitor(float64(), kBatchSize / 2, [&](DoubleBuilder* builder) {
99 builder->UnsafeAppend(value);
100 value += 1.0;
101 }));
102
103 auto filtered_batch =
104 RecordBatch::Make(schema_, f64_filtered->length(), {f64_filtered});
105
106 AssertScannerEqualsRepetitionsOf(MakeScanner(batch), filtered_batch);
107 }
108
TEST_F(TestScanner,MaterializeMissingColumn)109 TEST_F(TestScanner, MaterializeMissingColumn) {
110 SetSchema({field("i32", int32()), field("f64", float64())});
111 auto batch_missing_f64 =
112 ConstantArrayGenerator::Zeroes(kBatchSize, schema({field("i32", int32())}));
113
114 ASSERT_OK(options_->projector.SetDefaultValue(schema_->GetFieldIndex("f64"),
115 MakeScalar(2.5)));
116
117 ASSERT_OK_AND_ASSIGN(auto f64, ArrayFromBuilderVisitor(float64(), kBatchSize,
118 [&](DoubleBuilder* builder) {
119 builder->UnsafeAppend(2.5);
120 }));
121 auto batch_with_f64 =
122 RecordBatch::Make(schema_, f64->length(), {batch_missing_f64->column(0), f64});
123
124 AssertScannerEqualsRepetitionsOf(MakeScanner(batch_missing_f64), batch_with_f64);
125 }
126
TEST_F(TestScanner,ToTable)127 TEST_F(TestScanner, ToTable) {
128 SetSchema({field("i32", int32()), field("f64", float64())});
129 auto batch = ConstantArrayGenerator::Zeroes(kBatchSize, schema_);
130 std::vector<std::shared_ptr<RecordBatch>> batches{kNumberBatches * kNumberChildDatasets,
131 batch};
132
133 ASSERT_OK_AND_ASSIGN(auto expected, Table::FromRecordBatches(batches));
134
135 auto scanner = MakeScanner(batch);
136 std::shared_ptr<Table> actual;
137
138 ctx_->use_threads = false;
139 ASSERT_OK_AND_ASSIGN(actual, scanner.ToTable());
140 AssertTablesEqual(*expected, *actual);
141
142 // There is no guarantee on the ordering when using multiple threads, but
143 // since the RecordBatch is always the same it will pass.
144 ctx_->use_threads = true;
145 ASSERT_OK_AND_ASSIGN(actual, scanner.ToTable());
146 AssertTablesEqual(*expected, *actual);
147 }
148
149 class TestScannerBuilder : public ::testing::Test {
SetUp()150 void SetUp() {
151 DatasetVector sources;
152
153 schema_ = schema({
154 field("b", boolean()),
155 field("i8", int8()),
156 field("i16", int16()),
157 field("i32", int32()),
158 field("i64", int64()),
159 });
160
161 ASSERT_OK_AND_ASSIGN(dataset_, UnionDataset::Make(schema_, sources));
162 }
163
164 protected:
165 std::shared_ptr<ScanContext> ctx_;
166 std::shared_ptr<Schema> schema_;
167 std::shared_ptr<Dataset> dataset_;
168 };
169
TEST_F(TestScannerBuilder,TestProject)170 TEST_F(TestScannerBuilder, TestProject) {
171 ScannerBuilder builder(dataset_, ctx_);
172
173 // It is valid to request no columns, e.g. `SELECT 1 FROM t WHERE t.a > 0`.
174 // still needs to touch the `a` column.
175 ASSERT_OK(builder.Project({}));
176 ASSERT_OK(builder.Project({"i64", "b", "i8"}));
177 ASSERT_OK(builder.Project({"i16", "i16"}));
178
179 ASSERT_RAISES(Invalid, builder.Project({"not_found_column"}));
180 ASSERT_RAISES(Invalid, builder.Project({"i8", "not_found_column"}));
181 }
182
TEST_F(TestScannerBuilder,TestFilter)183 TEST_F(TestScannerBuilder, TestFilter) {
184 ScannerBuilder builder(dataset_, ctx_);
185
186 ASSERT_OK(builder.Filter(scalar(true)));
187 ASSERT_OK(builder.Filter("i64"_ == int64_t(10)));
188 ASSERT_OK(builder.Filter("i64"_ == int64_t(10) || "b"_ == true));
189
190 ASSERT_RAISES(TypeError, builder.Filter("i64"_ == int32_t(10)));
191 ASSERT_RAISES(Invalid, builder.Filter("not_a_column"_ == true));
192 ASSERT_RAISES(Invalid,
193 builder.Filter("i64"_ == int64_t(10) || "not_a_column"_ == true));
194 }
195
196 using testing::ElementsAre;
197 using testing::IsEmpty;
198
TEST(ScanOptions,TestMaterializedFields)199 TEST(ScanOptions, TestMaterializedFields) {
200 auto i32 = field("i32", int32());
201 auto i64 = field("i64", int64());
202
203 auto opts = ScanOptions::Make(schema({}));
204 EXPECT_THAT(opts->MaterializedFields(), IsEmpty());
205
206 opts->filter = ("i32"_ == 10).Copy();
207 EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
208
209 opts = ScanOptions::Make(schema({i32, i64}));
210 EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
211
212 opts = opts->ReplaceSchema(schema({i32}));
213 EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32"));
214
215 opts->filter = ("i32"_ == 10).Copy();
216 EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i32"));
217
218 opts->filter = ("i64"_ == 10).Copy();
219 EXPECT_THAT(opts->MaterializedFields(), ElementsAre("i32", "i64"));
220 }
221
222 } // namespace dataset
223 } // namespace arrow
224