1 /* Copyright (C) 2014 InfiniDB, Inc.
2    Copyright (C) 2019 MariaDB Corporation
3 
4    This program is free software; you can redistribute it and/or
5    modify it under the terms of the GNU General Public License
6    as published by the Free Software Foundation; version 2 of
7    the License.
8 
9    This program is distributed in the hope that it will be useful,
10    but WITHOUT ANY WARRANTY; without even the implied warranty of
11    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12    GNU General Public License for more details.
13 
14    You should have received a copy of the GNU General Public License
15    along with this program; if not, write to the Free Software
16    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
17    MA 02110-1301, USA. */
18 
19 /***********************************************************************
20 *   $Id$
21 *
22 *
23 ***********************************************************************/
24 /** @file */
25 
26 #ifndef ARITHMETICOPERATOR_H
27 #define ARITHMETICOPERATOR_H
28 #include <string>
29 #include <iosfwd>
30 #include <cmath>
31 #include <sstream>
32 
33 #include "operator.h"
34 #include "parsetree.h"
35 
36 namespace messageqcpp
37 {
38 class ByteStream;
39 }
40 
41 namespace execplan
42 {
43 
44 class ArithmeticOperator : public Operator
45 {
46 
47 public:
48     ArithmeticOperator();
49     ArithmeticOperator(const std::string& operatorName);
50     ArithmeticOperator(const ArithmeticOperator& rhs);
51 
52     virtual ~ArithmeticOperator();
53 
54     /** return a copy of this pointer
55      *
56      * deep copy of this pointer and return the copy
57      */
clone()58     inline virtual ArithmeticOperator* clone() const
59     {
60         return new ArithmeticOperator (*this);
61     }
62 
timeZone()63     inline const std::string& timeZone() const
64     {
65         return fTimeZone;
66     }
timeZone(const std::string & timeZone)67     inline void timeZone(const std::string& timeZone)
68     {
69         fTimeZone = timeZone;
70     }
71 
72     /**
73      * The serialization interface
74      */
75     virtual void serialize(messageqcpp::ByteStream&) const;
76     virtual void unserialize(messageqcpp::ByteStream&);
77 
78     /** @brief Do a deep, strict (as opposed to semantic) equivalence test
79      *
80      * Do a deep, strict (as opposed to semantic) equivalence test.
81      * @return true iff every member of t is a duplicate copy of every member of this; false otherwise
82      */
83     virtual bool operator==(const TreeNode* t) const;
84 
85     /** @brief Do a deep, strict (as opposed to semantic) equivalence test
86      *
87      * Do a deep, strict (as opposed to semantic) equivalence test.
88      * @return true iff every member of t is a duplicate copy of every member of this; false otherwise
89      */
90     bool operator==(const ArithmeticOperator& t) const;
91 
92     /** @brief Do a deep, strict (as opposed to semantic) equivalence test
93      *
94      * Do a deep, strict (as opposed to semantic) equivalence test.
95      * @return false iff every member of t is a duplicate copy of every member of this; true otherwise
96      */
97     virtual bool operator!=(const TreeNode* t) const;
98 
99     /** @brief Do a deep, strict (as opposed to semantic) equivalence test
100      *
101      * Do a deep, strict (as opposed to semantic) equivalence test.
102      * @return false iff every member of t is a duplicate copy of every member of this; true otherwise
103      */
104     bool operator!=(const ArithmeticOperator& t) const;
105 
106     /***********************************************************
107      *                 F&E framework                           *
108      ***********************************************************/
109     using Operator::evaluate;
110     inline virtual void evaluate(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop);
111 
112     using Operator::getStrVal;
getStrVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)113     virtual const std::string& getStrVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
114     {
115         evaluate(row, isNull, lop, rop);
116         return TreeNode::getStrVal(fTimeZone);
117     }
118     using Operator::getIntVal;
getIntVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)119     virtual int64_t getIntVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
120     {
121         evaluate(row, isNull, lop, rop);
122         return TreeNode::getIntVal();
123     }
124     using Operator::getUintVal;
getUintVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)125     virtual uint64_t getUintVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
126     {
127         evaluate(row, isNull, lop, rop);
128         return TreeNode::getUintVal();
129     }
130     using Operator::getFloatVal;
getFloatVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)131     virtual float getFloatVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
132     {
133         evaluate(row, isNull, lop, rop);
134         return TreeNode::getFloatVal();
135     }
136     using Operator::getDoubleVal;
getDoubleVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)137     virtual double getDoubleVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
138     {
139         evaluate(row, isNull, lop, rop);
140         return TreeNode::getDoubleVal();
141     }
142     using Operator::getLongDoubleVal;
getLongDoubleVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)143     virtual long double getLongDoubleVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
144     {
145         evaluate(row, isNull, lop, rop);
146         return TreeNode::getLongDoubleVal();
147     }
148     using Operator::getDecimalVal;
getDecimalVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)149     virtual IDB_Decimal getDecimalVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
150     {
151         evaluate(row, isNull, lop, rop);
152 
153         // @bug5736, double type with precision -1 indicates that this type is for decimal math,
154         //      the original decimal scale is stored in scale field, which is no use for double.
155         if (fResultType.colDataType == CalpontSystemCatalog::DOUBLE && fResultType.precision == -1)
156         {
157             IDB_Decimal rv;
158             rv.scale = fResultType.scale;
159             rv.precision = 15;
160             rv.value = (int64_t)(TreeNode::getDoubleVal() * IDB_pow[rv.scale]);
161 
162             return rv;
163         }
164 
165         return TreeNode::getDecimalVal();
166     }
167     using Operator::getDateIntVal;
getDateIntVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)168     virtual int32_t getDateIntVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
169     {
170         evaluate(row, isNull, lop, rop);
171         return TreeNode::getDateIntVal();
172     }
173     using Operator::getDatetimeIntVal;
getDatetimeIntVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)174     virtual int64_t getDatetimeIntVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
175     {
176         evaluate(row, isNull, lop, rop);
177         return TreeNode::getDatetimeIntVal();
178     }
179     using Operator::getTimestampIntVal;
getTimestampIntVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)180     virtual int64_t getTimestampIntVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
181     {
182         evaluate(row, isNull, lop, rop);
183         return TreeNode::getTimestampIntVal();
184     }
185     using Operator::getTimeIntVal;
getTimeIntVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)186     virtual int64_t getTimeIntVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
187     {
188         evaluate(row, isNull, lop, rop);
189         return TreeNode::getTimeIntVal();
190     }
191     using Operator::getBoolVal;
getBoolVal(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)192     virtual bool getBoolVal(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
193     {
194         evaluate(row, isNull, lop, rop);
195         return TreeNode::getBoolVal();
196     }
197     void adjustResultType(const CalpontSystemCatalog::ColType& m);
198 
199 private:
200     template <typename result_t>
201     inline result_t execute(result_t op1, result_t op2, bool& isNull);
202     inline void execute(IDB_Decimal& result, IDB_Decimal op1, IDB_Decimal op2, bool& isNull);
203     std::string fTimeZone;
204 };
205 
206 #include "parsetree.h"
207 
evaluate(rowgroup::Row & row,bool & isNull,ParseTree * lop,ParseTree * rop)208 inline void ArithmeticOperator::evaluate(rowgroup::Row& row, bool& isNull, ParseTree* lop, ParseTree* rop)
209 {
210     // fOpType should have already been set on the connector during parsing
211     switch (fOperationType.colDataType)
212     {
213         case execplan::CalpontSystemCatalog::BIGINT:
214         case execplan::CalpontSystemCatalog::INT:
215         case execplan::CalpontSystemCatalog::MEDINT:
216         case execplan::CalpontSystemCatalog::SMALLINT:
217         case execplan::CalpontSystemCatalog::TINYINT:
218             fResult.intVal = execute(lop->getIntVal(row, isNull), rop->getIntVal(row, isNull), isNull);
219             break;
220 
221         case execplan::CalpontSystemCatalog::UBIGINT:
222         case execplan::CalpontSystemCatalog::UINT:
223         case execplan::CalpontSystemCatalog::UMEDINT:
224         case execplan::CalpontSystemCatalog::USMALLINT:
225         case execplan::CalpontSystemCatalog::UTINYINT:
226             fResult.uintVal = execute(lop->getUintVal(row, isNull), rop->getUintVal(row, isNull), isNull);
227             break;
228 
229         case execplan::CalpontSystemCatalog::DOUBLE:
230         case execplan::CalpontSystemCatalog::FLOAT:
231         case execplan::CalpontSystemCatalog::UDOUBLE:
232         case execplan::CalpontSystemCatalog::UFLOAT:
233             fResult.doubleVal = execute(lop->getDoubleVal(row, isNull), rop->getDoubleVal(row, isNull), isNull);
234             break;
235 
236         case execplan::CalpontSystemCatalog::LONGDOUBLE:
237             fResult.longDoubleVal = execute(lop->getLongDoubleVal(row, isNull), rop->getLongDoubleVal(row, isNull), isNull);
238             break;
239 
240         case execplan::CalpontSystemCatalog::DECIMAL:
241         case execplan::CalpontSystemCatalog::UDECIMAL:
242             execute (fResult.decimalVal, lop->getDecimalVal(row, isNull), rop->getDecimalVal(row, isNull), isNull);
243             break;
244 
245         default:
246         {
247             std::ostringstream oss;
248             oss << "invalid arithmetic operand type: " << fOperationType.colDataType;
249             throw logging::InvalidArgumentExcept(oss.str());
250         }
251     }
252 }
253 
254 template <typename result_t>
execute(result_t op1,result_t op2,bool & isNull)255 inline result_t ArithmeticOperator::execute(result_t op1, result_t op2, bool& isNull)
256 {
257     switch (fOp)
258     {
259         case OP_ADD:
260             return op1 + op2;
261 
262         case OP_SUB:
263             return op1 - op2;
264 
265         case OP_MUL:
266             return op1 * op2;
267 
268         case OP_DIV:
269             if (op2)
270                 return op1 / op2;
271             else
272                 isNull = true;
273 
274             return 0;
275 
276         default:
277         {
278             std::ostringstream oss;
279             oss << "invalid arithmetic operation: " << fOp;
280             throw logging::InvalidOperationExcept(oss.str());
281         }
282     }
283 }
284 
execute(IDB_Decimal & result,IDB_Decimal op1,IDB_Decimal op2,bool & isNull)285 inline void ArithmeticOperator::execute(IDB_Decimal& result, IDB_Decimal op1, IDB_Decimal op2, bool& isNull)
286 {
287     switch (fOp)
288     {
289         case OP_ADD:
290             if (result.scale == op1.scale && result.scale == op2.scale)
291             {
292                 result.value = op1.value + op2.value;
293                 break;
294             }
295 
296             if (result.scale >= op1.scale)
297                 op1.value *= IDB_pow[result.scale - op1.scale];
298             else
299                 op1.value = (int64_t)(op1.value > 0 ?
300                                       (double)op1.value / IDB_pow[op1.scale - result.scale] + 0.5 :
301                                       (double)op1.value / IDB_pow[op1.scale - result.scale] - 0.5);
302 
303             if (result.scale >= op2.scale)
304                 op2.value *= IDB_pow[result.scale - op2.scale];
305             else
306                 op2.value = (int64_t)(op2.value > 0 ?
307                                       (double)op2.value / IDB_pow[op2.scale - result.scale] + 0.5 :
308                                       (double)op2.value / IDB_pow[op2.scale - result.scale] - 0.5);
309 
310             result.value = op1.value + op2.value;
311             break;
312 
313         case OP_SUB:
314             if (result.scale == op1.scale && result.scale == op2.scale)
315             {
316                 result.value = op1.value - op2.value;
317                 break;
318             }
319 
320             if (result.scale >= op1.scale)
321                 op1.value *= IDB_pow[result.scale - op1.scale];
322             else
323                 op1.value = (int64_t)(op1.value > 0 ?
324                                       (double)op1.value / IDB_pow[op1.scale - result.scale] + 0.5 :
325                                       (double)op1.value / IDB_pow[op1.scale - result.scale] - 0.5);
326 
327             if (result.scale >= op2.scale)
328                 op2.value *= IDB_pow[result.scale - op2.scale];
329             else
330                 op2.value = (int64_t)(op2.value > 0 ?
331                                       (double)op2.value / IDB_pow[op2.scale - result.scale] + 0.5 :
332                                       (double)op2.value / IDB_pow[op2.scale - result.scale] - 0.5);
333 
334             result.value = op1.value - op2.value;
335             break;
336 
337         case OP_MUL:
338             if (result.scale >= op1.scale + op2.scale)
339                 result.value = op1.value * op2.value * IDB_pow[result.scale - (op1.scale + op2.scale)];
340             else
341                 result.value = (int64_t)(( (op1.value > 0 && op2.value > 0) || (op1.value < 0 && op2.value < 0) ?
342                                            (double)op1.value * op2.value / IDB_pow[op1.scale + op2.scale - result.scale] + 0.5 :
343                                            (double)op1.value * op2.value / IDB_pow[op1.scale + op2.scale - result.scale] - 0.5));
344 
345             break;
346 
347         case OP_DIV:
348             if (op2.value == 0)
349             {
350                 isNull = true;
351                 break;
352             }
353 
354             if (result.scale >= op1.scale - op2.scale)
355                 result.value = (int64_t)(( (op1.value > 0 && op2.value > 0) || (op1.value < 0 && op2.value < 0) ?
356                                            (long double)op1.value / op2.value * IDB_pow[result.scale - (op1.scale - op2.scale)] + 0.5 :
357                                            (long double)op1.value / op2.value * IDB_pow[result.scale - (op1.scale - op2.scale)] - 0.5));
358             else
359                 result.value = (int64_t)(( (op1.value > 0 && op2.value > 0) || (op1.value < 0 && op2.value < 0) ?
360                                            (long double)op1.value / op2.value / IDB_pow[op1.scale - op2.scale - result.scale] + 0.5 :
361                                            (long double)op1.value / op2.value / IDB_pow[op1.scale - op2.scale - result.scale] - 0.5));
362 
363             break;
364 
365         default:
366         {
367             std::ostringstream oss;
368             oss << "invalid arithmetic operation: " << fOp;
369             throw logging::InvalidOperationExcept(oss.str());
370         }
371     }
372 }
373 
374 std::ostream& operator<<(std::ostream& os, const ArithmeticOperator& rhs);
375 }
376 
377 #endif
378 
379