1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements.  See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.  The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License.  You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied.  See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 #include "arrow/dbi/hiveserver2/operation.h"
19 
20 #include <memory>
21 #include <sstream>
22 #include <string>
23 #include <vector>
24 
25 #include <gtest/gtest.h>
26 
27 #include "arrow/dbi/hiveserver2/service.h"
28 #include "arrow/dbi/hiveserver2/session.h"
29 #include "arrow/dbi/hiveserver2/thrift_internal.h"
30 
31 #include "arrow/status.h"
32 #include "arrow/testing/gtest_util.h"
33 
34 namespace arrow {
35 namespace hiveserver2 {
36 
GetTestHost()37 static std::string GetTestHost() {
38   const char* host = std::getenv("ARROW_HIVESERVER2_TEST_HOST");
39   return host == nullptr ? "localhost" : std::string(host);
40 }
41 
42 // Convenience functions for finding a row of values given several columns.
43 template <typename VType, typename CType>
FindRow(VType value,CType * column)44 bool FindRow(VType value, CType* column) {
45   for (int i = 0; i < column->length(); ++i) {
46     if (column->data()[i] == value) {
47       return true;
48     }
49   }
50   return false;
51 }
52 
53 template <typename V1Type, typename V2Type, typename C1Type, typename C2Type>
FindRow(V1Type value1,V2Type value2,C1Type * column1,C2Type * column2)54 bool FindRow(V1Type value1, V2Type value2, C1Type* column1, C2Type* column2) {
55   EXPECT_EQ(column1->length(), column2->length());
56   for (int i = 0; i < column1->length(); ++i) {
57     if (column1->data()[i] == value1 && column2->data()[i] == value2) {
58       return true;
59     }
60   }
61   return false;
62 }
63 
64 template <typename V1Type, typename V2Type, typename V3Type, typename C1Type,
65           typename C2Type, typename C3Type>
FindRow(V1Type value1,V2Type value2,V3Type value3,C1Type * column1,C2Type * column2,C3Type column3)66 bool FindRow(V1Type value1, V2Type value2, V3Type value3, C1Type* column1,
67              C2Type* column2, C3Type column3) {
68   EXPECT_EQ(column1->length(), column2->length());
69   EXPECT_EQ(column1->length(), column3->length());
70   for (int i = 0; i < column1->length(); ++i) {
71     if (column1->data()[i] == value1 && column2->data()[i] == value2 &&
72         column3->data()[i] == value3) {
73       return true;
74     }
75   }
76   return false;
77 }
78 
79 // Waits for this operation to reach the given state, sleeping for sleep microseconds
80 // between checks, and failing after max_retries checks.
Wait(const std::unique_ptr<Operation> & op,Operation::State state=Operation::State::FINISHED,int sleep_us=10000,int max_retries=100)81 Status Wait(const std::unique_ptr<Operation>& op,
82             Operation::State state = Operation::State::FINISHED, int sleep_us = 10000,
83             int max_retries = 100) {
84   int retries = 0;
85   Operation::State op_state;
86   RETURN_NOT_OK(op->GetState(&op_state));
87   while (op_state != state && retries < max_retries) {
88     usleep(sleep_us);
89     RETURN_NOT_OK(op->GetState(&op_state));
90     ++retries;
91   }
92 
93   if (op_state == state) {
94     return Status::OK();
95   } else {
96     return Status::IOError("Failed to reach state '", OperationStateToString(state),
97                            "' after ", retries, " retries");
98   }
99 }
100 
101 // Creates a service, session, and database for use in tests.
102 class HS2ClientTest : public ::testing::Test {
103  protected:
SetUp()104   virtual void SetUp() {
105     hostname_ = GetTestHost();
106 
107     int conn_timeout = 0;
108     ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
109     ASSERT_OK(
110         Service::Connect(hostname_, port, conn_timeout, protocol_version, &service_));
111 
112     std::string user = "user";
113     HS2ClientConfig config;
114     ASSERT_OK(service_->OpenSession(user, config, &session_));
115 
116     std::unique_ptr<Operation> drop_db_op;
117     ASSERT_OK(session_->ExecuteStatement(
118         "drop database if exists " + TEST_DB + " cascade", &drop_db_op));
119     ASSERT_OK(drop_db_op->Close());
120 
121     std::unique_ptr<Operation> create_db_op;
122     ASSERT_OK(session_->ExecuteStatement("create database " + TEST_DB, &create_db_op));
123     ASSERT_OK(create_db_op->Close());
124 
125     std::unique_ptr<Operation> use_db_op;
126     ASSERT_OK(session_->ExecuteStatement("use " + TEST_DB, &use_db_op));
127     ASSERT_OK(use_db_op->Close());
128   }
129 
TearDown()130   virtual void TearDown() {
131     std::unique_ptr<Operation> use_db_op;
132     if (session_) {
133       // We were able to create a session and service
134       ASSERT_OK(session_->ExecuteStatement("use default", &use_db_op));
135       ASSERT_OK(use_db_op->Close());
136 
137       std::unique_ptr<Operation> drop_db_op;
138       ASSERT_OK(session_->ExecuteStatement("drop database " + TEST_DB + " cascade",
139                                            &drop_db_op));
140       ASSERT_OK(drop_db_op->Close());
141 
142       ASSERT_OK(session_->Close());
143       ASSERT_OK(service_->Close());
144     }
145   }
146 
CreateTestTable()147   void CreateTestTable() {
148     std::unique_ptr<Operation> create_table_op;
149     ASSERT_OK(session_->ExecuteStatement(
150         "create table " + TEST_TBL + " (" + TEST_COL1 + " int, " + TEST_COL2 + " string)",
151         &create_table_op));
152     ASSERT_OK(create_table_op->Close());
153   }
154 
InsertIntoTestTable(std::vector<int> int_col_data,std::vector<std::string> string_col_data)155   void InsertIntoTestTable(std::vector<int> int_col_data,
156                            std::vector<std::string> string_col_data) {
157     ASSERT_EQ(int_col_data.size(), string_col_data.size());
158 
159     std::stringstream query;
160     query << "insert into " << TEST_TBL << " VALUES ";
161     for (size_t i = 0; i < int_col_data.size(); i++) {
162       if (int_col_data[i] == NULL_INT_VALUE) {
163         query << " (NULL, ";
164       } else {
165         query << " (" << int_col_data[i] << ", ";
166       }
167 
168       if (string_col_data[i] == "NULL") {
169         query << "NULL)";
170       } else {
171         query << "'" << string_col_data[i] << "')";
172       }
173 
174       if (i != int_col_data.size() - 1) {
175         query << ", ";
176       }
177     }
178 
179     std::unique_ptr<Operation> insert_op;
180     ASSERT_OK(session_->ExecuteStatement(query.str(), &insert_op));
181     ASSERT_OK(Wait(insert_op));
182     Operation::State insert_op_state;
183     ASSERT_OK(insert_op->GetState(&insert_op_state));
184     ASSERT_EQ(insert_op_state, Operation::State::FINISHED);
185     ASSERT_OK(insert_op->Close());
186   }
187   std::string hostname_;
188 
189   int port = 21050;
190 
191   const std::string TEST_DB = "hs2client_test_db";
192   const std::string TEST_TBL = "hs2client_test_table";
193   const std::string TEST_COL1 = "int_col";
194   const std::string TEST_COL2 = "string_col";
195 
196   const int NULL_INT_VALUE = -1;
197 
198   std::unique_ptr<Service> service_;
199   std::unique_ptr<Session> session_;
200 };
201 
202 class OperationTest : public HS2ClientTest {};
203 
TEST_F(OperationTest,TestFetch)204 TEST_F(OperationTest, TestFetch) {
205   CreateTestTable();
206   InsertIntoTestTable(std::vector<int>({1, 2, 3, 4}),
207                       std::vector<std::string>({"a", "b", "c", "d"}));
208 
209   std::unique_ptr<Operation> select_op;
210   ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL + " order by int_col",
211                                        &select_op));
212 
213   std::unique_ptr<ColumnarRowSet> results;
214   bool has_more_rows = false;
215   // Impala only supports NEXT and FIRST.
216   ASSERT_RAISES(IOError,
217                 select_op->Fetch(2, FetchOrientation::LAST, &results, &has_more_rows));
218 
219   // Fetch the results in two batches by passing max_rows to Fetch.
220   ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
221   ASSERT_OK(Wait(select_op));
222   ASSERT_TRUE(select_op->HasResultSet());
223   std::unique_ptr<Int32Column> int_col = results->GetInt32Col(0);
224   std::unique_ptr<StringColumn> string_col = results->GetStringCol(1);
225   ASSERT_EQ(int_col->data(), std::vector<int>({1, 2}));
226   ASSERT_EQ(string_col->data(), std::vector<std::string>({"a", "b"}));
227   ASSERT_TRUE(has_more_rows);
228 
229   ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
230   int_col = results->GetInt32Col(0);
231   string_col = results->GetStringCol(1);
232   ASSERT_EQ(int_col->data(), std::vector<int>({3, 4}));
233   ASSERT_EQ(string_col->data(), std::vector<std::string>({"c", "d"}));
234 
235   ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
236   int_col = results->GetInt32Col(0);
237   string_col = results->GetStringCol(1);
238   ASSERT_EQ(int_col->length(), 0);
239   ASSERT_EQ(string_col->length(), 0);
240   ASSERT_FALSE(has_more_rows);
241 
242   ASSERT_OK(select_op->Fetch(2, FetchOrientation::NEXT, &results, &has_more_rows));
243   int_col = results->GetInt32Col(0);
244   string_col = results->GetStringCol(1);
245   ASSERT_EQ(int_col->length(), 0);
246   ASSERT_EQ(string_col->length(), 0);
247   ASSERT_FALSE(has_more_rows);
248 
249   ASSERT_OK(select_op->Close());
250 }
251 
TEST_F(OperationTest,TestIsNull)252 TEST_F(OperationTest, TestIsNull) {
253   CreateTestTable();
254   // Insert some NULLs and ensure Column::IsNull() is correct.
255   InsertIntoTestTable(std::vector<int>({1, 2, 3, 4, 5, NULL_INT_VALUE}),
256                       std::vector<std::string>({"a", "b", "NULL", "d", "NULL", "f"}));
257 
258   std::unique_ptr<Operation> select_nulls_op;
259   ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL + " order by int_col",
260                                        &select_nulls_op));
261 
262   std::unique_ptr<ColumnarRowSet> nulls_results;
263   bool has_more_rows = false;
264   ASSERT_OK(select_nulls_op->Fetch(&nulls_results, &has_more_rows));
265   std::unique_ptr<Int32Column> int_col = nulls_results->GetInt32Col(0);
266   std::unique_ptr<StringColumn> string_col = nulls_results->GetStringCol(1);
267   ASSERT_EQ(int_col->length(), 6);
268   ASSERT_EQ(int_col->length(), string_col->length());
269 
270   bool int_nulls[] = {false, false, false, false, false, true};
271   for (int i = 0; i < int_col->length(); i++) {
272     ASSERT_EQ(int_col->IsNull(i), int_nulls[i]);
273   }
274   bool string_nulls[] = {false, false, true, false, true, false};
275   for (int i = 0; i < string_col->length(); i++) {
276     ASSERT_EQ(string_col->IsNull(i), string_nulls[i]);
277   }
278 
279   ASSERT_OK(select_nulls_op->Close());
280 }
281 
TEST_F(OperationTest,TestCancel)282 TEST_F(OperationTest, TestCancel) {
283   CreateTestTable();
284   InsertIntoTestTable(std::vector<int>({1, 2, 3, 4}),
285                       std::vector<std::string>({"a", "b", "c", "d"}));
286 
287   std::unique_ptr<Operation> op;
288   ASSERT_OK(session_->ExecuteStatement("select count(*) from " + TEST_TBL, &op));
289   ASSERT_OK(op->Cancel());
290   // Impala currently returns ERROR and not CANCELED for canceled queries
291   // due to the use of beeswax states, which don't support a canceled state.
292   ASSERT_OK(Wait(op, Operation::State::ERROR));
293 
294   std::string profile;
295   ASSERT_OK(op->GetProfile(&profile));
296   ASSERT_TRUE(profile.find("Cancelled") != std::string::npos);
297 
298   ASSERT_OK(op->Close());
299 }
300 
TEST_F(OperationTest,TestGetLog)301 TEST_F(OperationTest, TestGetLog) {
302   CreateTestTable();
303 
304   std::unique_ptr<Operation> op;
305   ASSERT_OK(session_->ExecuteStatement("select count(*) from " + TEST_TBL, &op));
306   std::string log;
307   ASSERT_OK(op->GetLog(&log));
308   ASSERT_NE(log, "");
309 
310   ASSERT_OK(op->Close());
311 }
312 
TEST_F(OperationTest,TestGetResultSetMetadata)313 TEST_F(OperationTest, TestGetResultSetMetadata) {
314   const std::string TEST_COL1 = "int_col";
315   const std::string TEST_COL2 = "varchar_col";
316   const int MAX_LENGTH = 10;
317   const std::string TEST_COL3 = "decimal_cal";
318   const int PRECISION = 5;
319   const int SCALE = 3;
320   std::stringstream create_query;
321   create_query << "create table " << TEST_TBL << " (" << TEST_COL1 << " int, "
322                << TEST_COL2 << " varchar(" << MAX_LENGTH << "), " << TEST_COL3
323                << " decimal(" << PRECISION << ", " << SCALE << "))";
324   std::unique_ptr<Operation> create_table_op;
325   ASSERT_OK(session_->ExecuteStatement(create_query.str(), &create_table_op));
326   ASSERT_OK(create_table_op->Close());
327 
328   // Perform a select, and check that we get the right metadata back.
329   std::unique_ptr<Operation> select_op;
330   ASSERT_OK(session_->ExecuteStatement("select * from " + TEST_TBL, &select_op));
331   std::vector<ColumnDesc> column_descs;
332   ASSERT_OK(select_op->GetResultSetMetadata(&column_descs));
333   ASSERT_EQ(column_descs.size(), 3);
334 
335   ASSERT_EQ(column_descs[0].column_name(), TEST_COL1);
336   ASSERT_EQ(column_descs[0].type()->ToString(), "INT");
337   ASSERT_EQ(column_descs[0].type()->type_id(), ColumnType::TypeId::INT);
338   ASSERT_EQ(column_descs[0].position(), 0);
339 
340   ASSERT_EQ(column_descs[1].column_name(), TEST_COL2);
341   ASSERT_EQ(column_descs[1].type()->ToString(), "VARCHAR");
342   ASSERT_EQ(column_descs[1].type()->type_id(), ColumnType::TypeId::VARCHAR);
343   ASSERT_EQ(column_descs[1].position(), 1);
344   ASSERT_EQ(column_descs[1].GetCharacterType()->max_length(), MAX_LENGTH);
345 
346   ASSERT_EQ(column_descs[2].column_name(), TEST_COL3);
347   ASSERT_EQ(column_descs[2].type()->ToString(), "DECIMAL");
348   ASSERT_EQ(column_descs[2].type()->type_id(), ColumnType::TypeId::DECIMAL);
349   ASSERT_EQ(column_descs[2].position(), 2);
350   ASSERT_EQ(column_descs[2].GetDecimalType()->precision(), PRECISION);
351   ASSERT_EQ(column_descs[2].GetDecimalType()->scale(), SCALE);
352 
353   ASSERT_OK(select_op->Close());
354 
355   // Insert ops don't have result sets.
356   std::stringstream insert_query;
357   insert_query << "insert into " << TEST_TBL << " VALUES (1, cast('a' as varchar("
358                << MAX_LENGTH << ")), cast(1 as decimal(" << PRECISION << ", " << SCALE
359                << ")))";
360   std::unique_ptr<Operation> insert_op;
361   ASSERT_OK(session_->ExecuteStatement(insert_query.str(), &insert_op));
362   std::vector<ColumnDesc> insert_column_descs;
363   ASSERT_OK(insert_op->GetResultSetMetadata(&insert_column_descs));
364   ASSERT_EQ(insert_column_descs.size(), 0);
365   ASSERT_OK(insert_op->Close());
366 }
367 
368 class SessionTest : public HS2ClientTest {};
369 
TEST_F(SessionTest,TestSessionConfig)370 TEST_F(SessionTest, TestSessionConfig) {
371   // Create a table in TEST_DB.
372   const std::string& TEST_TBL = "hs2client_test_table";
373   std::unique_ptr<Operation> create_table_op;
374   ASSERT_OK(session_->ExecuteStatement(
375       "create table " + TEST_TBL + " (int_col int, string_col string)",
376       &create_table_op));
377   ASSERT_OK(create_table_op->Close());
378 
379   // Start a new session with the use:database session option.
380   std::string user = "user";
381   HS2ClientConfig config_use;
382   config_use.SetOption("use:database", TEST_DB);
383   std::unique_ptr<Session> session_ok;
384   ASSERT_OK(service_->OpenSession(user, config_use, &session_ok));
385 
386   // Ensure the use:database worked and we can access the table.
387   std::unique_ptr<Operation> select_op;
388   ASSERT_OK(session_ok->ExecuteStatement("select * from " + TEST_TBL, &select_op));
389   ASSERT_OK(select_op->Close());
390   ASSERT_OK(session_ok->Close());
391 
392   // Start another session without use:database.
393   HS2ClientConfig config_no_use;
394   std::unique_ptr<Session> session_error;
395   ASSERT_OK(service_->OpenSession(user, config_no_use, &session_error));
396 
397   // Ensure the we can't access the table.
398   std::unique_ptr<Operation> select_op_error;
399   ASSERT_RAISES(IOError, session_error->ExecuteStatement("select * from " + TEST_TBL,
400                                                          &select_op_error));
401   ASSERT_OK(session_error->Close());
402 }
403 
TEST(ServiceTest,TestConnect)404 TEST(ServiceTest, TestConnect) {
405   // Open a connection.
406   std::string host = GetTestHost();
407   int port = 21050;
408   int conn_timeout = 0;
409   ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
410   std::unique_ptr<Service> service;
411   ASSERT_OK(Service::Connect(host, port, conn_timeout, protocol_version, &service));
412   ASSERT_TRUE(service->IsConnected());
413 
414   // Check that we can start a session.
415   std::string user = "user";
416   HS2ClientConfig config;
417   std::unique_ptr<Session> session1;
418   ASSERT_OK(service->OpenSession(user, config, &session1));
419   ASSERT_OK(session1->Close());
420 
421   // Close the service. We should not be able to open a session.
422   ASSERT_OK(service->Close());
423   ASSERT_FALSE(service->IsConnected());
424   ASSERT_OK(service->Close());
425   std::unique_ptr<Session> session3;
426   ASSERT_RAISES(IOError, service->OpenSession(user, config, &session3));
427   ASSERT_OK(session3->Close());
428 
429   // We should be able to call Close again without errors.
430   ASSERT_OK(service->Close());
431   ASSERT_FALSE(service->IsConnected());
432 }
433 
TEST(ServiceTest,TestFailedConnect)434 TEST(ServiceTest, TestFailedConnect) {
435   std::string host = GetTestHost();
436   int port = 21050;
437 
438   // Set 100ms timeout so these return quickly
439   int conn_timeout = 100;
440 
441   ProtocolVersion protocol_version = ProtocolVersion::PROTOCOL_V7;
442   std::unique_ptr<Service> service;
443 
444   std::string invalid_host = "does_not_exist";
445   ASSERT_RAISES(IOError, Service::Connect(invalid_host, port, conn_timeout,
446                                           protocol_version, &service));
447 
448   int invalid_port = -1;
449   ASSERT_RAISES(IOError, Service::Connect(host, invalid_port, conn_timeout,
450                                           protocol_version, &service));
451 
452   ProtocolVersion invalid_protocol_version = ProtocolVersion::PROTOCOL_V2;
453   ASSERT_RAISES(NotImplemented, Service::Connect(host, port, conn_timeout,
454                                                  invalid_protocol_version, &service));
455 }
456 
457 }  // namespace hiveserver2
458 }  // namespace arrow
459