1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 #include "arrow/record_batch.h"
19
20 #include <algorithm>
21 #include <atomic>
22 #include <cstdlib>
23 #include <memory>
24 #include <sstream>
25 #include <string>
26 #include <utility>
27
28 #include "arrow/array.h"
29 #include "arrow/array/validate.h"
30 #include "arrow/pretty_print.h"
31 #include "arrow/status.h"
32 #include "arrow/table.h"
33 #include "arrow/type.h"
34 #include "arrow/util/atomic_shared_ptr.h"
35 #include "arrow/util/iterator.h"
36 #include "arrow/util/logging.h"
37 #include "arrow/util/vector.h"
38
39 namespace arrow {
40
AddColumn(int i,std::string field_name,const std::shared_ptr<Array> & column) const41 Result<std::shared_ptr<RecordBatch>> RecordBatch::AddColumn(
42 int i, std::string field_name, const std::shared_ptr<Array>& column) const {
43 auto field = ::arrow::field(std::move(field_name), column->type());
44 return AddColumn(i, field, column);
45 }
46
AddColumn(int i,std::string field_name,const std::shared_ptr<Array> & column,std::shared_ptr<RecordBatch> * out) const47 Status RecordBatch::AddColumn(int i, std::string field_name,
48 const std::shared_ptr<Array>& column,
49 std::shared_ptr<RecordBatch>* out) const {
50 return AddColumn(i, std::move(field_name), column).Value(out);
51 }
52
AddColumn(int i,const std::shared_ptr<Field> & field,const std::shared_ptr<Array> & column,std::shared_ptr<RecordBatch> * out) const53 Status RecordBatch::AddColumn(int i, const std::shared_ptr<Field>& field,
54 const std::shared_ptr<Array>& column,
55 std::shared_ptr<RecordBatch>* out) const {
56 return AddColumn(i, field, column).Value(out);
57 }
58
RemoveColumn(int i,std::shared_ptr<RecordBatch> * out) const59 Status RecordBatch::RemoveColumn(int i, std::shared_ptr<RecordBatch>* out) const {
60 return RemoveColumn(i).Value(out);
61 }
62
GetColumnByName(const std::string & name) const63 std::shared_ptr<Array> RecordBatch::GetColumnByName(const std::string& name) const {
64 auto i = schema_->GetFieldIndex(name);
65 return i == -1 ? NULLPTR : column(i);
66 }
67
num_columns() const68 int RecordBatch::num_columns() const { return schema_->num_fields(); }
69
70 /// \class SimpleRecordBatch
71 /// \brief A basic, non-lazy in-memory record batch
72 class SimpleRecordBatch : public RecordBatch {
73 public:
SimpleRecordBatch(std::shared_ptr<Schema> schema,int64_t num_rows,std::vector<std::shared_ptr<Array>> columns)74 SimpleRecordBatch(std::shared_ptr<Schema> schema, int64_t num_rows,
75 std::vector<std::shared_ptr<Array>> columns)
76 : RecordBatch(std::move(schema), num_rows), boxed_columns_(std::move(columns)) {
77 columns_.resize(boxed_columns_.size());
78 for (size_t i = 0; i < columns_.size(); ++i) {
79 columns_[i] = boxed_columns_[i]->data();
80 }
81 }
82
SimpleRecordBatch(const std::shared_ptr<Schema> & schema,int64_t num_rows,std::vector<std::shared_ptr<ArrayData>> columns)83 SimpleRecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows,
84 std::vector<std::shared_ptr<ArrayData>> columns)
85 : RecordBatch(std::move(schema), num_rows), columns_(std::move(columns)) {
86 boxed_columns_.resize(schema_->num_fields());
87 }
88
column(int i) const89 std::shared_ptr<Array> column(int i) const override {
90 std::shared_ptr<Array> result = internal::atomic_load(&boxed_columns_[i]);
91 if (!result) {
92 result = MakeArray(columns_[i]);
93 internal::atomic_store(&boxed_columns_[i], result);
94 }
95 return result;
96 }
97
column_data(int i) const98 std::shared_ptr<ArrayData> column_data(int i) const override { return columns_[i]; }
99
column_data() const100 ArrayDataVector column_data() const override { return columns_; }
101
AddColumn(int i,const std::shared_ptr<Field> & field,const std::shared_ptr<Array> & column) const102 Result<std::shared_ptr<RecordBatch>> AddColumn(
103 int i, const std::shared_ptr<Field>& field,
104 const std::shared_ptr<Array>& column) const override {
105 ARROW_CHECK(field != nullptr);
106 ARROW_CHECK(column != nullptr);
107
108 if (!field->type()->Equals(column->type())) {
109 return Status::Invalid("Column data type ", field->type()->name(),
110 " does not match field data type ", column->type()->name());
111 }
112 if (column->length() != num_rows_) {
113 return Status::Invalid(
114 "Added column's length must match record batch's length. Expected length ",
115 num_rows_, " but got length ", column->length());
116 }
117
118 ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->AddField(i, field));
119
120 return RecordBatch::Make(new_schema, num_rows_,
121 internal::AddVectorElement(columns_, i, column->data()));
122 }
123
RemoveColumn(int i) const124 Result<std::shared_ptr<RecordBatch>> RemoveColumn(int i) const override {
125 ARROW_ASSIGN_OR_RAISE(auto new_schema, schema_->RemoveField(i));
126
127 return RecordBatch::Make(new_schema, num_rows_,
128 internal::DeleteVectorElement(columns_, i));
129 }
130
ReplaceSchemaMetadata(const std::shared_ptr<const KeyValueMetadata> & metadata) const131 std::shared_ptr<RecordBatch> ReplaceSchemaMetadata(
132 const std::shared_ptr<const KeyValueMetadata>& metadata) const override {
133 auto new_schema = schema_->WithMetadata(metadata);
134 return RecordBatch::Make(new_schema, num_rows_, columns_);
135 }
136
Slice(int64_t offset,int64_t length) const137 std::shared_ptr<RecordBatch> Slice(int64_t offset, int64_t length) const override {
138 std::vector<std::shared_ptr<ArrayData>> arrays;
139 arrays.reserve(num_columns());
140 for (const auto& field : columns_) {
141 int64_t col_length = std::min(field->length - offset, length);
142 int64_t col_offset = field->offset + offset;
143
144 auto new_data = std::make_shared<ArrayData>(*field);
145 new_data->length = col_length;
146 new_data->offset = col_offset;
147 new_data->null_count = kUnknownNullCount;
148 arrays.emplace_back(new_data);
149 }
150 int64_t num_rows = std::min(num_rows_ - offset, length);
151 return std::make_shared<SimpleRecordBatch>(schema_, num_rows, std::move(arrays));
152 }
153
Validate() const154 Status Validate() const override {
155 if (static_cast<int>(columns_.size()) != schema_->num_fields()) {
156 return Status::Invalid("Number of columns did not match schema");
157 }
158 return RecordBatch::Validate();
159 }
160
161 private:
162 std::vector<std::shared_ptr<ArrayData>> columns_;
163
164 // Caching boxed array data
165 mutable std::vector<std::shared_ptr<Array>> boxed_columns_;
166 };
167
RecordBatch(const std::shared_ptr<Schema> & schema,int64_t num_rows)168 RecordBatch::RecordBatch(const std::shared_ptr<Schema>& schema, int64_t num_rows)
169 : schema_(schema), num_rows_(num_rows) {}
170
Make(std::shared_ptr<Schema> schema,int64_t num_rows,std::vector<std::shared_ptr<Array>> columns)171 std::shared_ptr<RecordBatch> RecordBatch::Make(
172 std::shared_ptr<Schema> schema, int64_t num_rows,
173 std::vector<std::shared_ptr<Array>> columns) {
174 DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
175 return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows, columns);
176 }
177
Make(std::shared_ptr<Schema> schema,int64_t num_rows,std::vector<std::shared_ptr<ArrayData>> columns)178 std::shared_ptr<RecordBatch> RecordBatch::Make(
179 std::shared_ptr<Schema> schema, int64_t num_rows,
180 std::vector<std::shared_ptr<ArrayData>> columns) {
181 DCHECK_EQ(schema->num_fields(), static_cast<int>(columns.size()));
182 return std::make_shared<SimpleRecordBatch>(std::move(schema), num_rows,
183 std::move(columns));
184 }
185
FromStructArray(const std::shared_ptr<Array> & array)186 Result<std::shared_ptr<RecordBatch>> RecordBatch::FromStructArray(
187 const std::shared_ptr<Array>& array) {
188 // TODO fail if null_count != 0?
189 if (array->type_id() != Type::STRUCT) {
190 return Status::Invalid("Cannot construct record batch from array of type ",
191 *array->type());
192 }
193 return Make(arrow::schema(array->type()->fields()), array->length(),
194 array->data()->child_data);
195 }
196
ToStructArray() const197 Result<std::shared_ptr<Array>> RecordBatch::ToStructArray() const {
198 return StructArray::Make(columns(), schema()->fields());
199 }
200
columns() const201 std::vector<std::shared_ptr<Array>> RecordBatch::columns() const {
202 std::vector<std::shared_ptr<Array>> children(num_columns());
203 for (int i = 0; i < num_columns(); ++i) {
204 children[i] = column(i);
205 }
206 return children;
207 }
208
column_name(int i) const209 const std::string& RecordBatch::column_name(int i) const {
210 return schema_->field(i)->name();
211 }
212
Equals(const RecordBatch & other,bool check_metadata) const213 bool RecordBatch::Equals(const RecordBatch& other, bool check_metadata) const {
214 if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
215 return false;
216 }
217
218 if (check_metadata) {
219 if (!schema_->Equals(*other.schema(), /*check_metadata=*/true)) {
220 return false;
221 }
222 }
223
224 for (int i = 0; i < num_columns(); ++i) {
225 if (!column(i)->Equals(other.column(i))) {
226 return false;
227 }
228 }
229
230 return true;
231 }
232
ApproxEquals(const RecordBatch & other) const233 bool RecordBatch::ApproxEquals(const RecordBatch& other) const {
234 if (num_columns() != other.num_columns() || num_rows_ != other.num_rows()) {
235 return false;
236 }
237
238 for (int i = 0; i < num_columns(); ++i) {
239 if (!column(i)->ApproxEquals(other.column(i))) {
240 return false;
241 }
242 }
243
244 return true;
245 }
246
Slice(int64_t offset) const247 std::shared_ptr<RecordBatch> RecordBatch::Slice(int64_t offset) const {
248 return Slice(offset, this->num_rows() - offset);
249 }
250
ToString() const251 std::string RecordBatch::ToString() const {
252 std::stringstream ss;
253 ARROW_CHECK_OK(PrettyPrint(*this, 0, &ss));
254 return ss.str();
255 }
256
Validate() const257 Status RecordBatch::Validate() const {
258 for (int i = 0; i < num_columns(); ++i) {
259 const auto& array = *this->column(i);
260 if (array.length() != num_rows_) {
261 return Status::Invalid("Number of rows in column ", i,
262 " did not match batch: ", array.length(), " vs ", num_rows_);
263 }
264 const auto& schema_type = *schema_->field(i)->type();
265 if (!array.type()->Equals(schema_type)) {
266 return Status::Invalid("Column ", i,
267 " type not match schema: ", array.type()->ToString(), " vs ",
268 schema_type.ToString());
269 }
270 RETURN_NOT_OK(internal::ValidateArray(array));
271 }
272 return Status::OK();
273 }
274
ValidateFull() const275 Status RecordBatch::ValidateFull() const {
276 RETURN_NOT_OK(Validate());
277 for (int i = 0; i < num_columns(); ++i) {
278 const auto& array = *this->column(i);
279 RETURN_NOT_OK(internal::ValidateArrayData(array));
280 }
281 return Status::OK();
282 }
283
284 // ----------------------------------------------------------------------
285 // Base record batch reader
286
ReadAll(std::vector<std::shared_ptr<RecordBatch>> * batches)287 Status RecordBatchReader::ReadAll(std::vector<std::shared_ptr<RecordBatch>>* batches) {
288 while (true) {
289 std::shared_ptr<RecordBatch> batch;
290 RETURN_NOT_OK(ReadNext(&batch));
291 if (!batch) {
292 break;
293 }
294 batches->emplace_back(std::move(batch));
295 }
296 return Status::OK();
297 }
298
ReadAll(std::shared_ptr<Table> * table)299 Status RecordBatchReader::ReadAll(std::shared_ptr<Table>* table) {
300 std::vector<std::shared_ptr<RecordBatch>> batches;
301 RETURN_NOT_OK(ReadAll(&batches));
302 return Table::FromRecordBatches(schema(), std::move(batches)).Value(table);
303 }
304
305 class SimpleRecordBatchReader : public RecordBatchReader {
306 public:
SimpleRecordBatchReader(Iterator<std::shared_ptr<RecordBatch>> it,std::shared_ptr<Schema> schema)307 SimpleRecordBatchReader(Iterator<std::shared_ptr<RecordBatch>> it,
308 std::shared_ptr<Schema> schema)
309 : schema_(std::move(schema)), it_(std::move(it)) {}
310
SimpleRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,std::shared_ptr<Schema> schema)311 SimpleRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,
312 std::shared_ptr<Schema> schema)
313 : schema_(std::move(schema)), it_(MakeVectorIterator(std::move(batches))) {}
314
ReadNext(std::shared_ptr<RecordBatch> * batch)315 Status ReadNext(std::shared_ptr<RecordBatch>* batch) override {
316 return it_.Next().Value(batch);
317 }
318
schema() const319 std::shared_ptr<Schema> schema() const override { return schema_; }
320
321 protected:
322 std::shared_ptr<Schema> schema_;
323 Iterator<std::shared_ptr<RecordBatch>> it_;
324 };
325
Make(std::vector<std::shared_ptr<RecordBatch>> batches,std::shared_ptr<Schema> schema)326 Result<std::shared_ptr<RecordBatchReader>> RecordBatchReader::Make(
327 std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema) {
328 if (schema == nullptr) {
329 if (batches.size() == 0 || batches[0] == nullptr) {
330 return Status::Invalid("Cannot infer schema from empty vector or nullptr");
331 }
332
333 schema = batches[0]->schema();
334 }
335
336 return std::make_shared<SimpleRecordBatchReader>(std::move(batches), schema);
337 }
338
MakeRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,std::shared_ptr<Schema> schema)339 Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader(
340 std::vector<std::shared_ptr<RecordBatch>> batches, std::shared_ptr<Schema> schema) {
341 return RecordBatchReader::Make(std::move(batches), std::move(schema));
342 }
343
MakeRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,std::shared_ptr<Schema> schema,std::shared_ptr<RecordBatchReader> * out)344 Status MakeRecordBatchReader(std::vector<std::shared_ptr<RecordBatch>> batches,
345 std::shared_ptr<Schema> schema,
346 std::shared_ptr<RecordBatchReader>* out) {
347 return RecordBatchReader::Make(std::move(batches), std::move(schema)).Value(out);
348 }
349
350 } // namespace arrow
351