1
2 /**
3 * Copyright (C) 2018-present MongoDB, Inc.
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the Server Side Public License, version 1,
7 * as published by MongoDB, Inc.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * Server Side Public License for more details.
13 *
14 * You should have received a copy of the Server Side Public License
15 * along with this program. If not, see
16 * <http://www.mongodb.com/licensing/server-side-public-license>.
17 *
18 * As a special exception, the copyright holders give permission to link the
19 * code of portions of this program with the OpenSSL library under certain
20 * conditions as described in each individual source file and distribute
21 * linked combinations including the program with the OpenSSL library. You
22 * must comply with the Server Side Public License in all respects for
23 * all of the code used other than as permitted herein. If you modify file(s)
24 * with this exception, you may extend this exception to your version of the
25 * file(s), but you are not obligated to do so. If you do not wish to do so,
26 * delete this exception statement from your version. If you delete this
27 * exception statement from all source files in the program, then also delete
28 * it in the license file.
29 */
30
31
32 #include "mongo/platform/basic.h"
33
34 #include "mongo/db/pipeline/expression.h"
35
36 #include <algorithm>
37 #include <boost/algorithm/string.hpp>
38 #include <cstdio>
39 #include <vector>
40
41 #include "mongo/db/jsobj.h"
42 #include "mongo/db/pipeline/document.h"
43 #include "mongo/db/pipeline/expression_context.h"
44 #include "mongo/db/pipeline/value.h"
45 #include "mongo/db/query/datetime/date_time_support.h"
46 #include "mongo/platform/bits.h"
47 #include "mongo/platform/decimal128.h"
48 #include "mongo/util/mongoutils/str.h"
49 #include "mongo/util/string_map.h"
50 #include "mongo/util/summation.h"
51
52 namespace mongo {
53 using Parser = Expression::Parser;
54
55 using namespace mongoutils;
56
57 using boost::intrusive_ptr;
58 using std::map;
59 using std::move;
60 using std::pair;
61 using std::set;
62 using std::string;
63 using std::vector;
64
65 /// Helper function to easily wrap constants with $const.
serializeConstant(Value val)66 static Value serializeConstant(Value val) {
67 if (val.missing()) {
68 return Value("$$REMOVE"_sd);
69 }
70
71 return Value(DOC("$const" << val));
72 }
73
74 /* --------------------------- Expression ------------------------------ */
75
removeFieldPrefix(const string & prefixedField)76 string Expression::removeFieldPrefix(const string& prefixedField) {
77 uassert(16419,
78 str::stream() << "field path must not contain embedded null characters"
79 << prefixedField.find("\0")
80 << ",",
81 prefixedField.find('\0') == string::npos);
82
83 const char* pPrefixedField = prefixedField.c_str();
84 uassert(15982,
85 str::stream() << "field path references must be prefixed with a '$' ('" << prefixedField
86 << "'",
87 pPrefixedField[0] == '$');
88
89 return string(pPrefixedField + 1);
90 }
91
parseObject(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)92 intrusive_ptr<Expression> Expression::parseObject(
93 const boost::intrusive_ptr<ExpressionContext>& expCtx,
94 BSONObj obj,
95 const VariablesParseState& vps) {
96 if (obj.isEmpty()) {
97 return ExpressionObject::create(expCtx, {});
98 }
99
100 if (obj.firstElementFieldName()[0] == '$') {
101 // Assume this is an expression like {$add: [...]}.
102 return parseExpression(expCtx, obj, vps);
103 }
104
105 return ExpressionObject::parse(expCtx, obj, vps);
106 }
107
108 namespace {
109 StringMap<Parser> parserMap;
110 }
111
registerExpression(string key,Parser parser)112 void Expression::registerExpression(string key, Parser parser) {
113 auto op = parserMap.find(key);
114 massert(17064,
115 str::stream() << "Duplicate expression (" << key << ") registered.",
116 op == parserMap.end());
117 parserMap[key] = parser;
118 }
119
parseExpression(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)120 intrusive_ptr<Expression> Expression::parseExpression(
121 const boost::intrusive_ptr<ExpressionContext>& expCtx,
122 BSONObj obj,
123 const VariablesParseState& vps) {
124 uassert(15983,
125 str::stream() << "An object representing an expression must have exactly one "
126 "field: "
127 << obj.toString(),
128 obj.nFields() == 1);
129
130 // Look up the parser associated with the expression name.
131 const char* opName = obj.firstElementFieldName();
132 auto op = parserMap.find(opName);
133 uassert(ErrorCodes::InvalidPipelineOperator,
134 str::stream() << "Unrecognized expression '" << opName << "'",
135 op != parserMap.end());
136 return op->second(expCtx, obj.firstElement(), vps);
137 }
138
parseArguments(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)139 Expression::ExpressionVector ExpressionNary::parseArguments(
140 const boost::intrusive_ptr<ExpressionContext>& expCtx,
141 BSONElement exprElement,
142 const VariablesParseState& vps) {
143 ExpressionVector out;
144 if (exprElement.type() == Array) {
145 BSONForEach(elem, exprElement.Obj()) {
146 out.push_back(Expression::parseOperand(expCtx, elem, vps));
147 }
148 } else { // Assume it's an operand that accepts a single argument.
149 out.push_back(Expression::parseOperand(expCtx, exprElement, vps));
150 }
151
152 return out;
153 }
154
parseOperand(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)155 intrusive_ptr<Expression> Expression::parseOperand(
156 const boost::intrusive_ptr<ExpressionContext>& expCtx,
157 BSONElement exprElement,
158 const VariablesParseState& vps) {
159 BSONType type = exprElement.type();
160
161 if (type == String && exprElement.valuestr()[0] == '$') {
162 /* if we got here, this is a field path expression */
163 return ExpressionFieldPath::parse(expCtx, exprElement.str(), vps);
164 } else if (type == Object) {
165 return Expression::parseObject(expCtx, exprElement.Obj(), vps);
166 } else if (type == Array) {
167 return ExpressionArray::parse(expCtx, exprElement, vps);
168 } else {
169 return ExpressionConstant::parse(expCtx, exprElement, vps);
170 }
171 }
172
173 namespace {
174 /**
175 * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
176 * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading
177 * byte.
178 */
isLeadingByte(char charByte)179 bool isLeadingByte(char charByte) {
180 return (charByte & 0xc0) == 0xc0;
181 }
182
183 /**
184 * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is
185 * a single-byte code point.
186 */
isSingleByte(char charByte)187 bool isSingleByte(char charByte) {
188 return (charByte & 0x80) == 0x0;
189 }
190
getCodePointLength(char charByte)191 size_t getCodePointLength(char charByte) {
192 if (isSingleByte(charByte)) {
193 return 1;
194 }
195
196 invariant(isLeadingByte(charByte));
197
198 // In UTF-8, the number of leading ones is the number of bytes the code point takes up.
199 return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8)));
200 }
201 } // namespace
202
203 /* ------------------------- Register Date Expressions ----------------------------- */
204
205 REGISTER_EXPRESSION(dayOfMonth, ExpressionDayOfMonth::parse);
206 REGISTER_EXPRESSION(dayOfWeek, ExpressionDayOfWeek::parse);
207 REGISTER_EXPRESSION(dayOfYear, ExpressionDayOfYear::parse);
208 REGISTER_EXPRESSION(hour, ExpressionHour::parse);
209 REGISTER_EXPRESSION(isoDayOfWeek, ExpressionIsoDayOfWeek::parse);
210 REGISTER_EXPRESSION(isoWeek, ExpressionIsoWeek::parse);
211 REGISTER_EXPRESSION(isoWeekYear, ExpressionIsoWeekYear::parse);
212 REGISTER_EXPRESSION(millisecond, ExpressionMillisecond::parse);
213 REGISTER_EXPRESSION(minute, ExpressionMinute::parse);
214 REGISTER_EXPRESSION(month, ExpressionMonth::parse);
215 REGISTER_EXPRESSION(second, ExpressionSecond::parse);
216 REGISTER_EXPRESSION(week, ExpressionWeek::parse);
217 REGISTER_EXPRESSION(year, ExpressionYear::parse);
218
219 /* ----------------------- ExpressionAbs ---------------------------- */
220
evaluateNumericArg(const Value & numericArg) const221 Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const {
222 BSONType type = numericArg.getType();
223 if (type == NumberDouble) {
224 return Value(std::abs(numericArg.getDouble()));
225 } else if (type == NumberDecimal) {
226 return Value(numericArg.getDecimal().toAbs());
227 } else {
228 long long num = numericArg.getLong();
229 uassert(28680,
230 "can't take $abs of long long min",
231 num != std::numeric_limits<long long>::min());
232 long long absVal = std::abs(num);
233 return type == NumberLong ? Value(absVal) : Value::createIntOrLong(absVal);
234 }
235 }
236
237 REGISTER_EXPRESSION(abs, ExpressionAbs::parse);
getOpName() const238 const char* ExpressionAbs::getOpName() const {
239 return "$abs";
240 }
241
242 /* ------------------------- ExpressionAdd ----------------------------- */
243
evaluate(const Document & root,Variables * variables) const244 Value ExpressionAdd::evaluate(const Document& root, Variables* variables) const {
245 // We'll try to return the narrowest possible result value while avoiding overflow, loss
246 // of precision due to intermediate rounding or implicit use of decimal types. To do that,
247 // compute a compensated sum for non-decimal values and a separate decimal sum for decimal
248 // values, and track the current narrowest type.
249 DoubleDoubleSummation nonDecimalTotal;
250 Decimal128 decimalTotal;
251 BSONType totalType = NumberInt;
252 bool haveDate = false;
253
254 const size_t n = vpOperand.size();
255 for (size_t i = 0; i < n; ++i) {
256 Value val = vpOperand[i]->evaluate(root, variables);
257
258 switch (val.getType()) {
259 case NumberDecimal:
260 decimalTotal = decimalTotal.add(val.getDecimal());
261 totalType = NumberDecimal;
262 break;
263 case NumberDouble:
264 nonDecimalTotal.addDouble(val.getDouble());
265 if (totalType != NumberDecimal)
266 totalType = NumberDouble;
267 break;
268 case NumberLong:
269 nonDecimalTotal.addLong(val.getLong());
270 if (totalType == NumberInt)
271 totalType = NumberLong;
272 break;
273 case NumberInt:
274 nonDecimalTotal.addDouble(val.getInt());
275 break;
276 case Date:
277 uassert(16612, "only one date allowed in an $add expression", !haveDate);
278 haveDate = true;
279 nonDecimalTotal.addLong(val.getDate().toMillisSinceEpoch());
280 break;
281 default:
282 uassert(16554,
283 str::stream() << "$add only supports numeric or date types, not "
284 << typeName(val.getType()),
285 val.nullish());
286 return Value(BSONNULL);
287 }
288 }
289
290 if (haveDate) {
291 int64_t longTotal;
292 if (totalType == NumberDecimal) {
293 longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong();
294 } else {
295 uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong());
296 longTotal = nonDecimalTotal.getLong();
297 }
298 return Value(Date_t::fromMillisSinceEpoch(longTotal));
299 }
300 switch (totalType) {
301 case NumberDecimal:
302 return Value(decimalTotal.add(nonDecimalTotal.getDecimal()));
303 case NumberLong:
304 dassert(nonDecimalTotal.isInteger());
305 if (nonDecimalTotal.fitsLong())
306 return Value(nonDecimalTotal.getLong());
307 // Fallthrough.
308 case NumberInt:
309 if (nonDecimalTotal.fitsLong())
310 return Value::createIntOrLong(nonDecimalTotal.getLong());
311 // Fallthrough.
312 case NumberDouble:
313 return Value(nonDecimalTotal.getDouble());
314 default:
315 massert(16417, "$add resulted in a non-numeric type", false);
316 }
317 }
318
319 REGISTER_EXPRESSION(add, ExpressionAdd::parse);
getOpName() const320 const char* ExpressionAdd::getOpName() const {
321 return "$add";
322 }
323
324 /* ------------------------- ExpressionAllElementsTrue -------------------------- */
325
evaluate(const Document & root,Variables * variables) const326 Value ExpressionAllElementsTrue::evaluate(const Document& root, Variables* variables) const {
327 const Value arr = vpOperand[0]->evaluate(root, variables);
328 uassert(17040,
329 str::stream() << getOpName() << "'s argument must be an array, but is "
330 << typeName(arr.getType()),
331 arr.isArray());
332 const vector<Value>& array = arr.getArray();
333 for (vector<Value>::const_iterator it = array.begin(); it != array.end(); ++it) {
334 if (!it->coerceToBool()) {
335 return Value(false);
336 }
337 }
338 return Value(true);
339 }
340
341 REGISTER_EXPRESSION(allElementsTrue, ExpressionAllElementsTrue::parse);
getOpName() const342 const char* ExpressionAllElementsTrue::getOpName() const {
343 return "$allElementsTrue";
344 }
345
346 /* ------------------------- ExpressionAnd ----------------------------- */
347
optimize()348 intrusive_ptr<Expression> ExpressionAnd::optimize() {
349 /* optimize the conjunction as much as possible */
350 intrusive_ptr<Expression> pE(ExpressionNary::optimize());
351
352 /* if the result isn't a conjunction, we can't do anything */
353 ExpressionAnd* pAnd = dynamic_cast<ExpressionAnd*>(pE.get());
354 if (!pAnd)
355 return pE;
356
357 /*
358 Check the last argument on the result; if it's not constant (as
359 promised by ExpressionNary::optimize(),) then there's nothing
360 we can do.
361 */
362 const size_t n = pAnd->vpOperand.size();
363 // ExpressionNary::optimize() generates an ExpressionConstant for {$and:[]}.
364 verify(n > 0);
365 intrusive_ptr<Expression> pLast(pAnd->vpOperand[n - 1]);
366 const ExpressionConstant* pConst = dynamic_cast<ExpressionConstant*>(pLast.get());
367 if (!pConst)
368 return pE;
369
370 /*
371 Evaluate and coerce the last argument to a boolean. If it's false,
372 then we can replace this entire expression.
373 */
374 bool last = pConst->getValue().coerceToBool();
375 if (!last) {
376 intrusive_ptr<ExpressionConstant> pFinal(
377 ExpressionConstant::create(getExpressionContext(), Value(false)));
378 return pFinal;
379 }
380
381 /*
382 If we got here, the final operand was true, so we don't need it
383 anymore. If there was only one other operand, we don't need the
384 conjunction either. Note we still need to keep the promise that
385 the result will be a boolean.
386 */
387 if (n == 2) {
388 intrusive_ptr<Expression> pFinal(
389 ExpressionCoerceToBool::create(getExpressionContext(), pAnd->vpOperand[0]));
390 return pFinal;
391 }
392
393 /*
394 Remove the final "true" value, and return the new expression.
395
396 CW TODO:
397 Note that because of any implicit conversions, we may need to
398 apply an implicit boolean conversion.
399 */
400 pAnd->vpOperand.resize(n - 1);
401 return pE;
402 }
403
evaluate(const Document & root,Variables * variables) const404 Value ExpressionAnd::evaluate(const Document& root, Variables* variables) const {
405 const size_t n = vpOperand.size();
406 for (size_t i = 0; i < n; ++i) {
407 Value pValue(vpOperand[i]->evaluate(root, variables));
408 if (!pValue.coerceToBool())
409 return Value(false);
410 }
411
412 return Value(true);
413 }
414
415 REGISTER_EXPRESSION(and, ExpressionAnd::parse);
getOpName() const416 const char* ExpressionAnd::getOpName() const {
417 return "$and";
418 }
419
420 /* ------------------------- ExpressionAnyElementTrue -------------------------- */
421
evaluate(const Document & root,Variables * variables) const422 Value ExpressionAnyElementTrue::evaluate(const Document& root, Variables* variables) const {
423 const Value arr = vpOperand[0]->evaluate(root, variables);
424 uassert(17041,
425 str::stream() << getOpName() << "'s argument must be an array, but is "
426 << typeName(arr.getType()),
427 arr.isArray());
428 const vector<Value>& array = arr.getArray();
429 for (vector<Value>::const_iterator it = array.begin(); it != array.end(); ++it) {
430 if (it->coerceToBool()) {
431 return Value(true);
432 }
433 }
434 return Value(false);
435 }
436
437 REGISTER_EXPRESSION(anyElementTrue, ExpressionAnyElementTrue::parse);
getOpName() const438 const char* ExpressionAnyElementTrue::getOpName() const {
439 return "$anyElementTrue";
440 }
441
442 /* ---------------------- ExpressionArray --------------------------- */
443
evaluate(const Document & root,Variables * variables) const444 Value ExpressionArray::evaluate(const Document& root, Variables* variables) const {
445 vector<Value> values;
446 values.reserve(vpOperand.size());
447 for (auto&& expr : vpOperand) {
448 Value elemVal = expr->evaluate(root, variables);
449 values.push_back(elemVal.missing() ? Value(BSONNULL) : std::move(elemVal));
450 }
451 return Value(std::move(values));
452 }
453
serialize(bool explain) const454 Value ExpressionArray::serialize(bool explain) const {
455 vector<Value> expressions;
456 expressions.reserve(vpOperand.size());
457 for (auto&& expr : vpOperand) {
458 expressions.push_back(expr->serialize(explain));
459 }
460 return Value(std::move(expressions));
461 }
462
getOpName() const463 const char* ExpressionArray::getOpName() const {
464 // This should never be called, but is needed to inherit from ExpressionNary.
465 return "$array";
466 }
467
468 /* ------------------------- ExpressionArrayElemAt -------------------------- */
469
evaluate(const Document & root,Variables * variables) const470 Value ExpressionArrayElemAt::evaluate(const Document& root, Variables* variables) const {
471 const Value array = vpOperand[0]->evaluate(root, variables);
472 const Value indexArg = vpOperand[1]->evaluate(root, variables);
473
474 if (array.nullish() || indexArg.nullish()) {
475 return Value(BSONNULL);
476 }
477
478 uassert(28689,
479 str::stream() << getOpName() << "'s first argument must be an array, but is "
480 << typeName(array.getType()),
481 array.isArray());
482 uassert(28690,
483 str::stream() << getOpName() << "'s second argument must be a numeric value,"
484 << " but is "
485 << typeName(indexArg.getType()),
486 indexArg.numeric());
487 uassert(28691,
488 str::stream() << getOpName() << "'s second argument must be representable as"
489 << " a 32-bit integer: "
490 << indexArg.coerceToDouble(),
491 indexArg.integral());
492
493 long long i = indexArg.coerceToLong();
494 if (i < 0 && static_cast<size_t>(std::abs(i)) > array.getArrayLength()) {
495 // Positive indices that are too large are handled automatically by Value.
496 return Value();
497 } else if (i < 0) {
498 // Index from the back of the array.
499 i = array.getArrayLength() + i;
500 }
501 const size_t index = static_cast<size_t>(i);
502 return array[index];
503 }
504
505 REGISTER_EXPRESSION(arrayElemAt, ExpressionArrayElemAt::parse);
getOpName() const506 const char* ExpressionArrayElemAt::getOpName() const {
507 return "$arrayElemAt";
508 }
509
510 /* ------------------------- ExpressionObjectToArray -------------------------- */
511
512
evaluate(const Document & root,Variables * variables) const513 Value ExpressionObjectToArray::evaluate(const Document& root, Variables* variables) const {
514 const Value targetVal = vpOperand[0]->evaluate(root, variables);
515
516 if (targetVal.nullish()) {
517 return Value(BSONNULL);
518 }
519
520 uassert(40390,
521 str::stream() << "$objectToArray requires a document input, found: "
522 << typeName(targetVal.getType()),
523 (targetVal.getType() == BSONType::Object));
524
525 vector<Value> output;
526
527 FieldIterator iter = targetVal.getDocument().fieldIterator();
528 while (iter.more()) {
529 Document::FieldPair pair = iter.next();
530 MutableDocument keyvalue;
531 keyvalue.addField("k", Value(pair.first));
532 keyvalue.addField("v", pair.second);
533 output.push_back(keyvalue.freezeToValue());
534 }
535
536 return Value(output);
537 }
538
539 REGISTER_EXPRESSION(objectToArray, ExpressionObjectToArray::parse);
getOpName() const540 const char* ExpressionObjectToArray::getOpName() const {
541 return "$objectToArray";
542 }
543
544 /* ------------------------- ExpressionArrayToObject -------------------------- */
545
evaluate(const Document & root,Variables * variables) const546 Value ExpressionArrayToObject::evaluate(const Document& root, Variables* variables) const {
547 const Value input = vpOperand[0]->evaluate(root, variables);
548 if (input.nullish()) {
549 return Value(BSONNULL);
550 }
551
552 uassert(40386,
553 str::stream() << "$arrayToObject requires an array input, found: "
554 << typeName(input.getType()),
555 input.isArray());
556
557 MutableDocument output;
558 const vector<Value>& array = input.getArray();
559 if (array.empty()) {
560 return output.freezeToValue();
561 }
562
563 // There are two accepted input formats in an array: [ [key, val] ] or [ {k:key, v:val} ]. The
564 // first array element determines the format for the rest of the array. Mixing input formats is
565 // not allowed.
566 bool inputArrayFormat;
567 if (array[0].isArray()) {
568 inputArrayFormat = true;
569 } else if (array[0].getType() == BSONType::Object) {
570 inputArrayFormat = false;
571 } else {
572 uasserted(40398,
573 str::stream() << "Unrecognised input type format for $arrayToObject: "
574 << typeName(array[0].getType()));
575 }
576
577 for (auto&& elem : array) {
578 if (inputArrayFormat == true) {
579 uassert(
580 40396,
581 str::stream() << "$arrayToObject requires a consistent input format. Elements must"
582 "all be arrays or all be objects. Array was detected, now found: "
583 << typeName(elem.getType()),
584 elem.isArray());
585
586 const vector<Value>& valArray = elem.getArray();
587
588 uassert(40397,
589 str::stream() << "$arrayToObject requires an array of size 2 arrays,"
590 "found array of size: "
591 << valArray.size(),
592 (valArray.size() == 2));
593
594 uassert(40395,
595 str::stream() << "$arrayToObject requires an array of key-value pairs, where "
596 "the key must be of type string. Found key type: "
597 << typeName(valArray[0].getType()),
598 (valArray[0].getType() == BSONType::String));
599
600 auto keyName = valArray[0].getString();
601
602 uassert(4940400,
603 "Key field cannot contain an embedded null byte",
604 keyName.find('\0') == std::string::npos);
605
606 output[keyName] = valArray[1];
607
608 } else {
609 uassert(
610 40391,
611 str::stream() << "$arrayToObject requires a consistent input format. Elements must"
612 "all be arrays or all be objects. Object was detected, now found: "
613 << typeName(elem.getType()),
614 (elem.getType() == BSONType::Object));
615
616 uassert(40392,
617 str::stream() << "$arrayToObject requires an object keys of 'k' and 'v'. "
618 "Found incorrect number of keys:"
619 << elem.getDocument().size(),
620 (elem.getDocument().size() == 2));
621
622 Value key = elem.getDocument().getField("k");
623 Value value = elem.getDocument().getField("v");
624
625 uassert(40393,
626 str::stream() << "$arrayToObject requires an object with keys 'k' and 'v'. "
627 "Missing either or both keys from: "
628 << elem.toString(),
629 (!key.missing() && !value.missing()));
630
631 uassert(
632 40394,
633 str::stream() << "$arrayToObject requires an object with keys 'k' and 'v', where "
634 "the value of 'k' must be of type string. Found type: "
635 << typeName(key.getType()),
636 (key.getType() == BSONType::String));
637
638 auto keyName = key.getString();
639
640 uassert(4940401,
641 "Key field cannot contain an embedded null byte",
642 keyName.find('\0') == std::string::npos);
643
644 output[keyName] = value;
645 }
646 }
647
648 return output.freezeToValue();
649 }
650
651 REGISTER_EXPRESSION(arrayToObject, ExpressionArrayToObject::parse);
getOpName() const652 const char* ExpressionArrayToObject::getOpName() const {
653 return "$arrayToObject";
654 }
655
656 /* ------------------------- ExpressionCeil -------------------------- */
657
evaluateNumericArg(const Value & numericArg) const658 Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const {
659 // There's no point in taking the ceiling of integers or longs, it will have no effect.
660 switch (numericArg.getType()) {
661 case NumberDouble:
662 return Value(std::ceil(numericArg.getDouble()));
663 case NumberDecimal:
664 // Round toward the nearest decimal with a zero exponent in the positive direction.
665 return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
666 Decimal128::kRoundTowardPositive));
667 default:
668 return numericArg;
669 }
670 }
671
672 REGISTER_EXPRESSION(ceil, ExpressionCeil::parse);
getOpName() const673 const char* ExpressionCeil::getOpName() const {
674 return "$ceil";
675 }
676
677 /* -------------------- ExpressionCoerceToBool ------------------------- */
678
create(const intrusive_ptr<ExpressionContext> & expCtx,const intrusive_ptr<Expression> & pExpression)679 intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create(
680 const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) {
681 intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(expCtx, pExpression));
682 return pNew;
683 }
684
ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext> & expCtx,const intrusive_ptr<Expression> & pTheExpression)685 ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext>& expCtx,
686 const intrusive_ptr<Expression>& pTheExpression)
687 : Expression(expCtx), pExpression(pTheExpression) {}
688
optimize()689 intrusive_ptr<Expression> ExpressionCoerceToBool::optimize() {
690 /* optimize the operand */
691 pExpression = pExpression->optimize();
692
693 /* if the operand already produces a boolean, then we don't need this */
694 /* LATER - Expression to support a "typeof" query? */
695 Expression* pE = pExpression.get();
696 if (dynamic_cast<ExpressionAnd*>(pE) || dynamic_cast<ExpressionOr*>(pE) ||
697 dynamic_cast<ExpressionNot*>(pE) || dynamic_cast<ExpressionCoerceToBool*>(pE))
698 return pExpression;
699
700 return intrusive_ptr<Expression>(this);
701 }
702
_doAddDependencies(DepsTracker * deps) const703 void ExpressionCoerceToBool::_doAddDependencies(DepsTracker* deps) const {
704 pExpression->addDependencies(deps);
705 }
706
evaluate(const Document & root,Variables * variables) const707 Value ExpressionCoerceToBool::evaluate(const Document& root, Variables* variables) const {
708 Value pResult(pExpression->evaluate(root, variables));
709 bool b = pResult.coerceToBool();
710 if (b)
711 return Value(true);
712 return Value(false);
713 }
714
serialize(bool explain) const715 Value ExpressionCoerceToBool::serialize(bool explain) const {
716 // When not explaining, serialize to an $and expression. When parsed, the $and expression
717 // will be optimized back into a ExpressionCoerceToBool.
718 const char* name = explain ? "$coerceToBool" : "$and";
719 return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain))));
720 }
721
722 /* ----------------------- ExpressionCompare --------------------------- */
723
724 REGISTER_EXPRESSION(cmp,
725 stdx::bind(ExpressionCompare::parse,
726 stdx::placeholders::_1,
727 stdx::placeholders::_2,
728 stdx::placeholders::_3,
729 ExpressionCompare::CMP));
730 REGISTER_EXPRESSION(eq,
731 stdx::bind(ExpressionCompare::parse,
732 stdx::placeholders::_1,
733 stdx::placeholders::_2,
734 stdx::placeholders::_3,
735 ExpressionCompare::EQ));
736 REGISTER_EXPRESSION(gt,
737 stdx::bind(ExpressionCompare::parse,
738 stdx::placeholders::_1,
739 stdx::placeholders::_2,
740 stdx::placeholders::_3,
741 ExpressionCompare::GT));
742 REGISTER_EXPRESSION(gte,
743 stdx::bind(ExpressionCompare::parse,
744 stdx::placeholders::_1,
745 stdx::placeholders::_2,
746 stdx::placeholders::_3,
747 ExpressionCompare::GTE));
748 REGISTER_EXPRESSION(lt,
749 stdx::bind(ExpressionCompare::parse,
750 stdx::placeholders::_1,
751 stdx::placeholders::_2,
752 stdx::placeholders::_3,
753 ExpressionCompare::LT));
754 REGISTER_EXPRESSION(lte,
755 stdx::bind(ExpressionCompare::parse,
756 stdx::placeholders::_1,
757 stdx::placeholders::_2,
758 stdx::placeholders::_3,
759 ExpressionCompare::LTE));
760 REGISTER_EXPRESSION(ne,
761 stdx::bind(ExpressionCompare::parse,
762 stdx::placeholders::_1,
763 stdx::placeholders::_2,
764 stdx::placeholders::_3,
765 ExpressionCompare::NE));
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement bsonExpr,const VariablesParseState & vps,CmpOp op)766 intrusive_ptr<Expression> ExpressionCompare::parse(
767 const boost::intrusive_ptr<ExpressionContext>& expCtx,
768 BSONElement bsonExpr,
769 const VariablesParseState& vps,
770 CmpOp op) {
771 intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, op);
772 ExpressionVector args = parseArguments(expCtx, bsonExpr, vps);
773 expr->validateArguments(args);
774 expr->vpOperand = args;
775 return expr;
776 }
777
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,CmpOp cmpOp,const boost::intrusive_ptr<Expression> & exprLeft,const boost::intrusive_ptr<Expression> & exprRight)778 boost::intrusive_ptr<ExpressionCompare> ExpressionCompare::create(
779 const boost::intrusive_ptr<ExpressionContext>& expCtx,
780 CmpOp cmpOp,
781 const boost::intrusive_ptr<Expression>& exprLeft,
782 const boost::intrusive_ptr<Expression>& exprRight) {
783 boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, cmpOp);
784 expr->vpOperand = {exprLeft, exprRight};
785 return expr;
786 }
787
788 namespace {
789 // Lookup table for truth value returns
790 struct CmpLookup {
791 const bool truthValue[3]; // truth value for -1, 0, 1
792 const ExpressionCompare::CmpOp reverse; // reverse(b,a) returns the same as op(a,b)
793 const char name[5]; // string name with trailing '\0'
794 };
795 static const CmpLookup cmpLookup[7] = {
796 /* -1 0 1 reverse name */
797 /* EQ */ {{false, true, false}, ExpressionCompare::EQ, "$eq"},
798 /* NE */ {{true, false, true}, ExpressionCompare::NE, "$ne"},
799 /* GT */ {{false, false, true}, ExpressionCompare::LT, "$gt"},
800 /* GTE */ {{false, true, true}, ExpressionCompare::LTE, "$gte"},
801 /* LT */ {{true, false, false}, ExpressionCompare::GT, "$lt"},
802 /* LTE */ {{true, true, false}, ExpressionCompare::GTE, "$lte"},
803
804 // CMP is special. Only name is used.
805 /* CMP */ {{false, false, false}, ExpressionCompare::CMP, "$cmp"},
806 };
807 } // namespace
808
809
evaluate(const Document & root,Variables * variables) const810 Value ExpressionCompare::evaluate(const Document& root, Variables* variables) const {
811 Value pLeft(vpOperand[0]->evaluate(root, variables));
812 Value pRight(vpOperand[1]->evaluate(root, variables));
813
814 int cmp = getExpressionContext()->getValueComparator().compare(pLeft, pRight);
815
816 // Make cmp one of 1, 0, or -1.
817 if (cmp == 0) {
818 // leave as 0
819 } else if (cmp < 0) {
820 cmp = -1;
821 } else if (cmp > 0) {
822 cmp = 1;
823 }
824
825 if (cmpOp == CMP)
826 return Value(cmp);
827
828 bool returnValue = cmpLookup[cmpOp].truthValue[cmp + 1];
829 return Value(returnValue);
830 }
831
getOpName() const832 const char* ExpressionCompare::getOpName() const {
833 return cmpLookup[cmpOp].name;
834 }
835
836 /* ------------------------- ExpressionConcat ----------------------------- */
837
evaluate(const Document & root,Variables * variables) const838 Value ExpressionConcat::evaluate(const Document& root, Variables* variables) const {
839 const size_t n = vpOperand.size();
840
841 StringBuilder result;
842 for (size_t i = 0; i < n; ++i) {
843 Value val = vpOperand[i]->evaluate(root, variables);
844 if (val.nullish())
845 return Value(BSONNULL);
846
847 uassert(16702,
848 str::stream() << "$concat only supports strings, not " << typeName(val.getType()),
849 val.getType() == String);
850
851 result << val.coerceToString();
852 }
853
854 return Value(result.str());
855 }
856
857 REGISTER_EXPRESSION(concat, ExpressionConcat::parse);
getOpName() const858 const char* ExpressionConcat::getOpName() const {
859 return "$concat";
860 }
861
862 /* ------------------------- ExpressionConcatArrays ----------------------------- */
863
evaluate(const Document & root,Variables * variables) const864 Value ExpressionConcatArrays::evaluate(const Document& root, Variables* variables) const {
865 const size_t n = vpOperand.size();
866 vector<Value> values;
867
868 for (size_t i = 0; i < n; ++i) {
869 Value val = vpOperand[i]->evaluate(root, variables);
870 if (val.nullish()) {
871 return Value(BSONNULL);
872 }
873
874 uassert(28664,
875 str::stream() << "$concatArrays only supports arrays, not "
876 << typeName(val.getType()),
877 val.isArray());
878
879 const auto& subValues = val.getArray();
880 values.insert(values.end(), subValues.begin(), subValues.end());
881 }
882 return Value(std::move(values));
883 }
884
885 REGISTER_EXPRESSION(concatArrays, ExpressionConcatArrays::parse);
getOpName() const886 const char* ExpressionConcatArrays::getOpName() const {
887 return "$concatArrays";
888 }
889
890 /* ----------------------- ExpressionCond ------------------------------ */
891
evaluate(const Document & root,Variables * variables) const892 Value ExpressionCond::evaluate(const Document& root, Variables* variables) const {
893 Value pCond(vpOperand[0]->evaluate(root, variables));
894 int idx = pCond.coerceToBool() ? 1 : 2;
895 return vpOperand[idx]->evaluate(root, variables);
896 }
897
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)898 intrusive_ptr<Expression> ExpressionCond::parse(
899 const boost::intrusive_ptr<ExpressionContext>& expCtx,
900 BSONElement expr,
901 const VariablesParseState& vps) {
902 if (expr.type() != Object) {
903 return Base::parse(expCtx, expr, vps);
904 }
905 verify(str::equals(expr.fieldName(), "$cond"));
906
907 intrusive_ptr<ExpressionCond> ret = new ExpressionCond(expCtx);
908 ret->vpOperand.resize(3);
909
910 const BSONObj args = expr.embeddedObject();
911 BSONForEach(arg, args) {
912 if (str::equals(arg.fieldName(), "if")) {
913 ret->vpOperand[0] = parseOperand(expCtx, arg, vps);
914 } else if (str::equals(arg.fieldName(), "then")) {
915 ret->vpOperand[1] = parseOperand(expCtx, arg, vps);
916 } else if (str::equals(arg.fieldName(), "else")) {
917 ret->vpOperand[2] = parseOperand(expCtx, arg, vps);
918 } else {
919 uasserted(17083,
920 str::stream() << "Unrecognized parameter to $cond: " << arg.fieldName());
921 }
922 }
923
924 uassert(17080, "Missing 'if' parameter to $cond", ret->vpOperand[0]);
925 uassert(17081, "Missing 'then' parameter to $cond", ret->vpOperand[1]);
926 uassert(17082, "Missing 'else' parameter to $cond", ret->vpOperand[2]);
927
928 return ret;
929 }
930
931 REGISTER_EXPRESSION(cond, ExpressionCond::parse);
getOpName() const932 const char* ExpressionCond::getOpName() const {
933 return "$cond";
934 }
935
936 /* ---------------------- ExpressionConstant --------------------------- */
937
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)938 intrusive_ptr<Expression> ExpressionConstant::parse(
939 const boost::intrusive_ptr<ExpressionContext>& expCtx,
940 BSONElement exprElement,
941 const VariablesParseState& vps) {
942 return new ExpressionConstant(expCtx, Value(exprElement));
943 }
944
945
create(const intrusive_ptr<ExpressionContext> & expCtx,const Value & value)946 intrusive_ptr<ExpressionConstant> ExpressionConstant::create(
947 const intrusive_ptr<ExpressionContext>& expCtx, const Value& value) {
948 intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(expCtx, value));
949 return pEC;
950 }
951
ExpressionConstant(const boost::intrusive_ptr<ExpressionContext> & expCtx,const Value & value)952 ExpressionConstant::ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx,
953 const Value& value)
954 : Expression(expCtx), _value(value) {}
955
956
optimize()957 intrusive_ptr<Expression> ExpressionConstant::optimize() {
958 /* nothing to do */
959 return intrusive_ptr<Expression>(this);
960 }
961
_doAddDependencies(DepsTracker * deps) const962 void ExpressionConstant::_doAddDependencies(DepsTracker* deps) const {
963 /* nothing to do */
964 }
965
evaluate(const Document & root,Variables * variables) const966 Value ExpressionConstant::evaluate(const Document& root, Variables* variables) const {
967 return _value;
968 }
969
serialize(bool explain) const970 Value ExpressionConstant::serialize(bool explain) const {
971 return serializeConstant(_value);
972 }
973
974 REGISTER_EXPRESSION(const, ExpressionConstant::parse);
975 REGISTER_EXPRESSION(literal, ExpressionConstant::parse); // alias
getOpName() const976 const char* ExpressionConstant::getOpName() const {
977 return "$const";
978 }
979
980 /* ---------------------- ExpressionDateFromParts ----------------------- */
981
982 /* Helper functions also shared with ExpressionDateToParts */
983
984 namespace {
985
makeTimeZone(const TimeZoneDatabase * tzdb,const Document & root,const Expression * timeZone,Variables * variables)986 boost::optional<TimeZone> makeTimeZone(const TimeZoneDatabase* tzdb,
987 const Document& root,
988 const Expression* timeZone,
989 Variables* variables) {
990 invariant(tzdb);
991
992 if (!timeZone) {
993 return mongo::TimeZoneDatabase::utcZone();
994 }
995
996 auto timeZoneId = timeZone->evaluate(root, variables);
997
998 if (timeZoneId.nullish()) {
999 return boost::none;
1000 }
1001
1002 uassert(40517,
1003 str::stream() << "timezone must evaluate to a string, found "
1004 << typeName(timeZoneId.getType()),
1005 timeZoneId.getType() == BSONType::String);
1006
1007 return tzdb->getTimeZone(timeZoneId.getString());
1008 }
1009
1010 } // namespace
1011
1012
1013 REGISTER_EXPRESSION(dateFromParts, ExpressionDateFromParts::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1014 intrusive_ptr<Expression> ExpressionDateFromParts::parse(
1015 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1016 BSONElement expr,
1017 const VariablesParseState& vps) {
1018
1019 uassert(40519,
1020 "$dateFromParts only supports an object as its argument",
1021 expr.type() == BSONType::Object);
1022
1023 BSONElement yearElem;
1024 BSONElement monthElem;
1025 BSONElement dayElem;
1026 BSONElement hourElem;
1027 BSONElement minuteElem;
1028 BSONElement secondElem;
1029 BSONElement millisecondElem;
1030 BSONElement isoWeekYearElem;
1031 BSONElement isoWeekElem;
1032 BSONElement isoDayOfWeekElem;
1033 BSONElement timeZoneElem;
1034
1035 const BSONObj args = expr.embeddedObject();
1036 for (auto&& arg : args) {
1037 auto field = arg.fieldNameStringData();
1038
1039 if (field == "year"_sd) {
1040 yearElem = arg;
1041 } else if (field == "month"_sd) {
1042 monthElem = arg;
1043 } else if (field == "day"_sd) {
1044 dayElem = arg;
1045 } else if (field == "hour"_sd) {
1046 hourElem = arg;
1047 } else if (field == "minute"_sd) {
1048 minuteElem = arg;
1049 } else if (field == "second"_sd) {
1050 secondElem = arg;
1051 } else if (field == "millisecond"_sd) {
1052 millisecondElem = arg;
1053 } else if (field == "isoWeekYear"_sd) {
1054 isoWeekYearElem = arg;
1055 } else if (field == "isoWeek"_sd) {
1056 isoWeekElem = arg;
1057 } else if (field == "isoDayOfWeek"_sd) {
1058 isoDayOfWeekElem = arg;
1059 } else if (field == "timezone"_sd) {
1060 timeZoneElem = arg;
1061 } else {
1062 uasserted(40518,
1063 str::stream() << "Unrecognized argument to $dateFromParts: "
1064 << arg.fieldName());
1065 }
1066 }
1067
1068 if (!yearElem && !isoWeekYearElem) {
1069 uasserted(40516, "$dateFromParts requires either 'year' or 'isoWeekYear' to be present");
1070 }
1071
1072 if (yearElem && (isoWeekYearElem || isoWeekElem || isoDayOfWeekElem)) {
1073 uasserted(40489, "$dateFromParts does not allow mixing natural dates with ISO dates");
1074 }
1075
1076 if (isoWeekYearElem && (yearElem || monthElem || dayElem)) {
1077 uasserted(40525, "$dateFromParts does not allow mixing ISO dates with natural dates");
1078 }
1079
1080 return new ExpressionDateFromParts(
1081 expCtx,
1082 yearElem ? parseOperand(expCtx, yearElem, vps) : nullptr,
1083 monthElem ? parseOperand(expCtx, monthElem, vps) : nullptr,
1084 dayElem ? parseOperand(expCtx, dayElem, vps) : nullptr,
1085 hourElem ? parseOperand(expCtx, hourElem, vps) : nullptr,
1086 minuteElem ? parseOperand(expCtx, minuteElem, vps) : nullptr,
1087 secondElem ? parseOperand(expCtx, secondElem, vps) : nullptr,
1088 millisecondElem ? parseOperand(expCtx, millisecondElem, vps) : nullptr,
1089 isoWeekYearElem ? parseOperand(expCtx, isoWeekYearElem, vps) : nullptr,
1090 isoWeekElem ? parseOperand(expCtx, isoWeekElem, vps) : nullptr,
1091 isoDayOfWeekElem ? parseOperand(expCtx, isoDayOfWeekElem, vps) : nullptr,
1092 timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps) : nullptr);
1093 }
1094
ExpressionDateFromParts(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> year,intrusive_ptr<Expression> month,intrusive_ptr<Expression> day,intrusive_ptr<Expression> hour,intrusive_ptr<Expression> minute,intrusive_ptr<Expression> second,intrusive_ptr<Expression> millisecond,intrusive_ptr<Expression> isoWeekYear,intrusive_ptr<Expression> isoWeek,intrusive_ptr<Expression> isoDayOfWeek,intrusive_ptr<Expression> timeZone)1095 ExpressionDateFromParts::ExpressionDateFromParts(
1096 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1097 intrusive_ptr<Expression> year,
1098 intrusive_ptr<Expression> month,
1099 intrusive_ptr<Expression> day,
1100 intrusive_ptr<Expression> hour,
1101 intrusive_ptr<Expression> minute,
1102 intrusive_ptr<Expression> second,
1103 intrusive_ptr<Expression> millisecond,
1104 intrusive_ptr<Expression> isoWeekYear,
1105 intrusive_ptr<Expression> isoWeek,
1106 intrusive_ptr<Expression> isoDayOfWeek,
1107 intrusive_ptr<Expression> timeZone)
1108 : Expression(expCtx),
1109 _year(std::move(year)),
1110 _month(std::move(month)),
1111 _day(std::move(day)),
1112 _hour(std::move(hour)),
1113 _minute(std::move(minute)),
1114 _second(std::move(second)),
1115 _millisecond(std::move(millisecond)),
1116 _isoWeekYear(std::move(isoWeekYear)),
1117 _isoWeek(std::move(isoWeek)),
1118 _isoDayOfWeek(std::move(isoDayOfWeek)),
1119 _timeZone(std::move(timeZone)) {}
1120
optimize()1121 intrusive_ptr<Expression> ExpressionDateFromParts::optimize() {
1122 if (_year) {
1123 _year = _year->optimize();
1124 }
1125 if (_month) {
1126 _month = _month->optimize();
1127 }
1128 if (_day) {
1129 _day = _day->optimize();
1130 }
1131 if (_hour) {
1132 _hour = _hour->optimize();
1133 }
1134 if (_minute) {
1135 _minute = _minute->optimize();
1136 }
1137 if (_second) {
1138 _second = _second->optimize();
1139 }
1140 if (_millisecond) {
1141 _millisecond = _millisecond->optimize();
1142 }
1143 if (_isoWeekYear) {
1144 _isoWeekYear = _isoWeekYear->optimize();
1145 }
1146 if (_isoWeek) {
1147 _isoWeek = _isoWeek->optimize();
1148 }
1149 if (_isoDayOfWeek) {
1150 _isoDayOfWeek = _isoDayOfWeek->optimize();
1151 }
1152 if (_timeZone) {
1153 _timeZone = _timeZone->optimize();
1154 }
1155
1156 if (ExpressionConstant::allNullOrConstant({_year,
1157 _month,
1158 _day,
1159 _hour,
1160 _minute,
1161 _second,
1162 _millisecond,
1163 _isoWeekYear,
1164 _isoWeek,
1165 _isoDayOfWeek,
1166 _timeZone})) {
1167
1168 // Everything is a constant, so we can turn into a constant.
1169 return ExpressionConstant::create(
1170 getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1171 }
1172
1173 return this;
1174 }
1175
serialize(bool explain) const1176 Value ExpressionDateFromParts::serialize(bool explain) const {
1177 return Value(Document{
1178 {"$dateFromParts",
1179 Document{{"year", _year ? _year->serialize(explain) : Value()},
1180 {"month", _month ? _month->serialize(explain) : Value()},
1181 {"day", _day ? _day->serialize(explain) : Value()},
1182 {"hour", _hour ? _hour->serialize(explain) : Value()},
1183 {"minute", _minute ? _minute->serialize(explain) : Value()},
1184 {"second", _second ? _second->serialize(explain) : Value()},
1185 {"millisecond", _millisecond ? _millisecond->serialize(explain) : Value()},
1186 {"isoWeekYear", _isoWeekYear ? _isoWeekYear->serialize(explain) : Value()},
1187 {"isoWeek", _isoWeek ? _isoWeek->serialize(explain) : Value()},
1188 {"isoDayOfWeek", _isoDayOfWeek ? _isoDayOfWeek->serialize(explain) : Value()},
1189 {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1190 }
1191
1192 /**
1193 * This function checks whether a field is a number, and fits in the given range.
1194 *
1195 * If the field does not exist, the default value is returned trough the returnValue out parameter
1196 * and the function returns true.
1197 *
1198 * If the field exists:
1199 * - if the value is "nullish", the function returns false, so that the calling function can return
1200 * a BSONNULL value.
1201 * - if the value can not be coerced to an integral value, an exception is returned.
1202 * - if the value is out of the range [minValue..maxValue], an exception is returned.
1203 * - otherwise, the coerced integral value is returned through the returnValue
1204 * out parameter, and the function returns true.
1205 */
evaluateNumberWithinRange(const Document & root,const Expression * field,StringData fieldName,int defaultValue,int minValue,int maxValue,int * returnValue,Variables * variables) const1206 bool ExpressionDateFromParts::evaluateNumberWithinRange(const Document& root,
1207 const Expression* field,
1208 StringData fieldName,
1209 int defaultValue,
1210 int minValue,
1211 int maxValue,
1212 int* returnValue,
1213 Variables* variables) const {
1214 if (!field) {
1215 *returnValue = defaultValue;
1216 return true;
1217 }
1218
1219 auto fieldValue = field->evaluate(root, variables);
1220
1221 if (fieldValue.nullish()) {
1222 return false;
1223 }
1224
1225 uassert(40515,
1226 str::stream() << "'" << fieldName << "' must evaluate to an integer, found "
1227 << typeName(fieldValue.getType())
1228 << " with value "
1229 << fieldValue.toString(),
1230 fieldValue.integral());
1231
1232 *returnValue = fieldValue.coerceToInt();
1233
1234 uassert(40523,
1235 str::stream() << "'" << fieldName << "' must evaluate to an integer in the range "
1236 << minValue
1237 << " to "
1238 << maxValue
1239 << ", found "
1240 << *returnValue,
1241 *returnValue >= minValue && *returnValue <= maxValue);
1242
1243 return true;
1244 }
1245
evaluate(const Document & root,Variables * variables) const1246 Value ExpressionDateFromParts::evaluate(const Document& root, Variables* variables) const {
1247 int hour, minute, second, millisecond;
1248
1249 if (!evaluateNumberWithinRange(root, _hour.get(), "hour"_sd, 0, 0, 24, &hour, variables) ||
1250 !evaluateNumberWithinRange(
1251 root, _minute.get(), "minute"_sd, 0, 0, 59, &minute, variables) ||
1252 !evaluateNumberWithinRange(
1253 root, _second.get(), "second"_sd, 0, 0, 59, &second, variables) ||
1254 !evaluateNumberWithinRange(
1255 root, _millisecond.get(), "millisecond"_sd, 0, 0, 999, &millisecond, variables)) {
1256 return Value(BSONNULL);
1257 }
1258
1259 auto timeZone =
1260 makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1261
1262 if (!timeZone) {
1263 return Value(BSONNULL);
1264 }
1265
1266 if (_year) {
1267 int year, month, day;
1268
1269 if (!evaluateNumberWithinRange(
1270 root, _year.get(), "year"_sd, 1970, 0, 9999, &year, variables) ||
1271 !evaluateNumberWithinRange(
1272 root, _month.get(), "month"_sd, 1, 1, 12, &month, variables) ||
1273 !evaluateNumberWithinRange(root, _day.get(), "day"_sd, 1, 1, 31, &day, variables)) {
1274 return Value(BSONNULL);
1275 }
1276
1277 return Value(
1278 timeZone->createFromDateParts(year, month, day, hour, minute, second, millisecond));
1279 }
1280
1281 if (_isoWeekYear) {
1282 int isoWeekYear, isoWeek, isoDayOfWeek;
1283
1284 if (!evaluateNumberWithinRange(root,
1285 _isoWeekYear.get(),
1286 "isoWeekYear"_sd,
1287 1970,
1288 0,
1289 9999,
1290 &isoWeekYear,
1291 variables) ||
1292 !evaluateNumberWithinRange(
1293 root, _isoWeek.get(), "isoWeek"_sd, 1, 1, 53, &isoWeek, variables) ||
1294 !evaluateNumberWithinRange(
1295 root, _isoDayOfWeek.get(), "isoDayOfWeek"_sd, 1, 1, 7, &isoDayOfWeek, variables)) {
1296 return Value(BSONNULL);
1297 }
1298
1299 return Value(timeZone->createFromIso8601DateParts(
1300 isoWeekYear, isoWeek, isoDayOfWeek, hour, minute, second, millisecond));
1301 }
1302
1303 MONGO_UNREACHABLE;
1304 }
1305
_doAddDependencies(DepsTracker * deps) const1306 void ExpressionDateFromParts::_doAddDependencies(DepsTracker* deps) const {
1307 if (_year) {
1308 _year->addDependencies(deps);
1309 }
1310 if (_month) {
1311 _month->addDependencies(deps);
1312 }
1313 if (_day) {
1314 _day->addDependencies(deps);
1315 }
1316 if (_hour) {
1317 _hour->addDependencies(deps);
1318 }
1319 if (_minute) {
1320 _minute->addDependencies(deps);
1321 }
1322 if (_second) {
1323 _second->addDependencies(deps);
1324 }
1325 if (_millisecond) {
1326 _millisecond->addDependencies(deps);
1327 }
1328 if (_isoWeekYear) {
1329 _isoWeekYear->addDependencies(deps);
1330 }
1331 if (_isoWeek) {
1332 _isoWeek->addDependencies(deps);
1333 }
1334 if (_isoDayOfWeek) {
1335 _isoDayOfWeek->addDependencies(deps);
1336 }
1337 if (_timeZone) {
1338 _timeZone->addDependencies(deps);
1339 }
1340 }
1341
1342 /* ---------------------- ExpressionDateFromString --------------------- */
1343
1344 REGISTER_EXPRESSION(dateFromString, ExpressionDateFromString::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1345 intrusive_ptr<Expression> ExpressionDateFromString::parse(
1346 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1347 BSONElement expr,
1348 const VariablesParseState& vps) {
1349
1350 uassert(40540,
1351 str::stream() << "$dateFromString only supports an object as an argument, found: "
1352 << typeName(expr.type()),
1353 expr.type() == BSONType::Object);
1354
1355 BSONElement dateStringElem;
1356 BSONElement timeZoneElem;
1357
1358 const BSONObj args = expr.embeddedObject();
1359 for (auto&& arg : args) {
1360 auto field = arg.fieldNameStringData();
1361
1362 if (field == "dateString"_sd) {
1363 dateStringElem = arg;
1364 } else if (field == "timezone"_sd) {
1365 timeZoneElem = arg;
1366 } else {
1367 uasserted(40541,
1368 str::stream() << "Unrecognized argument to $dateFromString: "
1369 << arg.fieldName());
1370 }
1371 }
1372
1373 uassert(40542, "Missing 'dateString' parameter to $dateFromString", dateStringElem);
1374
1375 return new ExpressionDateFromString(expCtx,
1376 parseOperand(expCtx, dateStringElem, vps),
1377 timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps)
1378 : nullptr);
1379 }
1380
ExpressionDateFromString(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> dateString,intrusive_ptr<Expression> timeZone)1381 ExpressionDateFromString::ExpressionDateFromString(
1382 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1383 intrusive_ptr<Expression> dateString,
1384 intrusive_ptr<Expression> timeZone)
1385 : Expression(expCtx), _dateString(std::move(dateString)), _timeZone(std::move(timeZone)) {}
1386
optimize()1387 intrusive_ptr<Expression> ExpressionDateFromString::optimize() {
1388 _dateString = _dateString->optimize();
1389 if (_timeZone) {
1390 _timeZone = _timeZone->optimize();
1391 }
1392
1393 if (ExpressionConstant::allNullOrConstant({_dateString, _timeZone})) {
1394 // Everything is a constant, so we can turn into a constant.
1395 return ExpressionConstant::create(
1396 getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1397 }
1398 return this;
1399 }
1400
serialize(bool explain) const1401 Value ExpressionDateFromString::serialize(bool explain) const {
1402 return Value(
1403 Document{{"$dateFromString",
1404 Document{{"dateString", _dateString->serialize(explain)},
1405 {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1406 }
1407
evaluate(const Document & root,Variables * variables) const1408 Value ExpressionDateFromString::evaluate(const Document& root, Variables* variables) const {
1409 const Value dateString = _dateString->evaluate(root, variables);
1410
1411 auto timeZone =
1412 makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1413
1414 if (!timeZone || dateString.nullish()) {
1415 return Value(BSONNULL);
1416 }
1417
1418 uassert(40543,
1419 str::stream() << "$dateFromString requires that 'dateString' be a string, found: "
1420 << typeName(dateString.getType())
1421 << " with value "
1422 << dateString.toString(),
1423 dateString.getType() == BSONType::String);
1424 const std::string& dateTimeString = dateString.getString();
1425
1426 return Value(getExpressionContext()->timeZoneDatabase->fromString(dateTimeString, timeZone));
1427 }
1428
_doAddDependencies(DepsTracker * deps) const1429 void ExpressionDateFromString::_doAddDependencies(DepsTracker* deps) const {
1430 _dateString->addDependencies(deps);
1431 if (_timeZone) {
1432 _timeZone->addDependencies(deps);
1433 }
1434 }
1435
1436 /* ---------------------- ExpressionDateToParts ----------------------- */
1437
1438 REGISTER_EXPRESSION(dateToParts, ExpressionDateToParts::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1439 intrusive_ptr<Expression> ExpressionDateToParts::parse(
1440 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1441 BSONElement expr,
1442 const VariablesParseState& vps) {
1443
1444 uassert(40524,
1445 "$dateToParts only supports an object as its argument",
1446 expr.type() == BSONType::Object);
1447
1448 BSONElement dateElem;
1449 BSONElement timeZoneElem;
1450 BSONElement isoDateElem;
1451
1452 const BSONObj args = expr.embeddedObject();
1453 for (auto&& arg : args) {
1454 auto field = arg.fieldNameStringData();
1455
1456 if (field == "date"_sd) {
1457 dateElem = arg;
1458 } else if (field == "timezone"_sd) {
1459 timeZoneElem = arg;
1460 } else if (field == "iso8601"_sd) {
1461 isoDateElem = arg;
1462 } else {
1463 uasserted(40520,
1464 str::stream() << "Unrecognized argument to $dateToParts: "
1465 << arg.fieldName());
1466 }
1467 }
1468
1469 uassert(40522, "Missing 'date' parameter to $dateToParts", dateElem);
1470
1471 return new ExpressionDateToParts(
1472 expCtx,
1473 parseOperand(expCtx, dateElem, vps),
1474 timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps) : nullptr,
1475 isoDateElem ? parseOperand(expCtx, isoDateElem, vps) : nullptr);
1476 }
1477
ExpressionDateToParts(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> date,intrusive_ptr<Expression> timeZone,intrusive_ptr<Expression> iso8601)1478 ExpressionDateToParts::ExpressionDateToParts(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1479 intrusive_ptr<Expression> date,
1480 intrusive_ptr<Expression> timeZone,
1481 intrusive_ptr<Expression> iso8601)
1482 : Expression(expCtx),
1483 _date(std::move(date)),
1484 _timeZone(std::move(timeZone)),
1485 _iso8601(std::move(iso8601)) {}
1486
optimize()1487 intrusive_ptr<Expression> ExpressionDateToParts::optimize() {
1488 _date = _date->optimize();
1489 if (_timeZone) {
1490 _timeZone = _timeZone->optimize();
1491 }
1492 if (_iso8601) {
1493 _iso8601 = _iso8601->optimize();
1494 }
1495
1496 if (ExpressionConstant::allNullOrConstant({_date, _iso8601, _timeZone})) {
1497 // Everything is a constant, so we can turn into a constant.
1498 return ExpressionConstant::create(
1499 getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1500 }
1501
1502 return this;
1503 }
1504
serialize(bool explain) const1505 Value ExpressionDateToParts::serialize(bool explain) const {
1506 return Value(
1507 Document{{"$dateToParts",
1508 Document{{"date", _date->serialize(explain)},
1509 {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()},
1510 {"iso8601", _iso8601 ? _iso8601->serialize(explain) : Value()}}}});
1511 }
1512
evaluateIso8601Flag(const Document & root,Variables * variables) const1513 boost::optional<int> ExpressionDateToParts::evaluateIso8601Flag(const Document& root,
1514 Variables* variables) const {
1515 if (!_iso8601) {
1516 return false;
1517 }
1518
1519 auto iso8601Output = _iso8601->evaluate(root, variables);
1520
1521 if (iso8601Output.nullish()) {
1522 return boost::none;
1523 }
1524
1525 uassert(40521,
1526 str::stream() << "iso8601 must evaluate to a bool, found "
1527 << typeName(iso8601Output.getType()),
1528 iso8601Output.getType() == BSONType::Bool);
1529
1530 return iso8601Output.getBool();
1531 }
1532
evaluate(const Document & root,Variables * variables) const1533 Value ExpressionDateToParts::evaluate(const Document& root, Variables* variables) const {
1534 const Value date = _date->evaluate(root, variables);
1535
1536 auto timeZone =
1537 makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1538 if (!timeZone) {
1539 return Value(BSONNULL);
1540 }
1541
1542 auto iso8601 = evaluateIso8601Flag(root, variables);
1543 if (!iso8601) {
1544 return Value(BSONNULL);
1545 }
1546
1547 if (date.nullish()) {
1548 return Value(BSONNULL);
1549 }
1550
1551 auto dateValue = date.coerceToDate();
1552
1553 if (*iso8601) {
1554 auto parts = timeZone->dateIso8601Parts(dateValue);
1555 return Value(Document{{"isoWeekYear", parts.year},
1556 {"isoWeek", parts.weekOfYear},
1557 {"isoDayOfWeek", parts.dayOfWeek},
1558 {"hour", parts.hour},
1559 {"minute", parts.minute},
1560 {"second", parts.second},
1561 {"millisecond", parts.millisecond}});
1562 } else {
1563 auto parts = timeZone->dateParts(dateValue);
1564 return Value(Document{{"year", parts.year},
1565 {"month", parts.month},
1566 {"day", parts.dayOfMonth},
1567 {"hour", parts.hour},
1568 {"minute", parts.minute},
1569 {"second", parts.second},
1570 {"millisecond", parts.millisecond}});
1571 }
1572 }
1573
_doAddDependencies(DepsTracker * deps) const1574 void ExpressionDateToParts::_doAddDependencies(DepsTracker* deps) const {
1575 _date->addDependencies(deps);
1576 if (_timeZone) {
1577 _timeZone->addDependencies(deps);
1578 }
1579 if (_iso8601) {
1580 _iso8601->addDependencies(deps);
1581 }
1582 }
1583
1584
1585 /* ---------------------- ExpressionDateToString ----------------------- */
1586
1587 REGISTER_EXPRESSION(dateToString, ExpressionDateToString::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1588 intrusive_ptr<Expression> ExpressionDateToString::parse(
1589 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1590 BSONElement expr,
1591 const VariablesParseState& vps) {
1592 verify(str::equals(expr.fieldName(), "$dateToString"));
1593
1594 uassert(18629, "$dateToString only supports an object as its argument", expr.type() == Object);
1595
1596 BSONElement formatElem;
1597 BSONElement dateElem;
1598 BSONElement timeZoneElem;
1599 const BSONObj args = expr.embeddedObject();
1600 BSONForEach(arg, args) {
1601 if (str::equals(arg.fieldName(), "format")) {
1602 formatElem = arg;
1603 } else if (str::equals(arg.fieldName(), "date")) {
1604 dateElem = arg;
1605 } else if (str::equals(arg.fieldName(), "timezone")) {
1606 timeZoneElem = arg;
1607 } else {
1608 uasserted(18534,
1609 str::stream() << "Unrecognized argument to $dateToString: "
1610 << arg.fieldName());
1611 }
1612 }
1613
1614 uassert(18627, "Missing 'format' parameter to $dateToString", !formatElem.eoo());
1615 uassert(18628, "Missing 'date' parameter to $dateToString", !dateElem.eoo());
1616
1617 uassert(18533,
1618 "The 'format' parameter to $dateToString must be a string literal",
1619 formatElem.type() == String);
1620
1621 const string format = formatElem.str();
1622
1623 TimeZone::validateFormat(format);
1624
1625 return new ExpressionDateToString(expCtx,
1626 format,
1627 parseOperand(expCtx, dateElem, vps),
1628 timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps)
1629 : nullptr);
1630 }
1631
ExpressionDateToString(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & format,intrusive_ptr<Expression> date,intrusive_ptr<Expression> timeZone)1632 ExpressionDateToString::ExpressionDateToString(
1633 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1634 const string& format,
1635 intrusive_ptr<Expression> date,
1636 intrusive_ptr<Expression> timeZone)
1637 : Expression(expCtx), _format(format), _date(std::move(date)), _timeZone(std::move(timeZone)) {}
1638
optimize()1639 intrusive_ptr<Expression> ExpressionDateToString::optimize() {
1640 _date = _date->optimize();
1641 if (_timeZone) {
1642 _timeZone = _timeZone->optimize();
1643 }
1644
1645 if (ExpressionConstant::allNullOrConstant({_date, _timeZone})) {
1646 // Everything is a constant, so we can turn into a constant.
1647 return ExpressionConstant::create(
1648 getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1649 }
1650
1651 return this;
1652 }
1653
serialize(bool explain) const1654 Value ExpressionDateToString::serialize(bool explain) const {
1655 return Value(
1656 Document{{"$dateToString",
1657 Document{{"format", _format},
1658 {"date", _date->serialize(explain)},
1659 {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1660 }
1661
evaluate(const Document & root,Variables * variables) const1662 Value ExpressionDateToString::evaluate(const Document& root, Variables* variables) const {
1663 const Value date = _date->evaluate(root, variables);
1664
1665 auto timeZone =
1666 makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1667 if (!timeZone) {
1668 return Value(BSONNULL);
1669 }
1670
1671 if (date.nullish()) {
1672 return Value(BSONNULL);
1673 }
1674
1675 return Value(timeZone->formatDate(_format, date.coerceToDate()));
1676 }
1677
_doAddDependencies(DepsTracker * deps) const1678 void ExpressionDateToString::_doAddDependencies(DepsTracker* deps) const {
1679 _date->addDependencies(deps);
1680 if (_timeZone) {
1681 _timeZone->addDependencies(deps);
1682 }
1683 }
1684
1685 /* ----------------------- ExpressionDivide ---------------------------- */
1686
evaluate(const Document & root,Variables * variables) const1687 Value ExpressionDivide::evaluate(const Document& root, Variables* variables) const {
1688 Value lhs = vpOperand[0]->evaluate(root, variables);
1689 Value rhs = vpOperand[1]->evaluate(root, variables);
1690
1691 auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); };
1692
1693 if (lhs.numeric() && rhs.numeric()) {
1694 // If, and only if, either side is decimal, return decimal.
1695 if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) {
1696 Decimal128 numer = lhs.coerceToDecimal();
1697 Decimal128 denom = rhs.coerceToDecimal();
1698 assertNonZero(!denom.isZero());
1699 return Value(numer.divide(denom));
1700 }
1701
1702 double numer = lhs.coerceToDouble();
1703 double denom = rhs.coerceToDouble();
1704 assertNonZero(denom != 0.0);
1705
1706 return Value(numer / denom);
1707 } else if (lhs.nullish() || rhs.nullish()) {
1708 return Value(BSONNULL);
1709 } else {
1710 uasserted(16609,
1711 str::stream() << "$divide only supports numeric types, not "
1712 << typeName(lhs.getType())
1713 << " and "
1714 << typeName(rhs.getType()));
1715 }
1716 }
1717
1718 REGISTER_EXPRESSION(divide, ExpressionDivide::parse);
getOpName() const1719 const char* ExpressionDivide::getOpName() const {
1720 return "$divide";
1721 }
1722
1723 /* ----------------------- ExpressionExp ---------------------------- */
1724
evaluateNumericArg(const Value & numericArg) const1725 Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const {
1726 // $exp always returns either a double or a decimal number, as e is irrational.
1727 if (numericArg.getType() == NumberDecimal)
1728 return Value(numericArg.coerceToDecimal().exponential());
1729
1730 return Value(exp(numericArg.coerceToDouble()));
1731 }
1732
1733 REGISTER_EXPRESSION(exp, ExpressionExp::parse);
getOpName() const1734 const char* ExpressionExp::getOpName() const {
1735 return "$exp";
1736 }
1737
1738 /* ---------------------- ExpressionObject --------------------------- */
1739
ExpressionObject(const boost::intrusive_ptr<ExpressionContext> & expCtx,vector<pair<string,intrusive_ptr<Expression>>> && expressions)1740 ExpressionObject::ExpressionObject(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1741 vector<pair<string, intrusive_ptr<Expression>>>&& expressions)
1742 : Expression(expCtx), _expressions(std::move(expressions)) {}
1743
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,vector<pair<string,intrusive_ptr<Expression>>> && expressions)1744 intrusive_ptr<ExpressionObject> ExpressionObject::create(
1745 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1746 vector<pair<string, intrusive_ptr<Expression>>>&& expressions) {
1747 return new ExpressionObject(expCtx, std::move(expressions));
1748 }
1749
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)1750 intrusive_ptr<ExpressionObject> ExpressionObject::parse(
1751 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1752 BSONObj obj,
1753 const VariablesParseState& vps) {
1754 // Make sure we don't have any duplicate field names.
1755 stdx::unordered_set<string> specifiedFields;
1756
1757 vector<pair<string, intrusive_ptr<Expression>>> expressions;
1758 for (auto&& elem : obj) {
1759 // Make sure this element has a valid field name. Use StringData here so that we can detect
1760 // if the field name contains a null byte.
1761 FieldPath::uassertValidFieldName(elem.fieldNameStringData());
1762
1763 auto fieldName = elem.fieldName();
1764 uassert(16406,
1765 str::stream() << "duplicate field name specified in object literal: "
1766 << obj.toString(),
1767 specifiedFields.find(fieldName) == specifiedFields.end());
1768 specifiedFields.insert(fieldName);
1769 expressions.emplace_back(fieldName, parseOperand(expCtx, elem, vps));
1770 }
1771
1772 return new ExpressionObject{expCtx, std::move(expressions)};
1773 }
1774
optimize()1775 intrusive_ptr<Expression> ExpressionObject::optimize() {
1776 for (auto&& pair : _expressions) {
1777 pair.second = pair.second->optimize();
1778 }
1779 return this;
1780 }
1781
_doAddDependencies(DepsTracker * deps) const1782 void ExpressionObject::_doAddDependencies(DepsTracker* deps) const {
1783 for (auto&& pair : _expressions) {
1784 pair.second->addDependencies(deps);
1785 }
1786 }
1787
evaluate(const Document & root,Variables * variables) const1788 Value ExpressionObject::evaluate(const Document& root, Variables* variables) const {
1789 MutableDocument outputDoc;
1790 for (auto&& pair : _expressions) {
1791 outputDoc.addField(pair.first, pair.second->evaluate(root, variables));
1792 }
1793 return outputDoc.freezeToValue();
1794 }
1795
serialize(bool explain) const1796 Value ExpressionObject::serialize(bool explain) const {
1797 MutableDocument outputDoc;
1798 for (auto&& pair : _expressions) {
1799 outputDoc.addField(pair.first, pair.second->serialize(explain));
1800 }
1801 return outputDoc.freezeToValue();
1802 }
1803
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const1804 Expression::ComputedPaths ExpressionObject::getComputedPaths(const std::string& exprFieldPath,
1805 Variables::Id renamingVar) const {
1806 ComputedPaths outputPaths;
1807 for (auto&& pair : _expressions) {
1808 auto exprComputedPaths = pair.second->getComputedPaths(pair.first, renamingVar);
1809 for (auto&& renames : exprComputedPaths.renames) {
1810 auto newPath = FieldPath::getFullyQualifiedPath(exprFieldPath, renames.first);
1811 outputPaths.renames[std::move(newPath)] = renames.second;
1812 }
1813 for (auto&& path : exprComputedPaths.paths) {
1814 outputPaths.paths.insert(FieldPath::getFullyQualifiedPath(exprFieldPath, path));
1815 }
1816 }
1817
1818 return outputPaths;
1819 }
1820
1821 /* --------------------- ExpressionFieldPath --------------------------- */
1822
1823 // this is the old deprecated version only used by tests not using variables
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & fieldPath)1824 intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create(
1825 const boost::intrusive_ptr<ExpressionContext>& expCtx, const string& fieldPath) {
1826 return new ExpressionFieldPath(expCtx, "CURRENT." + fieldPath, Variables::kRootId);
1827 }
1828
1829 // this is the new version that supports every syntax
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & raw,const VariablesParseState & vps)1830 intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(
1831 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1832 const string& raw,
1833 const VariablesParseState& vps) {
1834 uassert(16873,
1835 str::stream() << "FieldPath '" << raw << "' doesn't start with $",
1836 raw.c_str()[0] == '$'); // c_str()[0] is always a valid reference.
1837
1838 uassert(16872,
1839 str::stream() << "'$' by itself is not a valid FieldPath",
1840 raw.size() >= 2); // need at least "$" and either "$" or a field name
1841
1842 if (raw[1] == '$') {
1843 const StringData rawSD = raw;
1844 const StringData fieldPath = rawSD.substr(2); // strip off $$
1845 const StringData varName = fieldPath.substr(0, fieldPath.find('.'));
1846 Variables::uassertValidNameForUserRead(varName);
1847 return new ExpressionFieldPath(expCtx, fieldPath.toString(), vps.getVariable(varName));
1848 } else {
1849 return new ExpressionFieldPath(expCtx,
1850 "CURRENT." + raw.substr(1), // strip the "$" prefix
1851 vps.getVariable("CURRENT"));
1852 }
1853 }
1854
ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & theFieldPath,Variables::Id variable)1855 ExpressionFieldPath::ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1856 const string& theFieldPath,
1857 Variables::Id variable)
1858 : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {}
1859
optimize()1860 intrusive_ptr<Expression> ExpressionFieldPath::optimize() {
1861 if (_variable == Variables::kRemoveId) {
1862 // The REMOVE system variable optimizes to a constant missing value.
1863 return ExpressionConstant::create(getExpressionContext(), Value());
1864 }
1865
1866 if (getExpressionContext()->variables.hasConstantValue(_variable)) {
1867 return ExpressionConstant::create(
1868 getExpressionContext(), evaluate(Document(), &(getExpressionContext()->variables)));
1869 }
1870
1871 return intrusive_ptr<Expression>(this);
1872 }
1873
_doAddDependencies(DepsTracker * deps) const1874 void ExpressionFieldPath::_doAddDependencies(DepsTracker* deps) const {
1875 if (_variable == Variables::kRootId) { // includes CURRENT when it is equivalent to ROOT.
1876 if (_fieldPath.getPathLength() == 1) {
1877 deps->needWholeDocument = true; // need full doc if just "$$ROOT"
1878 } else {
1879 deps->fields.insert(_fieldPath.tail().fullPath());
1880 }
1881 } else if (Variables::isUserDefinedVariable(_variable)) {
1882 deps->vars.insert(_variable);
1883 }
1884 }
1885
evaluatePathArray(size_t index,const Value & input) const1886 Value ExpressionFieldPath::evaluatePathArray(size_t index, const Value& input) const {
1887 dassert(input.isArray());
1888
1889 // Check for remaining path in each element of array
1890 vector<Value> result;
1891 const vector<Value>& array = input.getArray();
1892 for (size_t i = 0; i < array.size(); i++) {
1893 if (array[i].getType() != Object)
1894 continue;
1895
1896 const Value nested = evaluatePath(index, array[i].getDocument());
1897 if (!nested.missing())
1898 result.push_back(nested);
1899 }
1900
1901 return Value(std::move(result));
1902 }
evaluatePath(size_t index,const Document & input) const1903 Value ExpressionFieldPath::evaluatePath(size_t index, const Document& input) const {
1904 // Note this function is very hot so it is important that is is well optimized.
1905 // In particular, all return paths should support RVO.
1906
1907 /* if we've hit the end of the path, stop */
1908 if (index == _fieldPath.getPathLength() - 1)
1909 return input[_fieldPath.getFieldName(index)];
1910
1911 // Try to dive deeper
1912 const Value val = input[_fieldPath.getFieldName(index)];
1913 switch (val.getType()) {
1914 case Object:
1915 return evaluatePath(index + 1, val.getDocument());
1916
1917 case Array:
1918 return evaluatePathArray(index + 1, val);
1919
1920 default:
1921 return Value();
1922 }
1923 }
1924
evaluate(const Document & root,Variables * variables) const1925 Value ExpressionFieldPath::evaluate(const Document& root, Variables* variables) const {
1926 if (_fieldPath.getPathLength() == 1) // get the whole variable
1927 return variables->getValue(_variable, root);
1928
1929 if (_variable == Variables::kRootId) {
1930 // ROOT is always a document so use optimized code path
1931 return evaluatePath(1, root);
1932 }
1933
1934 Value var = variables->getValue(_variable, root);
1935 switch (var.getType()) {
1936 case Object:
1937 return evaluatePath(1, var.getDocument());
1938 case Array:
1939 return evaluatePathArray(1, var);
1940 default:
1941 return Value();
1942 }
1943 }
1944
serialize(bool explain) const1945 Value ExpressionFieldPath::serialize(bool explain) const {
1946 if (_fieldPath.getFieldName(0) == "CURRENT" && _fieldPath.getPathLength() > 1) {
1947 // use short form for "$$CURRENT.foo" but not just "$$CURRENT"
1948 return Value("$" + _fieldPath.tail().fullPath());
1949 } else {
1950 return Value("$$" + _fieldPath.fullPath());
1951 }
1952 }
1953
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const1954 Expression::ComputedPaths ExpressionFieldPath::getComputedPaths(const std::string& exprFieldPath,
1955 Variables::Id renamingVar) const {
1956 // An expression field path is either considered a rename or a computed path. We need to find
1957 // out which case we fall into.
1958 //
1959 // The caller has told us that renames must have 'varId' as the first component. We also check
1960 // that there is only one additional component---no dotted field paths are allowed! This is
1961 // because dotted ExpressionFieldPaths can actually reshape the document rather than just
1962 // changing the field names. This can happen only if there are arrays along the dotted path.
1963 //
1964 // For example, suppose you have document {a: [{b: 1}, {b: 2}]}. The projection {"c.d": "$a.b"}
1965 // does *not* perform the strict rename to yield document {c: [{d: 1}, {d: 2}]}. Instead, it
1966 // results in the document {c: {d: [1, 2]}}. Due to this reshaping, matches expressed over "a.b"
1967 // before the $project is applied may not have the same behavior when expressed over "c.d" after
1968 // the $project is applied.
1969 ComputedPaths outputPaths;
1970 if (_variable == renamingVar && _fieldPath.getPathLength() == 2u) {
1971 outputPaths.renames[exprFieldPath] = _fieldPath.tail().fullPath();
1972 } else {
1973 outputPaths.paths.insert(exprFieldPath);
1974 }
1975
1976 return outputPaths;
1977 }
1978
1979 /* ------------------------- ExpressionFilter ----------------------------- */
1980
1981 REGISTER_EXPRESSION(filter, ExpressionFilter::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)1982 intrusive_ptr<Expression> ExpressionFilter::parse(
1983 const boost::intrusive_ptr<ExpressionContext>& expCtx,
1984 BSONElement expr,
1985 const VariablesParseState& vpsIn) {
1986 verify(str::equals(expr.fieldName(), "$filter"));
1987
1988 uassert(28646, "$filter only supports an object as its argument", expr.type() == Object);
1989
1990 // "cond" must be parsed after "as" regardless of BSON order.
1991 BSONElement inputElem;
1992 BSONElement asElem;
1993 BSONElement condElem;
1994 for (auto elem : expr.Obj()) {
1995 if (str::equals(elem.fieldName(), "input")) {
1996 inputElem = elem;
1997 } else if (str::equals(elem.fieldName(), "as")) {
1998 asElem = elem;
1999 } else if (str::equals(elem.fieldName(), "cond")) {
2000 condElem = elem;
2001 } else {
2002 uasserted(28647,
2003 str::stream() << "Unrecognized parameter to $filter: " << elem.fieldName());
2004 }
2005 }
2006
2007 uassert(28648, "Missing 'input' parameter to $filter", !inputElem.eoo());
2008 uassert(28650, "Missing 'cond' parameter to $filter", !condElem.eoo());
2009
2010 // Parse "input", only has outer variables.
2011 intrusive_ptr<Expression> input = parseOperand(expCtx, inputElem, vpsIn);
2012
2013 // Parse "as".
2014 VariablesParseState vpsSub(vpsIn); // vpsSub gets our variable, vpsIn doesn't.
2015
2016 // If "as" is not specified, then use "this" by default.
2017 auto varName = asElem.eoo() ? "this" : asElem.str();
2018
2019 Variables::uassertValidNameForUserWrite(varName);
2020 Variables::Id varId = vpsSub.defineVariable(varName);
2021
2022 // Parse "cond", has access to "as" variable.
2023 intrusive_ptr<Expression> cond = parseOperand(expCtx, condElem, vpsSub);
2024
2025 return new ExpressionFilter(
2026 expCtx, std::move(varName), varId, std::move(input), std::move(cond));
2027 }
2028
ExpressionFilter(const boost::intrusive_ptr<ExpressionContext> & expCtx,string varName,Variables::Id varId,intrusive_ptr<Expression> input,intrusive_ptr<Expression> filter)2029 ExpressionFilter::ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2030 string varName,
2031 Variables::Id varId,
2032 intrusive_ptr<Expression> input,
2033 intrusive_ptr<Expression> filter)
2034 : Expression(expCtx),
2035 _varName(std::move(varName)),
2036 _varId(varId),
2037 _input(std::move(input)),
2038 _filter(std::move(filter)) {}
2039
optimize()2040 intrusive_ptr<Expression> ExpressionFilter::optimize() {
2041 // TODO handle when _input is constant.
2042 _input = _input->optimize();
2043 _filter = _filter->optimize();
2044 return this;
2045 }
2046
serialize(bool explain) const2047 Value ExpressionFilter::serialize(bool explain) const {
2048 return Value(
2049 DOC("$filter" << DOC("input" << _input->serialize(explain) << "as" << _varName << "cond"
2050 << _filter->serialize(explain))));
2051 }
2052
evaluate(const Document & root,Variables * variables) const2053 Value ExpressionFilter::evaluate(const Document& root, Variables* variables) const {
2054 // We are guaranteed at parse time that this isn't using our _varId.
2055 const Value inputVal = _input->evaluate(root, variables);
2056 if (inputVal.nullish())
2057 return Value(BSONNULL);
2058
2059 uassert(28651,
2060 str::stream() << "input to $filter must be an array not "
2061 << typeName(inputVal.getType()),
2062 inputVal.isArray());
2063
2064 const vector<Value>& input = inputVal.getArray();
2065
2066 if (input.empty())
2067 return inputVal;
2068
2069 vector<Value> output;
2070 for (const auto& elem : input) {
2071 variables->setValue(_varId, elem);
2072
2073 if (_filter->evaluate(root, variables).coerceToBool()) {
2074 output.push_back(std::move(elem));
2075 }
2076 }
2077
2078 return Value(std::move(output));
2079 }
2080
_doAddDependencies(DepsTracker * deps) const2081 void ExpressionFilter::_doAddDependencies(DepsTracker* deps) const {
2082 _input->addDependencies(deps);
2083 _filter->addDependencies(deps);
2084 }
2085
2086 /* ------------------------- ExpressionFloor -------------------------- */
2087
evaluateNumericArg(const Value & numericArg) const2088 Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const {
2089 // There's no point in taking the floor of integers or longs, it will have no effect.
2090 switch (numericArg.getType()) {
2091 case NumberDouble:
2092 return Value(std::floor(numericArg.getDouble()));
2093 case NumberDecimal:
2094 // Round toward the nearest decimal with a zero exponent in the negative direction.
2095 return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
2096 Decimal128::kRoundTowardNegative));
2097 default:
2098 return numericArg;
2099 }
2100 }
2101
2102 REGISTER_EXPRESSION(floor, ExpressionFloor::parse);
getOpName() const2103 const char* ExpressionFloor::getOpName() const {
2104 return "$floor";
2105 }
2106
2107 /* ------------------------- ExpressionLet ----------------------------- */
2108
2109 REGISTER_EXPRESSION(let, ExpressionLet::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2110 intrusive_ptr<Expression> ExpressionLet::parse(
2111 const boost::intrusive_ptr<ExpressionContext>& expCtx,
2112 BSONElement expr,
2113 const VariablesParseState& vpsIn) {
2114 verify(str::equals(expr.fieldName(), "$let"));
2115
2116 uassert(16874, "$let only supports an object as its argument", expr.type() == Object);
2117 const BSONObj args = expr.embeddedObject();
2118
2119 // varsElem must be parsed before inElem regardless of BSON order.
2120 BSONElement varsElem;
2121 BSONElement inElem;
2122 BSONForEach(arg, args) {
2123 if (str::equals(arg.fieldName(), "vars")) {
2124 varsElem = arg;
2125 } else if (str::equals(arg.fieldName(), "in")) {
2126 inElem = arg;
2127 } else {
2128 uasserted(16875,
2129 str::stream() << "Unrecognized parameter to $let: " << arg.fieldName());
2130 }
2131 }
2132
2133 uassert(16876, "Missing 'vars' parameter to $let", !varsElem.eoo());
2134 uassert(16877, "Missing 'in' parameter to $let", !inElem.eoo());
2135
2136 // parse "vars"
2137 VariablesParseState vpsSub(vpsIn); // vpsSub gets our vars, vpsIn doesn't.
2138 VariableMap vars;
2139 BSONForEach(varElem, varsElem.embeddedObjectUserCheck()) {
2140 const string varName = varElem.fieldName();
2141 Variables::uassertValidNameForUserWrite(varName);
2142 Variables::Id id = vpsSub.defineVariable(varName);
2143
2144 vars[id] = NameAndExpression(varName,
2145 parseOperand(expCtx, varElem, vpsIn)); // only has outer vars
2146 }
2147
2148 // parse "in"
2149 intrusive_ptr<Expression> subExpression = parseOperand(expCtx, inElem, vpsSub); // has our vars
2150
2151 return new ExpressionLet(expCtx, vars, subExpression);
2152 }
2153
ExpressionLet(const boost::intrusive_ptr<ExpressionContext> & expCtx,const VariableMap & vars,intrusive_ptr<Expression> subExpression)2154 ExpressionLet::ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2155 const VariableMap& vars,
2156 intrusive_ptr<Expression> subExpression)
2157 : Expression(expCtx), _variables(vars), _subExpression(subExpression) {}
2158
optimize()2159 intrusive_ptr<Expression> ExpressionLet::optimize() {
2160 if (_variables.empty()) {
2161 // we aren't binding any variables so just return the subexpression
2162 return _subExpression->optimize();
2163 }
2164
2165 for (VariableMap::iterator it = _variables.begin(), end = _variables.end(); it != end; ++it) {
2166 it->second.expression = it->second.expression->optimize();
2167 }
2168
2169 _subExpression = _subExpression->optimize();
2170
2171 return this;
2172 }
2173
serialize(bool explain) const2174 Value ExpressionLet::serialize(bool explain) const {
2175 MutableDocument vars;
2176 for (VariableMap::const_iterator it = _variables.begin(), end = _variables.end(); it != end;
2177 ++it) {
2178 vars[it->second.name] = it->second.expression->serialize(explain);
2179 }
2180
2181 return Value(
2182 DOC("$let" << DOC("vars" << vars.freeze() << "in" << _subExpression->serialize(explain))));
2183 }
2184
evaluate(const Document & root,Variables * variables) const2185 Value ExpressionLet::evaluate(const Document& root, Variables* variables) const {
2186 for (const auto& item : _variables) {
2187 // It is guaranteed at parse-time that these expressions don't use the variable ids we
2188 // are setting
2189 variables->setValue(item.first, item.second.expression->evaluate(root, variables));
2190 }
2191
2192 return _subExpression->evaluate(root, variables);
2193 }
2194
_doAddDependencies(DepsTracker * deps) const2195 void ExpressionLet::_doAddDependencies(DepsTracker* deps) const {
2196 for (auto&& idToNameExp : _variables) {
2197 // Add the external dependencies from the 'vars' statement.
2198 idToNameExp.second.expression->addDependencies(deps);
2199 }
2200
2201 // Add subexpression dependencies, which may contain a mix of local and external variable refs.
2202 _subExpression->addDependencies(deps);
2203 }
2204
2205 /* ------------------------- ExpressionMap ----------------------------- */
2206
2207 REGISTER_EXPRESSION(map, ExpressionMap::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2208 intrusive_ptr<Expression> ExpressionMap::parse(
2209 const boost::intrusive_ptr<ExpressionContext>& expCtx,
2210 BSONElement expr,
2211 const VariablesParseState& vpsIn) {
2212 verify(str::equals(expr.fieldName(), "$map"));
2213
2214 uassert(16878, "$map only supports an object as its argument", expr.type() == Object);
2215
2216 // "in" must be parsed after "as" regardless of BSON order
2217 BSONElement inputElem;
2218 BSONElement asElem;
2219 BSONElement inElem;
2220 const BSONObj args = expr.embeddedObject();
2221 BSONForEach(arg, args) {
2222 if (str::equals(arg.fieldName(), "input")) {
2223 inputElem = arg;
2224 } else if (str::equals(arg.fieldName(), "as")) {
2225 asElem = arg;
2226 } else if (str::equals(arg.fieldName(), "in")) {
2227 inElem = arg;
2228 } else {
2229 uasserted(16879,
2230 str::stream() << "Unrecognized parameter to $map: " << arg.fieldName());
2231 }
2232 }
2233
2234 uassert(16880, "Missing 'input' parameter to $map", !inputElem.eoo());
2235 uassert(16882, "Missing 'in' parameter to $map", !inElem.eoo());
2236
2237 // parse "input"
2238 intrusive_ptr<Expression> input =
2239 parseOperand(expCtx, inputElem, vpsIn); // only has outer vars
2240
2241 // parse "as"
2242 VariablesParseState vpsSub(vpsIn); // vpsSub gets our vars, vpsIn doesn't.
2243
2244 // If "as" is not specified, then use "this" by default.
2245 auto varName = asElem.eoo() ? "this" : asElem.str();
2246
2247 Variables::uassertValidNameForUserWrite(varName);
2248 Variables::Id varId = vpsSub.defineVariable(varName);
2249
2250 // parse "in"
2251 intrusive_ptr<Expression> in =
2252 parseOperand(expCtx, inElem, vpsSub); // has access to map variable
2253
2254 return new ExpressionMap(expCtx, varName, varId, input, in);
2255 }
2256
ExpressionMap(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & varName,Variables::Id varId,intrusive_ptr<Expression> input,intrusive_ptr<Expression> each)2257 ExpressionMap::ExpressionMap(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2258 const string& varName,
2259 Variables::Id varId,
2260 intrusive_ptr<Expression> input,
2261 intrusive_ptr<Expression> each)
2262 : Expression(expCtx), _varName(varName), _varId(varId), _input(input), _each(each) {}
2263
optimize()2264 intrusive_ptr<Expression> ExpressionMap::optimize() {
2265 // TODO handle when _input is constant
2266 _input = _input->optimize();
2267 _each = _each->optimize();
2268 return this;
2269 }
2270
serialize(bool explain) const2271 Value ExpressionMap::serialize(bool explain) const {
2272 return Value(DOC("$map" << DOC("input" << _input->serialize(explain) << "as" << _varName << "in"
2273 << _each->serialize(explain))));
2274 }
2275
evaluate(const Document & root,Variables * variables) const2276 Value ExpressionMap::evaluate(const Document& root, Variables* variables) const {
2277 // guaranteed at parse time that this isn't using our _varId
2278 const Value inputVal = _input->evaluate(root, variables);
2279 if (inputVal.nullish())
2280 return Value(BSONNULL);
2281
2282 uassert(16883,
2283 str::stream() << "input to $map must be an array not " << typeName(inputVal.getType()),
2284 inputVal.isArray());
2285
2286 const vector<Value>& input = inputVal.getArray();
2287
2288 if (input.empty())
2289 return inputVal;
2290
2291 vector<Value> output;
2292 output.reserve(input.size());
2293 for (size_t i = 0; i < input.size(); i++) {
2294 variables->setValue(_varId, input[i]);
2295
2296 Value toInsert = _each->evaluate(root, variables);
2297 if (toInsert.missing())
2298 toInsert = Value(BSONNULL); // can't insert missing values into array
2299
2300 output.push_back(toInsert);
2301 }
2302
2303 return Value(std::move(output));
2304 }
2305
_doAddDependencies(DepsTracker * deps) const2306 void ExpressionMap::_doAddDependencies(DepsTracker* deps) const {
2307 _input->addDependencies(deps);
2308 _each->addDependencies(deps);
2309 }
2310
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const2311 Expression::ComputedPaths ExpressionMap::getComputedPaths(const std::string& exprFieldPath,
2312 Variables::Id renamingVar) const {
2313 auto inputFieldPath = dynamic_cast<ExpressionFieldPath*>(_input.get());
2314 if (!inputFieldPath) {
2315 return {{exprFieldPath}, {}};
2316 }
2317
2318 auto inputComputedPaths = inputFieldPath->getComputedPaths("", renamingVar);
2319 if (inputComputedPaths.renames.empty()) {
2320 return {{exprFieldPath}, {}};
2321 }
2322 invariant(inputComputedPaths.renames.size() == 1u);
2323 auto fieldPathRenameIter = inputComputedPaths.renames.find("");
2324 invariant(fieldPathRenameIter != inputComputedPaths.renames.end());
2325 const auto& oldArrayName = fieldPathRenameIter->second;
2326
2327 auto eachComputedPaths = _each->getComputedPaths(exprFieldPath, _varId);
2328 if (eachComputedPaths.renames.empty()) {
2329 return {{exprFieldPath}, {}};
2330 }
2331
2332 // Append the name of the array to the beginning of the old field path.
2333 for (auto&& rename : eachComputedPaths.renames) {
2334 eachComputedPaths.renames[rename.first] =
2335 FieldPath::getFullyQualifiedPath(oldArrayName, rename.second);
2336 }
2337 return eachComputedPaths;
2338 }
2339
2340 /* ------------------------- ExpressionMeta ----------------------------- */
2341
2342 REGISTER_EXPRESSION(meta, ExpressionMeta::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2343 intrusive_ptr<Expression> ExpressionMeta::parse(
2344 const boost::intrusive_ptr<ExpressionContext>& expCtx,
2345 BSONElement expr,
2346 const VariablesParseState& vpsIn) {
2347 uassert(17307, "$meta only supports string arguments", expr.type() == String);
2348 if (expr.valueStringData() == "textScore") {
2349 return new ExpressionMeta(expCtx, MetaType::TEXT_SCORE);
2350 } else if (expr.valueStringData() == "randVal") {
2351 return new ExpressionMeta(expCtx, MetaType::RAND_VAL);
2352 } else {
2353 uasserted(17308, "Unsupported argument to $meta: " + expr.String());
2354 }
2355 }
2356
ExpressionMeta(const boost::intrusive_ptr<ExpressionContext> & expCtx,MetaType metaType)2357 ExpressionMeta::ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2358 MetaType metaType)
2359 : Expression(expCtx), _metaType(metaType) {}
2360
serialize(bool explain) const2361 Value ExpressionMeta::serialize(bool explain) const {
2362 switch (_metaType) {
2363 case MetaType::TEXT_SCORE:
2364 return Value(DOC("$meta"
2365 << "textScore"_sd));
2366 case MetaType::RAND_VAL:
2367 return Value(DOC("$meta"
2368 << "randVal"_sd));
2369 }
2370 MONGO_UNREACHABLE;
2371 }
2372
evaluate(const Document & root,Variables * variables) const2373 Value ExpressionMeta::evaluate(const Document& root, Variables* variables) const {
2374 switch (_metaType) {
2375 case MetaType::TEXT_SCORE:
2376 return root.hasTextScore() ? Value(root.getTextScore()) : Value();
2377 case MetaType::RAND_VAL:
2378 return root.hasRandMetaField() ? Value(root.getRandMetaField()) : Value();
2379 }
2380 MONGO_UNREACHABLE;
2381 }
2382
_doAddDependencies(DepsTracker * deps) const2383 void ExpressionMeta::_doAddDependencies(DepsTracker* deps) const {
2384 if (_metaType == MetaType::TEXT_SCORE) {
2385 deps->setNeedTextScore(true);
2386 }
2387 }
2388
2389 /* ----------------------- ExpressionMod ---------------------------- */
2390
evaluate(const Document & root,Variables * variables) const2391 Value ExpressionMod::evaluate(const Document& root, Variables* variables) const {
2392 Value lhs = vpOperand[0]->evaluate(root, variables);
2393 Value rhs = vpOperand[1]->evaluate(root, variables);
2394
2395 BSONType leftType = lhs.getType();
2396 BSONType rightType = rhs.getType();
2397
2398 if (lhs.numeric() && rhs.numeric()) {
2399 auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); };
2400
2401 // If either side is decimal, perform the operation in decimal.
2402 if (leftType == NumberDecimal || rightType == NumberDecimal) {
2403 Decimal128 left = lhs.coerceToDecimal();
2404 Decimal128 right = rhs.coerceToDecimal();
2405 assertNonZero(right.isZero());
2406
2407 return Value(left.modulo(right));
2408 }
2409
2410 // ensure we aren't modding by 0
2411 double right = rhs.coerceToDouble();
2412 assertNonZero(right == 0);
2413
2414 if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) {
2415 // Need to do fmod. Integer-valued double case is handled below.
2416
2417 double left = lhs.coerceToDouble();
2418 return Value(fmod(left, right));
2419 }
2420
2421 if (leftType == NumberLong || rightType == NumberLong) {
2422 // if either is long, return long
2423 long long left = lhs.coerceToLong();
2424 long long rightLong = rhs.coerceToLong();
2425 return Value(mongoSafeMod(left, rightLong));
2426 }
2427
2428 // lastly they must both be ints, return int
2429 int left = lhs.coerceToInt();
2430 int rightInt = rhs.coerceToInt();
2431 return Value(mongoSafeMod(left, rightInt));
2432 } else if (lhs.nullish() || rhs.nullish()) {
2433 return Value(BSONNULL);
2434 } else {
2435 uasserted(16611,
2436 str::stream() << "$mod only supports numeric types, not "
2437 << typeName(lhs.getType())
2438 << " and "
2439 << typeName(rhs.getType()));
2440 }
2441 }
2442
2443 REGISTER_EXPRESSION(mod, ExpressionMod::parse);
getOpName() const2444 const char* ExpressionMod::getOpName() const {
2445 return "$mod";
2446 }
2447
2448 /* ------------------------- ExpressionMultiply ----------------------------- */
2449
evaluate(const Document & root,Variables * variables) const2450 Value ExpressionMultiply::evaluate(const Document& root, Variables* variables) const {
2451 /*
2452 We'll try to return the narrowest possible result value. To do that
2453 without creating intermediate Values, do the arithmetic for double
2454 and integral types in parallel, tracking the current narrowest
2455 type.
2456 */
2457 double doubleProduct = 1;
2458 long long longProduct = 1;
2459 Decimal128 decimalProduct; // This will be initialized on encountering the first decimal.
2460
2461 BSONType productType = NumberInt;
2462
2463 const size_t n = vpOperand.size();
2464 for (size_t i = 0; i < n; ++i) {
2465 Value val = vpOperand[i]->evaluate(root, variables);
2466
2467 if (val.numeric()) {
2468 BSONType oldProductType = productType;
2469 productType = Value::getWidestNumeric(productType, val.getType());
2470 if (productType == NumberDecimal) {
2471 // On finding the first decimal, convert the partial product to decimal.
2472 if (oldProductType != NumberDecimal) {
2473 decimalProduct = oldProductType == NumberDouble
2474 ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits)
2475 : Decimal128(static_cast<int64_t>(longProduct));
2476 }
2477 decimalProduct = decimalProduct.multiply(val.coerceToDecimal());
2478 } else {
2479 doubleProduct *= val.coerceToDouble();
2480 if (!std::isfinite(val.coerceToDouble()) ||
2481 mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) {
2482 // The number is either Infinity or NaN, or the 'longProduct' would have
2483 // overflowed, so we're abandoning it.
2484 productType = NumberDouble;
2485 }
2486 }
2487 } else if (val.nullish()) {
2488 return Value(BSONNULL);
2489 } else {
2490 uasserted(16555,
2491 str::stream() << "$multiply only supports numeric types, not "
2492 << typeName(val.getType()));
2493 }
2494 }
2495
2496 if (productType == NumberDouble)
2497 return Value(doubleProduct);
2498 else if (productType == NumberLong)
2499 return Value(longProduct);
2500 else if (productType == NumberInt)
2501 return Value::createIntOrLong(longProduct);
2502 else if (productType == NumberDecimal)
2503 return Value(decimalProduct);
2504 else
2505 massert(16418, "$multiply resulted in a non-numeric type", false);
2506 }
2507
2508 REGISTER_EXPRESSION(multiply, ExpressionMultiply::parse);
getOpName() const2509 const char* ExpressionMultiply::getOpName() const {
2510 return "$multiply";
2511 }
2512
2513 /* ----------------------- ExpressionIfNull ---------------------------- */
2514
evaluate(const Document & root,Variables * variables) const2515 Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
2516 Value pLeft(vpOperand[0]->evaluate(root, variables));
2517 if (!pLeft.nullish())
2518 return pLeft;
2519
2520 Value pRight(vpOperand[1]->evaluate(root, variables));
2521 return pRight;
2522 }
2523
2524 REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse);
getOpName() const2525 const char* ExpressionIfNull::getOpName() const {
2526 return "$ifNull";
2527 }
2528
2529 /* ----------------------- ExpressionIn ---------------------------- */
2530
evaluate(const Document & root,Variables * variables) const2531 Value ExpressionIn::evaluate(const Document& root, Variables* variables) const {
2532 Value argument(vpOperand[0]->evaluate(root, variables));
2533 Value arrayOfValues(vpOperand[1]->evaluate(root, variables));
2534
2535 uassert(40081,
2536 str::stream() << "$in requires an array as a second argument, found: "
2537 << typeName(arrayOfValues.getType()),
2538 arrayOfValues.isArray());
2539 for (auto&& value : arrayOfValues.getArray()) {
2540 if (getExpressionContext()->getValueComparator().evaluate(argument == value)) {
2541 return Value(true);
2542 }
2543 }
2544 return Value(false);
2545 }
2546
2547 REGISTER_EXPRESSION(in, ExpressionIn::parse);
getOpName() const2548 const char* ExpressionIn::getOpName() const {
2549 return "$in";
2550 }
2551
2552 /* ----------------------- ExpressionIndexOfArray ------------------ */
2553
2554 namespace {
2555
uassertIfNotIntegralAndNonNegative(Value val,StringData expressionName,StringData argumentName)2556 void uassertIfNotIntegralAndNonNegative(Value val,
2557 StringData expressionName,
2558 StringData argumentName) {
2559 uassert(40096,
2560 str::stream() << expressionName << "requires an integral " << argumentName
2561 << ", found a value of type: "
2562 << typeName(val.getType())
2563 << ", with value: "
2564 << val.toString(),
2565 val.integral());
2566 uassert(40097,
2567 str::stream() << expressionName << " requires a nonnegative " << argumentName
2568 << ", found: "
2569 << val.toString(),
2570 val.coerceToInt() >= 0);
2571 }
2572
2573 } // namespace
2574
evaluate(const Document & root,Variables * variables) const2575 Value ExpressionIndexOfArray::evaluate(const Document& root, Variables* variables) const {
2576 Value arrayArg = vpOperand[0]->evaluate(root, variables);
2577
2578 if (arrayArg.nullish()) {
2579 return Value(BSONNULL);
2580 }
2581
2582 uassert(40090,
2583 str::stream() << "$indexOfArray requires an array as a first argument, found: "
2584 << typeName(arrayArg.getType()),
2585 arrayArg.isArray());
2586
2587 std::vector<Value> array = arrayArg.getArray();
2588
2589 Value searchItem = vpOperand[1]->evaluate(root, variables);
2590
2591 size_t startIndex = 0;
2592 if (vpOperand.size() > 2) {
2593 Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2594 uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2595 startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2596 }
2597
2598 size_t endIndex = array.size();
2599 if (vpOperand.size() > 3) {
2600 Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2601 uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2602 // Don't let 'endIndex' exceed the length of the array.
2603 endIndex = std::min(array.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
2604 }
2605
2606 for (size_t i = startIndex; i < endIndex; i++) {
2607 if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) {
2608 return Value(static_cast<int>(i));
2609 }
2610 }
2611
2612 return Value(-1);
2613 }
2614
2615 REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse);
getOpName() const2616 const char* ExpressionIndexOfArray::getOpName() const {
2617 return "$indexOfArray";
2618 }
2619
2620 /* ----------------------- ExpressionIndexOfBytes ------------------ */
2621
2622 namespace {
2623
stringHasTokenAtIndex(size_t index,const std::string & input,const std::string & token)2624 bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) {
2625 if (token.size() + index > input.size()) {
2626 return false;
2627 }
2628 return input.compare(index, token.size(), token) == 0;
2629 }
2630
2631 } // namespace
2632
evaluate(const Document & root,Variables * variables) const2633 Value ExpressionIndexOfBytes::evaluate(const Document& root, Variables* variables) const {
2634 Value stringArg = vpOperand[0]->evaluate(root, variables);
2635
2636 if (stringArg.nullish()) {
2637 return Value(BSONNULL);
2638 }
2639
2640 uassert(40091,
2641 str::stream() << "$indexOfBytes requires a string as the first argument, found: "
2642 << typeName(stringArg.getType()),
2643 stringArg.getType() == String);
2644 const std::string& input = stringArg.getString();
2645
2646 Value tokenArg = vpOperand[1]->evaluate(root, variables);
2647 uassert(40092,
2648 str::stream() << "$indexOfBytes requires a string as the second argument, found: "
2649 << typeName(tokenArg.getType()),
2650 tokenArg.getType() == String);
2651 const std::string& token = tokenArg.getString();
2652
2653 size_t startIndex = 0;
2654 if (vpOperand.size() > 2) {
2655 Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2656 uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2657 startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2658 }
2659
2660 size_t endIndex = input.size();
2661 if (vpOperand.size() > 3) {
2662 Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2663 uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2664 // Don't let 'endIndex' exceed the length of the string.
2665 endIndex = std::min(input.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
2666 }
2667
2668 if (startIndex > input.length() || endIndex < startIndex) {
2669 return Value(-1);
2670 }
2671
2672 size_t position = input.substr(0, endIndex).find(token, startIndex);
2673 if (position == std::string::npos) {
2674 return Value(-1);
2675 }
2676
2677 return Value(static_cast<int>(position));
2678 }
2679
2680 REGISTER_EXPRESSION(indexOfBytes, ExpressionIndexOfBytes::parse);
getOpName() const2681 const char* ExpressionIndexOfBytes::getOpName() const {
2682 return "$indexOfBytes";
2683 }
2684
2685 /* ----------------------- ExpressionIndexOfCP --------------------- */
2686
evaluate(const Document & root,Variables * variables) const2687 Value ExpressionIndexOfCP::evaluate(const Document& root, Variables* variables) const {
2688 Value stringArg = vpOperand[0]->evaluate(root, variables);
2689
2690 if (stringArg.nullish()) {
2691 return Value(BSONNULL);
2692 }
2693
2694 uassert(40093,
2695 str::stream() << "$indexOfCP requires a string as the first argument, found: "
2696 << typeName(stringArg.getType()),
2697 stringArg.getType() == String);
2698 const std::string& input = stringArg.getString();
2699
2700 Value tokenArg = vpOperand[1]->evaluate(root, variables);
2701 uassert(40094,
2702 str::stream() << "$indexOfCP requires a string as the second argument, found: "
2703 << typeName(tokenArg.getType()),
2704 tokenArg.getType() == String);
2705 const std::string& token = tokenArg.getString();
2706
2707 size_t startCodePointIndex = 0;
2708 if (vpOperand.size() > 2) {
2709 Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2710 uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2711 startCodePointIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2712 }
2713
2714 // Compute the length (in code points) of the input, and convert 'startCodePointIndex' to a byte
2715 // index.
2716 size_t codePointLength = 0;
2717 size_t startByteIndex = 0;
2718 for (size_t byteIx = 0; byteIx < input.size(); ++codePointLength) {
2719 if (codePointLength == startCodePointIndex) {
2720 // We have determined the byte at which our search will start.
2721 startByteIndex = byteIx;
2722 }
2723
2724 uassert(40095,
2725 "$indexOfCP found bad UTF-8 in the input",
2726 !str::isUTF8ContinuationByte(input[byteIx]));
2727 byteIx += getCodePointLength(input[byteIx]);
2728 }
2729
2730 size_t endCodePointIndex = codePointLength;
2731 if (vpOperand.size() > 3) {
2732 Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2733 uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2734
2735 // Don't let 'endCodePointIndex' exceed the number of code points in the string.
2736 endCodePointIndex =
2737 std::min(codePointLength, static_cast<size_t>(endIndexArg.coerceToInt()));
2738 }
2739
2740 if (startByteIndex == 0 && input.empty() && token.empty()) {
2741 // If we are finding the index of "" in the string "", the below loop will not loop, so we
2742 // need a special case for this.
2743 return Value(0);
2744 }
2745
2746 // We must keep track of which byte, and which code point, we are examining, being careful not
2747 // to overflow either the length of the string or the ending code point.
2748
2749 size_t currentCodePointIndex = startCodePointIndex;
2750 for (size_t byteIx = startByteIndex; currentCodePointIndex < endCodePointIndex;
2751 ++currentCodePointIndex) {
2752 if (stringHasTokenAtIndex(byteIx, input, token)) {
2753 return Value(static_cast<int>(currentCodePointIndex));
2754 }
2755 byteIx += getCodePointLength(input[byteIx]);
2756 }
2757
2758 return Value(-1);
2759 }
2760
2761 REGISTER_EXPRESSION(indexOfCP, ExpressionIndexOfCP::parse);
getOpName() const2762 const char* ExpressionIndexOfCP::getOpName() const {
2763 return "$indexOfCP";
2764 }
2765
2766 /* ----------------------- ExpressionLn ---------------------------- */
2767
evaluateNumericArg(const Value & numericArg) const2768 Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const {
2769 if (numericArg.getType() == NumberDecimal) {
2770 Decimal128 argDecimal = numericArg.getDecimal();
2771 if (argDecimal.isGreater(Decimal128::kNormalizedZero))
2772 return Value(argDecimal.logarithm());
2773 // Fall through for error case.
2774 }
2775 double argDouble = numericArg.coerceToDouble();
2776 uassert(28766,
2777 str::stream() << "$ln's argument must be a positive number, but is " << argDouble,
2778 argDouble > 0 || std::isnan(argDouble));
2779 return Value(std::log(argDouble));
2780 }
2781
2782 REGISTER_EXPRESSION(ln, ExpressionLn::parse);
getOpName() const2783 const char* ExpressionLn::getOpName() const {
2784 return "$ln";
2785 }
2786
2787 /* ----------------------- ExpressionLog ---------------------------- */
2788
evaluate(const Document & root,Variables * variables) const2789 Value ExpressionLog::evaluate(const Document& root, Variables* variables) const {
2790 Value argVal = vpOperand[0]->evaluate(root, variables);
2791 Value baseVal = vpOperand[1]->evaluate(root, variables);
2792 if (argVal.nullish() || baseVal.nullish())
2793 return Value(BSONNULL);
2794
2795 uassert(28756,
2796 str::stream() << "$log's argument must be numeric, not " << typeName(argVal.getType()),
2797 argVal.numeric());
2798 uassert(28757,
2799 str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()),
2800 baseVal.numeric());
2801
2802 if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) {
2803 Decimal128 argDecimal = argVal.coerceToDecimal();
2804 Decimal128 baseDecimal = baseVal.coerceToDecimal();
2805
2806 if (argDecimal.isGreater(Decimal128::kNormalizedZero) &&
2807 baseDecimal.isNotEqual(Decimal128(1)) &&
2808 baseDecimal.isGreater(Decimal128::kNormalizedZero)) {
2809 return Value(argDecimal.logarithm(baseDecimal));
2810 }
2811 // Fall through for error cases.
2812 }
2813
2814 double argDouble = argVal.coerceToDouble();
2815 double baseDouble = baseVal.coerceToDouble();
2816 uassert(28758,
2817 str::stream() << "$log's argument must be a positive number, but is " << argDouble,
2818 argDouble > 0 || std::isnan(argDouble));
2819 uassert(28759,
2820 str::stream() << "$log's base must be a positive number not equal to 1, but is "
2821 << baseDouble,
2822 (baseDouble > 0 && baseDouble != 1) || std::isnan(baseDouble));
2823 return Value(std::log(argDouble) / std::log(baseDouble));
2824 }
2825
2826 REGISTER_EXPRESSION(log, ExpressionLog::parse);
getOpName() const2827 const char* ExpressionLog::getOpName() const {
2828 return "$log";
2829 }
2830
2831 /* ----------------------- ExpressionLog10 ---------------------------- */
2832
evaluateNumericArg(const Value & numericArg) const2833 Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const {
2834 if (numericArg.getType() == NumberDecimal) {
2835 Decimal128 argDecimal = numericArg.getDecimal();
2836 if (argDecimal.isGreater(Decimal128::kNormalizedZero))
2837 return Value(argDecimal.logarithm(Decimal128(10)));
2838 // Fall through for error case.
2839 }
2840
2841 double argDouble = numericArg.coerceToDouble();
2842 uassert(28761,
2843 str::stream() << "$log10's argument must be a positive number, but is " << argDouble,
2844 argDouble > 0 || std::isnan(argDouble));
2845 return Value(std::log10(argDouble));
2846 }
2847
2848 REGISTER_EXPRESSION(log10, ExpressionLog10::parse);
getOpName() const2849 const char* ExpressionLog10::getOpName() const {
2850 return "$log10";
2851 }
2852
2853 /* ------------------------ ExpressionNary ----------------------------- */
2854
2855 /**
2856 * Optimize a general Nary expression.
2857 *
2858 * The optimization has the following properties:
2859 * 1) Optimize each of the operators.
2860 * 2) If the operand is associative, flatten internal operators of the same type. I.e.:
2861 * A+B+(C+D)+E => A+B+C+D+E
2862 * 3) If the operand is commutative & associative, group all constant operators. For example:
2863 * c1 + c2 + n1 + c3 + n2 => n1 + n2 + c1 + c2 + c3
2864 * 4) If the operand is associative, execute the operation over all the contiguous constant
2865 * operators and replacing them by the result. For example: c1 + c2 + n1 + c3 + c4 + n5 =>
2866 * c5 = c1 + c2, c6 = c3 + c4 => c5 + n1 + c6 + n5
2867 *
2868 * It returns the optimized expression. It can be exactly the same expression, a modified version
2869 * of the same expression or a completely different expression.
2870 */
optimize()2871 intrusive_ptr<Expression> ExpressionNary::optimize() {
2872 uint32_t constOperandCount = 0;
2873
2874 for (auto& operand : vpOperand) {
2875 operand = operand->optimize();
2876 if (dynamic_cast<ExpressionConstant*>(operand.get())) {
2877 ++constOperandCount;
2878 }
2879 }
2880 // If all the operands are constant expressions, collapse the expression into one constant
2881 // expression.
2882 if (constOperandCount == vpOperand.size()) {
2883 return intrusive_ptr<Expression>(ExpressionConstant::create(
2884 getExpressionContext(), evaluate(Document(), &(getExpressionContext()->variables))));
2885 }
2886
2887 // If the expression is associative, we can collapse all the consecutive constant operands into
2888 // one by applying the expression to those consecutive constant operands.
2889 // If the expression is also commutative we can reorganize all the operands so that all of the
2890 // constant ones are together (arbitrarily at the back) and we can collapse all of them into
2891 // one.
2892 if (isAssociative()) {
2893 ExpressionVector constExpressions;
2894 ExpressionVector optimizedOperands;
2895 for (size_t i = 0; i < vpOperand.size();) {
2896 intrusive_ptr<Expression> operand = vpOperand[i];
2897 // If the operand is a constant one, add it to the current list of consecutive constant
2898 // operands.
2899 if (dynamic_cast<ExpressionConstant*>(operand.get())) {
2900 constExpressions.push_back(operand);
2901 ++i;
2902 continue;
2903 }
2904
2905 // If the operand is exactly the same type as the one we are currently optimizing and
2906 // is also associative, replace the expression for the operands it has.
2907 // E.g: sum(a, b, sum(c, d), e) => sum(a, b, c, d, e)
2908 ExpressionNary* nary = dynamic_cast<ExpressionNary*>(operand.get());
2909 if (nary && str::equals(nary->getOpName(), getOpName()) && nary->isAssociative()) {
2910 invariant(!nary->vpOperand.empty());
2911 vpOperand[i] = std::move(nary->vpOperand[0]);
2912 vpOperand.insert(
2913 vpOperand.begin() + i + 1, nary->vpOperand.begin() + 1, nary->vpOperand.end());
2914 continue;
2915 }
2916
2917 // If the operand is not a constant nor a same-type expression and the expression is
2918 // not commutative, evaluate an expression of the same type as the one we are
2919 // optimizing on the list of consecutive constant operands and use the resulting value
2920 // as a constant expression operand.
2921 // If the list of consecutive constant operands has less than 2 operands just place
2922 // back the operands.
2923 if (!isCommutative()) {
2924 if (constExpressions.size() > 1) {
2925 ExpressionVector vpOperandSave = std::move(vpOperand);
2926 vpOperand = std::move(constExpressions);
2927 optimizedOperands.emplace_back(ExpressionConstant::create(
2928 getExpressionContext(),
2929 evaluate(Document(), &getExpressionContext()->variables)));
2930 vpOperand = std::move(vpOperandSave);
2931 } else {
2932 optimizedOperands.insert(
2933 optimizedOperands.end(), constExpressions.begin(), constExpressions.end());
2934 }
2935 constExpressions.clear();
2936 }
2937 optimizedOperands.push_back(operand);
2938 ++i;
2939 }
2940
2941 if (constExpressions.size() > 1) {
2942 vpOperand = std::move(constExpressions);
2943 optimizedOperands.emplace_back(ExpressionConstant::create(
2944 getExpressionContext(),
2945 evaluate(Document(), &(getExpressionContext()->variables))));
2946 } else {
2947 optimizedOperands.insert(
2948 optimizedOperands.end(), constExpressions.begin(), constExpressions.end());
2949 }
2950
2951 vpOperand = std::move(optimizedOperands);
2952 }
2953 return this;
2954 }
2955
_doAddDependencies(DepsTracker * deps) const2956 void ExpressionNary::_doAddDependencies(DepsTracker* deps) const {
2957 for (auto&& operand : vpOperand) {
2958 operand->addDependencies(deps);
2959 }
2960 }
2961
addOperand(const intrusive_ptr<Expression> & pExpression)2962 void ExpressionNary::addOperand(const intrusive_ptr<Expression>& pExpression) {
2963 vpOperand.push_back(pExpression);
2964 }
2965
serialize(bool explain) const2966 Value ExpressionNary::serialize(bool explain) const {
2967 const size_t nOperand = vpOperand.size();
2968 vector<Value> array;
2969 /* build up the array */
2970 for (size_t i = 0; i < nOperand; i++)
2971 array.push_back(vpOperand[i]->serialize(explain));
2972
2973 return Value(DOC(getOpName() << array));
2974 }
2975
2976 /* ------------------------- ExpressionNot ----------------------------- */
2977
evaluate(const Document & root,Variables * variables) const2978 Value ExpressionNot::evaluate(const Document& root, Variables* variables) const {
2979 Value pOp(vpOperand[0]->evaluate(root, variables));
2980
2981 bool b = pOp.coerceToBool();
2982 return Value(!b);
2983 }
2984
2985 REGISTER_EXPRESSION(not, ExpressionNot::parse);
getOpName() const2986 const char* ExpressionNot::getOpName() const {
2987 return "$not";
2988 }
2989
2990 /* -------------------------- ExpressionOr ----------------------------- */
2991
evaluate(const Document & root,Variables * variables) const2992 Value ExpressionOr::evaluate(const Document& root, Variables* variables) const {
2993 const size_t n = vpOperand.size();
2994 for (size_t i = 0; i < n; ++i) {
2995 Value pValue(vpOperand[i]->evaluate(root, variables));
2996 if (pValue.coerceToBool())
2997 return Value(true);
2998 }
2999
3000 return Value(false);
3001 }
3002
optimize()3003 intrusive_ptr<Expression> ExpressionOr::optimize() {
3004 /* optimize the disjunction as much as possible */
3005 intrusive_ptr<Expression> pE(ExpressionNary::optimize());
3006
3007 /* if the result isn't a disjunction, we can't do anything */
3008 ExpressionOr* pOr = dynamic_cast<ExpressionOr*>(pE.get());
3009 if (!pOr)
3010 return pE;
3011
3012 /*
3013 Check the last argument on the result; if it's not constant (as
3014 promised by ExpressionNary::optimize(),) then there's nothing
3015 we can do.
3016 */
3017 const size_t n = pOr->vpOperand.size();
3018 // ExpressionNary::optimize() generates an ExpressionConstant for {$or:[]}.
3019 verify(n > 0);
3020 intrusive_ptr<Expression> pLast(pOr->vpOperand[n - 1]);
3021 const ExpressionConstant* pConst = dynamic_cast<ExpressionConstant*>(pLast.get());
3022 if (!pConst)
3023 return pE;
3024
3025 /*
3026 Evaluate and coerce the last argument to a boolean. If it's true,
3027 then we can replace this entire expression.
3028 */
3029 bool last = pConst->getValue().coerceToBool();
3030 if (last) {
3031 intrusive_ptr<ExpressionConstant> pFinal(
3032 ExpressionConstant::create(getExpressionContext(), Value(true)));
3033 return pFinal;
3034 }
3035
3036 /*
3037 If we got here, the final operand was false, so we don't need it
3038 anymore. If there was only one other operand, we don't need the
3039 conjunction either. Note we still need to keep the promise that
3040 the result will be a boolean.
3041 */
3042 if (n == 2) {
3043 intrusive_ptr<Expression> pFinal(
3044 ExpressionCoerceToBool::create(getExpressionContext(), pOr->vpOperand[0]));
3045 return pFinal;
3046 }
3047
3048 /*
3049 Remove the final "false" value, and return the new expression.
3050 */
3051 pOr->vpOperand.resize(n - 1);
3052 return pE;
3053 }
3054
3055 REGISTER_EXPRESSION(or, ExpressionOr::parse);
getOpName() const3056 const char* ExpressionOr::getOpName() const {
3057 return "$or";
3058 }
3059
3060 namespace {
3061 /**
3062 * Helper for ExpressionPow to determine whether base^exp can be represented in a 64 bit int.
3063 *
3064 *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63].
3065 */
representableAsLong(long long base,long long exp)3066 bool representableAsLong(long long base, long long exp) {
3067 invariant(exp <= 63);
3068 invariant(exp >= 0);
3069 struct MinMax {
3070 long long min;
3071 long long max;
3072 };
3073
3074 // Array indices correspond to exponents 0 through 63. The values in each index are the min
3075 // and max bases, respectively, that can be raised to that exponent without overflowing a
3076 // 64-bit int. For max bases, this was computed by solving for b in
3077 // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even
3078 // exps the equation used was b = (2^63-1)^(1/exp), and for odd exps the equation used was
3079 // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the
3080 // magnitude of some of the min bases raised to odd exps is greater than the corresponding
3081 // max bases raised to the same exponents.
3082
3083 static const MinMax kBaseLimits[] = {
3084 {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()}, // 0
3085 {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()},
3086 {-3037000499LL, 3037000499LL},
3087 {-2097152, 2097151},
3088 {-55108, 55108},
3089 {-6208, 6208},
3090 {-1448, 1448},
3091 {-512, 511},
3092 {-234, 234},
3093 {-128, 127},
3094 {-78, 78}, // 10
3095 {-52, 52},
3096 {-38, 38},
3097 {-28, 28},
3098 {-22, 22},
3099 {-18, 18},
3100 {-15, 15},
3101 {-13, 13},
3102 {-11, 11},
3103 {-9, 9},
3104 {-8, 8}, // 20
3105 {-8, 7},
3106 {-7, 7},
3107 {-6, 6},
3108 {-6, 6},
3109 {-5, 5},
3110 {-5, 5},
3111 {-5, 5},
3112 {-4, 4},
3113 {-4, 4},
3114 {-4, 4}, // 30
3115 {-4, 4},
3116 {-3, 3},
3117 {-3, 3},
3118 {-3, 3},
3119 {-3, 3},
3120 {-3, 3},
3121 {-3, 3},
3122 {-3, 3},
3123 {-3, 3},
3124 {-2, 2}, // 40
3125 {-2, 2},
3126 {-2, 2},
3127 {-2, 2},
3128 {-2, 2},
3129 {-2, 2},
3130 {-2, 2},
3131 {-2, 2},
3132 {-2, 2},
3133 {-2, 2},
3134 {-2, 2}, // 50
3135 {-2, 2},
3136 {-2, 2},
3137 {-2, 2},
3138 {-2, 2},
3139 {-2, 2},
3140 {-2, 2},
3141 {-2, 2},
3142 {-2, 2},
3143 {-2, 2},
3144 {-2, 2}, // 60
3145 {-2, 2},
3146 {-2, 2},
3147 {-2, 1}};
3148
3149 return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max;
3150 };
3151 } // namespace
3152
3153 /* ----------------------- ExpressionPow ---------------------------- */
3154
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,Value base,Value exp)3155 intrusive_ptr<Expression> ExpressionPow::create(
3156 const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp) {
3157 intrusive_ptr<ExpressionPow> expr(new ExpressionPow(expCtx));
3158 expr->vpOperand.push_back(
3159 ExpressionConstant::create(expr->getExpressionContext(), std::move(base)));
3160 expr->vpOperand.push_back(
3161 ExpressionConstant::create(expr->getExpressionContext(), std::move(exp)));
3162 return expr;
3163 }
3164
evaluate(const Document & root,Variables * variables) const3165 Value ExpressionPow::evaluate(const Document& root, Variables* variables) const {
3166 Value baseVal = vpOperand[0]->evaluate(root, variables);
3167 Value expVal = vpOperand[1]->evaluate(root, variables);
3168 if (baseVal.nullish() || expVal.nullish())
3169 return Value(BSONNULL);
3170
3171 BSONType baseType = baseVal.getType();
3172 BSONType expType = expVal.getType();
3173
3174 uassert(28762,
3175 str::stream() << "$pow's base must be numeric, not " << typeName(baseType),
3176 baseVal.numeric());
3177 uassert(28763,
3178 str::stream() << "$pow's exponent must be numeric, not " << typeName(expType),
3179 expVal.numeric());
3180
3181 auto checkNonZeroAndNeg = [](bool isZeroAndNeg) {
3182 uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg);
3183 };
3184
3185 // If either argument is decimal, return a decimal.
3186 if (baseType == NumberDecimal || expType == NumberDecimal) {
3187 Decimal128 baseDecimal = baseVal.coerceToDecimal();
3188 Decimal128 expDecimal = expVal.coerceToDecimal();
3189 checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative());
3190 return Value(baseDecimal.power(expDecimal));
3191 }
3192
3193 // pow() will cast args to doubles.
3194 double baseDouble = baseVal.coerceToDouble();
3195 double expDouble = expVal.coerceToDouble();
3196 checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0);
3197
3198 // If either argument is a double, return a double.
3199 if (baseType == NumberDouble || expType == NumberDouble) {
3200 return Value(std::pow(baseDouble, expDouble));
3201 }
3202
3203 // If either number is a long, return a long. If both numbers are ints, then return an int if
3204 // the result fits or a long if it is too big.
3205 const auto formatResult = [baseType, expType](long long res) {
3206 if (baseType == NumberLong || expType == NumberLong) {
3207 return Value(res);
3208 }
3209 return Value::createIntOrLong(res);
3210 };
3211
3212 const long long baseLong = baseVal.getLong();
3213 const long long expLong = expVal.getLong();
3214
3215 // Use this when the result cannot be represented as a long.
3216 const auto computeDoubleResult = [baseLong, expLong]() {
3217 return Value(std::pow(baseLong, expLong));
3218 };
3219
3220 // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1.
3221 if (baseLong == 0) {
3222 if (expLong == 0) {
3223 // 0^0 = 1.
3224 return formatResult(1);
3225 } else if (expLong > 0) {
3226 // 0^x where x > 0 is 0.
3227 return formatResult(0);
3228 }
3229
3230 // We should have checked earlier that 0 to a negative power is banned.
3231 MONGO_UNREACHABLE;
3232 } else if (baseLong == 1) {
3233 return formatResult(1);
3234 } else if (baseLong == -1) {
3235 // -1^0 = -1^2 = -1^4 = -1^6 ... = 1
3236 // -1^1 = -1^3 = -1^5 = -1^7 ... = -1
3237 return formatResult((expLong % 2 == 0) ? 1 : -1);
3238 } else if (expLong > 63 || expLong < 0) {
3239 // If the base is not 0, 1, or -1 and the exponent is too large, or negative,
3240 // the result cannot be represented as a long.
3241 return computeDoubleResult();
3242 }
3243
3244 // It's still possible that the result cannot be represented as a long. If that's the case,
3245 // return a double.
3246 if (!representableAsLong(baseLong, expLong)) {
3247 return computeDoubleResult();
3248 }
3249
3250 // Use repeated multiplication, since pow() casts args to doubles which could result in
3251 // loss of precision if arguments are very large.
3252 const auto computeWithRepeatedMultiplication = [](long long base, long long exp) {
3253 long long result = 1;
3254
3255 while (exp > 1) {
3256 if (exp % 2 == 1) {
3257 result *= base;
3258 exp--;
3259 }
3260 // 'exp' is now guaranteed to be even.
3261 base *= base;
3262 exp /= 2;
3263 }
3264
3265 if (exp) {
3266 invariant(exp == 1);
3267 result *= base;
3268 }
3269
3270 return result;
3271 };
3272
3273 return formatResult(computeWithRepeatedMultiplication(baseLong, expLong));
3274 }
3275
3276 REGISTER_EXPRESSION(pow, ExpressionPow::parse);
getOpName() const3277 const char* ExpressionPow::getOpName() const {
3278 return "$pow";
3279 }
3280
3281 /* ------------------------- ExpressionRange ------------------------------ */
3282
evaluate(const Document & root,Variables * variables) const3283 Value ExpressionRange::evaluate(const Document& root, Variables* variables) const {
3284 Value startVal(vpOperand[0]->evaluate(root, variables));
3285 Value endVal(vpOperand[1]->evaluate(root, variables));
3286
3287 uassert(34443,
3288 str::stream() << "$range requires a numeric starting value, found value of type: "
3289 << typeName(startVal.getType()),
3290 startVal.numeric());
3291 uassert(34444,
3292 str::stream() << "$range requires a starting value that can be represented as a 32-bit "
3293 "integer, found value: "
3294 << startVal.toString(),
3295 startVal.integral());
3296 uassert(34445,
3297 str::stream() << "$range requires a numeric ending value, found value of type: "
3298 << typeName(endVal.getType()),
3299 endVal.numeric());
3300 uassert(34446,
3301 str::stream() << "$range requires an ending value that can be represented as a 32-bit "
3302 "integer, found value: "
3303 << endVal.toString(),
3304 endVal.integral());
3305
3306 int current = startVal.coerceToInt();
3307 int end = endVal.coerceToInt();
3308
3309 int step = 1;
3310 if (vpOperand.size() == 3) {
3311 // A step was specified by the user.
3312 Value stepVal(vpOperand[2]->evaluate(root, variables));
3313
3314 uassert(34447,
3315 str::stream() << "$range requires a numeric step value, found value of type:"
3316 << typeName(stepVal.getType()),
3317 stepVal.numeric());
3318 uassert(34448,
3319 str::stream() << "$range requires a step value that can be represented as a 32-bit "
3320 "integer, found value: "
3321 << stepVal.toString(),
3322 stepVal.integral());
3323 step = stepVal.coerceToInt();
3324
3325 uassert(34449, "$range requires a non-zero step value", step != 0);
3326 }
3327
3328 std::vector<Value> output;
3329
3330 while ((step > 0 ? current < end : current > end)) {
3331 output.push_back(Value(current));
3332 current += step;
3333 }
3334
3335 return Value(output);
3336 }
3337
3338 REGISTER_EXPRESSION(range, ExpressionRange::parse);
getOpName() const3339 const char* ExpressionRange::getOpName() const {
3340 return "$range";
3341 }
3342
3343 /* ------------------------ ExpressionReduce ------------------------------ */
3344
3345 REGISTER_EXPRESSION(reduce, ExpressionReduce::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)3346 intrusive_ptr<Expression> ExpressionReduce::parse(
3347 const boost::intrusive_ptr<ExpressionContext>& expCtx,
3348 BSONElement expr,
3349 const VariablesParseState& vps) {
3350 uassert(40075,
3351 str::stream() << "$reduce requires an object as an argument, found: "
3352 << typeName(expr.type()),
3353 expr.type() == Object);
3354
3355 intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce(expCtx));
3356
3357 // vpsSub is used only to parse 'in', which must have access to $$this and $$value.
3358 VariablesParseState vpsSub(vps);
3359 reduce->_thisVar = vpsSub.defineVariable("this");
3360 reduce->_valueVar = vpsSub.defineVariable("value");
3361
3362 for (auto&& elem : expr.Obj()) {
3363 auto field = elem.fieldNameStringData();
3364
3365 if (field == "input") {
3366 reduce->_input = parseOperand(expCtx, elem, vps);
3367 } else if (field == "initialValue") {
3368 reduce->_initial = parseOperand(expCtx, elem, vps);
3369 } else if (field == "in") {
3370 reduce->_in = parseOperand(expCtx, elem, vpsSub);
3371 } else {
3372 uasserted(40076, str::stream() << "$reduce found an unknown argument: " << field);
3373 }
3374 }
3375
3376 uassert(40077, "$reduce requires 'input' to be specified", reduce->_input);
3377 uassert(40078, "$reduce requires 'initialValue' to be specified", reduce->_initial);
3378 uassert(40079, "$reduce requires 'in' to be specified", reduce->_in);
3379
3380 return reduce;
3381 }
3382
evaluate(const Document & root,Variables * variables) const3383 Value ExpressionReduce::evaluate(const Document& root, Variables* variables) const {
3384 Value inputVal = _input->evaluate(root, variables);
3385
3386 if (inputVal.nullish()) {
3387 return Value(BSONNULL);
3388 }
3389
3390 uassert(40080,
3391 str::stream() << "$reduce requires that 'input' be an array, found: "
3392 << inputVal.toString(),
3393 inputVal.isArray());
3394
3395 Value accumulatedValue = _initial->evaluate(root, variables);
3396
3397 for (auto&& elem : inputVal.getArray()) {
3398 variables->setValue(_thisVar, elem);
3399 variables->setValue(_valueVar, accumulatedValue);
3400
3401 accumulatedValue = _in->evaluate(root, variables);
3402 }
3403
3404 return accumulatedValue;
3405 }
3406
optimize()3407 intrusive_ptr<Expression> ExpressionReduce::optimize() {
3408 _input = _input->optimize();
3409 _initial = _initial->optimize();
3410 _in = _in->optimize();
3411 return this;
3412 }
3413
_doAddDependencies(DepsTracker * deps) const3414 void ExpressionReduce::_doAddDependencies(DepsTracker* deps) const {
3415 _input->addDependencies(deps);
3416 _initial->addDependencies(deps);
3417 _in->addDependencies(deps);
3418 }
3419
serialize(bool explain) const3420 Value ExpressionReduce::serialize(bool explain) const {
3421 return Value(Document{{"$reduce",
3422 Document{{"input", _input->serialize(explain)},
3423 {"initialValue", _initial->serialize(explain)},
3424 {"in", _in->serialize(explain)}}}});
3425 }
3426
3427 /* ------------------------ ExpressionReverseArray ------------------------ */
3428
evaluate(const Document & root,Variables * variables) const3429 Value ExpressionReverseArray::evaluate(const Document& root, Variables* variables) const {
3430 Value input(vpOperand[0]->evaluate(root, variables));
3431
3432 if (input.nullish()) {
3433 return Value(BSONNULL);
3434 }
3435
3436 uassert(34435,
3437 str::stream() << "The argument to $reverseArray must be an array, but was of type: "
3438 << typeName(input.getType()),
3439 input.isArray());
3440
3441 if (input.getArrayLength() < 2) {
3442 return input;
3443 }
3444
3445 std::vector<Value> array = input.getArray();
3446 std::reverse(array.begin(), array.end());
3447 return Value(array);
3448 }
3449
3450 REGISTER_EXPRESSION(reverseArray, ExpressionReverseArray::parse);
getOpName() const3451 const char* ExpressionReverseArray::getOpName() const {
3452 return "$reverseArray";
3453 }
3454
3455 namespace {
arrayToSet(const Value & val,const ValueComparator & valueComparator)3456 ValueSet arrayToSet(const Value& val, const ValueComparator& valueComparator) {
3457 const vector<Value>& array = val.getArray();
3458 ValueSet valueSet = valueComparator.makeOrderedValueSet();
3459 valueSet.insert(array.begin(), array.end());
3460 return valueSet;
3461 }
3462 } // namespace
3463
3464 /* ----------------------- ExpressionSetDifference ---------------------------- */
3465
evaluate(const Document & root,Variables * variables) const3466 Value ExpressionSetDifference::evaluate(const Document& root, Variables* variables) const {
3467 const Value lhs = vpOperand[0]->evaluate(root, variables);
3468 const Value rhs = vpOperand[1]->evaluate(root, variables);
3469
3470 if (lhs.nullish() || rhs.nullish()) {
3471 return Value(BSONNULL);
3472 }
3473
3474 uassert(17048,
3475 str::stream() << "both operands of $setDifference must be arrays. First "
3476 << "argument is of type: "
3477 << typeName(lhs.getType()),
3478 lhs.isArray());
3479 uassert(17049,
3480 str::stream() << "both operands of $setDifference must be arrays. Second "
3481 << "argument is of type: "
3482 << typeName(rhs.getType()),
3483 rhs.isArray());
3484
3485 ValueSet rhsSet = arrayToSet(rhs, getExpressionContext()->getValueComparator());
3486 const vector<Value>& lhsArray = lhs.getArray();
3487 vector<Value> returnVec;
3488
3489 for (vector<Value>::const_iterator it = lhsArray.begin(); it != lhsArray.end(); ++it) {
3490 // rhsSet serves the dual role of filtering out elements that were originally present
3491 // in RHS and of eleminating duplicates from LHS
3492 if (rhsSet.insert(*it).second) {
3493 returnVec.push_back(*it);
3494 }
3495 }
3496 return Value(std::move(returnVec));
3497 }
3498
3499 REGISTER_EXPRESSION(setDifference, ExpressionSetDifference::parse);
getOpName() const3500 const char* ExpressionSetDifference::getOpName() const {
3501 return "$setDifference";
3502 }
3503
3504 /* ----------------------- ExpressionSetEquals ---------------------------- */
3505
validateArguments(const ExpressionVector & args) const3506 void ExpressionSetEquals::validateArguments(const ExpressionVector& args) const {
3507 uassert(17045,
3508 str::stream() << "$setEquals needs at least two arguments had: " << args.size(),
3509 args.size() >= 2);
3510 }
3511
evaluate(const Document & root,Variables * variables) const3512 Value ExpressionSetEquals::evaluate(const Document& root, Variables* variables) const {
3513 const size_t n = vpOperand.size();
3514 const auto& valueComparator = getExpressionContext()->getValueComparator();
3515 ValueSet lhs = valueComparator.makeOrderedValueSet();
3516
3517 for (size_t i = 0; i < n; i++) {
3518 const Value nextEntry = vpOperand[i]->evaluate(root, variables);
3519 uassert(17044,
3520 str::stream() << "All operands of $setEquals must be arrays. One "
3521 << "argument is of type: "
3522 << typeName(nextEntry.getType()),
3523 nextEntry.isArray());
3524
3525 if (i == 0) {
3526 lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3527 } else {
3528 ValueSet rhs = valueComparator.makeOrderedValueSet();
3529 rhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3530 if (lhs.size() != rhs.size()) {
3531 return Value(false);
3532 }
3533
3534 if (!std::equal(lhs.begin(), lhs.end(), rhs.begin(), valueComparator.getEqualTo())) {
3535 return Value(false);
3536 }
3537 }
3538 }
3539 return Value(true);
3540 }
3541
3542 REGISTER_EXPRESSION(setEquals, ExpressionSetEquals::parse);
getOpName() const3543 const char* ExpressionSetEquals::getOpName() const {
3544 return "$setEquals";
3545 }
3546
3547 /* ----------------------- ExpressionSetIntersection ---------------------------- */
3548
evaluate(const Document & root,Variables * variables) const3549 Value ExpressionSetIntersection::evaluate(const Document& root, Variables* variables) const {
3550 const size_t n = vpOperand.size();
3551 const auto& valueComparator = getExpressionContext()->getValueComparator();
3552 ValueSet currentIntersection = valueComparator.makeOrderedValueSet();
3553 for (size_t i = 0; i < n; i++) {
3554 const Value nextEntry = vpOperand[i]->evaluate(root, variables);
3555 if (nextEntry.nullish()) {
3556 return Value(BSONNULL);
3557 }
3558 uassert(17047,
3559 str::stream() << "All operands of $setIntersection must be arrays. One "
3560 << "argument is of type: "
3561 << typeName(nextEntry.getType()),
3562 nextEntry.isArray());
3563
3564 if (i == 0) {
3565 currentIntersection.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3566 } else {
3567 ValueSet nextSet = arrayToSet(nextEntry, valueComparator);
3568 if (currentIntersection.size() > nextSet.size()) {
3569 // to iterate over whichever is the smaller set
3570 nextSet.swap(currentIntersection);
3571 }
3572 ValueSet::iterator it = currentIntersection.begin();
3573 while (it != currentIntersection.end()) {
3574 if (!nextSet.count(*it)) {
3575 ValueSet::iterator del = it;
3576 ++it;
3577 currentIntersection.erase(del);
3578 } else {
3579 ++it;
3580 }
3581 }
3582 }
3583 if (currentIntersection.empty()) {
3584 break;
3585 }
3586 }
3587 return Value(vector<Value>(currentIntersection.begin(), currentIntersection.end()));
3588 }
3589
3590 REGISTER_EXPRESSION(setIntersection, ExpressionSetIntersection::parse);
getOpName() const3591 const char* ExpressionSetIntersection::getOpName() const {
3592 return "$setIntersection";
3593 }
3594
3595 /* ----------------------- ExpressionSetIsSubset ---------------------------- */
3596
3597 namespace {
setIsSubsetHelper(const vector<Value> & lhs,const ValueSet & rhs)3598 Value setIsSubsetHelper(const vector<Value>& lhs, const ValueSet& rhs) {
3599 // do not shortcircuit when lhs.size() > rhs.size()
3600 // because lhs can have redundant entries
3601 for (vector<Value>::const_iterator it = lhs.begin(); it != lhs.end(); ++it) {
3602 if (!rhs.count(*it)) {
3603 return Value(false);
3604 }
3605 }
3606 return Value(true);
3607 }
3608 } // namespace
3609
evaluate(const Document & root,Variables * variables) const3610 Value ExpressionSetIsSubset::evaluate(const Document& root, Variables* variables) const {
3611 const Value lhs = vpOperand[0]->evaluate(root, variables);
3612 const Value rhs = vpOperand[1]->evaluate(root, variables);
3613
3614 uassert(17046,
3615 str::stream() << "both operands of $setIsSubset must be arrays. First "
3616 << "argument is of type: "
3617 << typeName(lhs.getType()),
3618 lhs.isArray());
3619 uassert(17042,
3620 str::stream() << "both operands of $setIsSubset must be arrays. Second "
3621 << "argument is of type: "
3622 << typeName(rhs.getType()),
3623 rhs.isArray());
3624
3625 return setIsSubsetHelper(lhs.getArray(),
3626 arrayToSet(rhs, getExpressionContext()->getValueComparator()));
3627 }
3628
3629 /**
3630 * This class handles the case where the RHS set is constant.
3631 *
3632 * Since it is constant we can construct the hashset once which makes the runtime performance
3633 * effectively constant with respect to the size of RHS. Large, constant RHS is expected to be a
3634 * major use case for $redact and this has been verified to improve performance significantly.
3635 */
3636 class ExpressionSetIsSubset::Optimized : public ExpressionSetIsSubset {
3637 public:
Optimized(const boost::intrusive_ptr<ExpressionContext> & expCtx,const ValueSet & cachedRhsSet,const ExpressionVector & operands)3638 Optimized(const boost::intrusive_ptr<ExpressionContext>& expCtx,
3639 const ValueSet& cachedRhsSet,
3640 const ExpressionVector& operands)
3641 : ExpressionSetIsSubset(expCtx), _cachedRhsSet(cachedRhsSet) {
3642 vpOperand = operands;
3643 }
3644
evaluate(const Document & root,Variables * variables) const3645 virtual Value evaluate(const Document& root, Variables* variables) const {
3646 const Value lhs = vpOperand[0]->evaluate(root, variables);
3647
3648 uassert(17310,
3649 str::stream() << "both operands of $setIsSubset must be arrays. First "
3650 << "argument is of type: "
3651 << typeName(lhs.getType()),
3652 lhs.isArray());
3653
3654 return setIsSubsetHelper(lhs.getArray(), _cachedRhsSet);
3655 }
3656
3657 private:
3658 const ValueSet _cachedRhsSet;
3659 };
3660
optimize()3661 intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() {
3662 // perfore basic optimizations
3663 intrusive_ptr<Expression> optimized = ExpressionNary::optimize();
3664
3665 // if ExpressionNary::optimize() created a new value, return it directly
3666 if (optimized.get() != this)
3667 return optimized;
3668
3669 if (ExpressionConstant* ec = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
3670 const Value rhs = ec->getValue();
3671 uassert(17311,
3672 str::stream() << "both operands of $setIsSubset must be arrays. Second "
3673 << "argument is of type: "
3674 << typeName(rhs.getType()),
3675 rhs.isArray());
3676
3677 intrusive_ptr<Expression> optimizedWithConstant(
3678 new Optimized(this->getExpressionContext(),
3679 arrayToSet(rhs, getExpressionContext()->getValueComparator()),
3680 vpOperand));
3681 return optimizedWithConstant;
3682 }
3683 return optimized;
3684 }
3685
3686 REGISTER_EXPRESSION(setIsSubset, ExpressionSetIsSubset::parse);
getOpName() const3687 const char* ExpressionSetIsSubset::getOpName() const {
3688 return "$setIsSubset";
3689 }
3690
3691 /* ----------------------- ExpressionSetUnion ---------------------------- */
3692
evaluate(const Document & root,Variables * variables) const3693 Value ExpressionSetUnion::evaluate(const Document& root, Variables* variables) const {
3694 ValueSet unionedSet = getExpressionContext()->getValueComparator().makeOrderedValueSet();
3695 const size_t n = vpOperand.size();
3696 for (size_t i = 0; i < n; i++) {
3697 const Value newEntries = vpOperand[i]->evaluate(root, variables);
3698 if (newEntries.nullish()) {
3699 return Value(BSONNULL);
3700 }
3701 uassert(17043,
3702 str::stream() << "All operands of $setUnion must be arrays. One argument"
3703 << " is of type: "
3704 << typeName(newEntries.getType()),
3705 newEntries.isArray());
3706
3707 unionedSet.insert(newEntries.getArray().begin(), newEntries.getArray().end());
3708 }
3709 return Value(vector<Value>(unionedSet.begin(), unionedSet.end()));
3710 }
3711
3712 REGISTER_EXPRESSION(setUnion, ExpressionSetUnion::parse);
getOpName() const3713 const char* ExpressionSetUnion::getOpName() const {
3714 return "$setUnion";
3715 }
3716
3717 /* ----------------------- ExpressionIsArray ---------------------------- */
3718
evaluate(const Document & root,Variables * variables) const3719 Value ExpressionIsArray::evaluate(const Document& root, Variables* variables) const {
3720 Value argument = vpOperand[0]->evaluate(root, variables);
3721 return Value(argument.isArray());
3722 }
3723
3724 REGISTER_EXPRESSION(isArray, ExpressionIsArray::parse);
getOpName() const3725 const char* ExpressionIsArray::getOpName() const {
3726 return "$isArray";
3727 }
3728
3729 /* ----------------------- ExpressionSlice ---------------------------- */
3730
evaluate(const Document & root,Variables * variables) const3731 Value ExpressionSlice::evaluate(const Document& root, Variables* variables) const {
3732 const size_t n = vpOperand.size();
3733
3734 Value arrayVal = vpOperand[0]->evaluate(root, variables);
3735 // Could be either a start index or the length from 0.
3736 Value arg2 = vpOperand[1]->evaluate(root, variables);
3737
3738 if (arrayVal.nullish() || arg2.nullish()) {
3739 return Value(BSONNULL);
3740 }
3741
3742 uassert(28724,
3743 str::stream() << "First argument to $slice must be an array, but is"
3744 << " of type: "
3745 << typeName(arrayVal.getType()),
3746 arrayVal.isArray());
3747 uassert(28725,
3748 str::stream() << "Second argument to $slice must be a numeric value,"
3749 << " but is of type: "
3750 << typeName(arg2.getType()),
3751 arg2.numeric());
3752 uassert(28726,
3753 str::stream() << "Second argument to $slice can't be represented as"
3754 << " a 32-bit integer: "
3755 << arg2.coerceToDouble(),
3756 arg2.integral());
3757
3758 const auto& array = arrayVal.getArray();
3759 size_t start;
3760 size_t end;
3761
3762 if (n == 2) {
3763 // Only count given.
3764 int count = arg2.coerceToInt();
3765 start = 0;
3766 end = array.size();
3767 if (count >= 0) {
3768 end = std::min(end, size_t(count));
3769 } else {
3770 // Negative count's start from the back. If a abs(count) is greater
3771 // than the
3772 // length of the array, return the whole array.
3773 start = std::max(0, static_cast<int>(array.size()) + count);
3774 }
3775 } else {
3776 // We have both a start index and a count.
3777 int startInt = arg2.coerceToInt();
3778 if (startInt < 0) {
3779 // Negative values start from the back. If a abs(start) is greater
3780 // than the length
3781 // of the array, start from 0.
3782 start = std::max(0, static_cast<int>(array.size()) + startInt);
3783 } else {
3784 start = std::min(array.size(), size_t(startInt));
3785 }
3786
3787 Value countVal = vpOperand[2]->evaluate(root, variables);
3788
3789 if (countVal.nullish()) {
3790 return Value(BSONNULL);
3791 }
3792
3793 uassert(28727,
3794 str::stream() << "Third argument to $slice must be numeric, but "
3795 << "is of type: "
3796 << typeName(countVal.getType()),
3797 countVal.numeric());
3798 uassert(28728,
3799 str::stream() << "Third argument to $slice can't be represented"
3800 << " as a 32-bit integer: "
3801 << countVal.coerceToDouble(),
3802 countVal.integral());
3803 uassert(28729,
3804 str::stream() << "Third argument to $slice must be positive: "
3805 << countVal.coerceToInt(),
3806 countVal.coerceToInt() > 0);
3807
3808 size_t count = size_t(countVal.coerceToInt());
3809 end = std::min(start + count, array.size());
3810 }
3811
3812 return Value(vector<Value>(array.begin() + start, array.begin() + end));
3813 }
3814
3815 REGISTER_EXPRESSION(slice, ExpressionSlice::parse);
getOpName() const3816 const char* ExpressionSlice::getOpName() const {
3817 return "$slice";
3818 }
3819
3820 /* ----------------------- ExpressionSize ---------------------------- */
3821
evaluate(const Document & root,Variables * variables) const3822 Value ExpressionSize::evaluate(const Document& root, Variables* variables) const {
3823 Value array = vpOperand[0]->evaluate(root, variables);
3824
3825 uassert(17124,
3826 str::stream() << "The argument to $size must be an array, but was of type: "
3827 << typeName(array.getType()),
3828 array.isArray());
3829 return Value::createIntOrLong(array.getArray().size());
3830 }
3831
3832 REGISTER_EXPRESSION(size, ExpressionSize::parse);
getOpName() const3833 const char* ExpressionSize::getOpName() const {
3834 return "$size";
3835 }
3836
3837 /* ----------------------- ExpressionSplit --------------------------- */
3838
evaluate(const Document & root,Variables * variables) const3839 Value ExpressionSplit::evaluate(const Document& root, Variables* variables) const {
3840 Value inputArg = vpOperand[0]->evaluate(root, variables);
3841 Value separatorArg = vpOperand[1]->evaluate(root, variables);
3842
3843 if (inputArg.nullish() || separatorArg.nullish()) {
3844 return Value(BSONNULL);
3845 }
3846
3847 uassert(40085,
3848 str::stream() << "$split requires an expression that evaluates to a string as a first "
3849 "argument, found: "
3850 << typeName(inputArg.getType()),
3851 inputArg.getType() == BSONType::String);
3852 uassert(40086,
3853 str::stream() << "$split requires an expression that evaluates to a string as a second "
3854 "argument, found: "
3855 << typeName(separatorArg.getType()),
3856 separatorArg.getType() == BSONType::String);
3857
3858 std::string input = inputArg.getString();
3859 std::string separator = separatorArg.getString();
3860
3861 uassert(40087, "$split requires a non-empty separator", !separator.empty());
3862
3863 std::vector<Value> output;
3864
3865 // Keep track of the index at which the current output string began.
3866 size_t splitStartIndex = 0;
3867
3868 // Iterate through 'input' and check to see if 'separator' matches at any point.
3869 for (size_t i = 0; i < input.size();) {
3870 if (stringHasTokenAtIndex(i, input, separator)) {
3871 // We matched; add the current string to our output and jump ahead.
3872 StringData splitString(input.c_str() + splitStartIndex, i - splitStartIndex);
3873 output.push_back(Value(splitString));
3874 i += separator.size();
3875 splitStartIndex = i;
3876 } else {
3877 // We did not match, continue to the next character.
3878 ++i;
3879 }
3880 }
3881
3882 StringData splitString(input.c_str() + splitStartIndex, input.size() - splitStartIndex);
3883 output.push_back(Value(splitString));
3884
3885 return Value(output);
3886 }
3887
3888 REGISTER_EXPRESSION(split, ExpressionSplit::parse);
getOpName() const3889 const char* ExpressionSplit::getOpName() const {
3890 return "$split";
3891 }
3892
3893 /* ----------------------- ExpressionSqrt ---------------------------- */
3894
evaluateNumericArg(const Value & numericArg) const3895 Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const {
3896 auto checkArg = [](bool nonNegative) {
3897 uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative);
3898 };
3899
3900 if (numericArg.getType() == NumberDecimal) {
3901 Decimal128 argDec = numericArg.getDecimal();
3902 checkArg(!argDec.isLess(Decimal128::kNormalizedZero)); // NaN returns Nan without error
3903 return Value(argDec.squareRoot());
3904 }
3905 double argDouble = numericArg.coerceToDouble();
3906 checkArg(!(argDouble < 0)); // NaN returns Nan without error
3907 return Value(sqrt(argDouble));
3908 }
3909
3910 REGISTER_EXPRESSION(sqrt, ExpressionSqrt::parse);
getOpName() const3911 const char* ExpressionSqrt::getOpName() const {
3912 return "$sqrt";
3913 }
3914
3915 /* ----------------------- ExpressionStrcasecmp ---------------------------- */
3916
evaluate(const Document & root,Variables * variables) const3917 Value ExpressionStrcasecmp::evaluate(const Document& root, Variables* variables) const {
3918 Value pString1(vpOperand[0]->evaluate(root, variables));
3919 Value pString2(vpOperand[1]->evaluate(root, variables));
3920
3921 /* boost::iequals returns a bool not an int so strings must actually be allocated */
3922 string str1 = boost::to_upper_copy(pString1.coerceToString());
3923 string str2 = boost::to_upper_copy(pString2.coerceToString());
3924 int result = str1.compare(str2);
3925
3926 if (result == 0)
3927 return Value(0);
3928 else if (result > 0)
3929 return Value(1);
3930 else
3931 return Value(-1);
3932 }
3933
3934 REGISTER_EXPRESSION(strcasecmp, ExpressionStrcasecmp::parse);
getOpName() const3935 const char* ExpressionStrcasecmp::getOpName() const {
3936 return "$strcasecmp";
3937 }
3938
3939 /* ----------------------- ExpressionSubstrBytes ---------------------------- */
3940
evaluate(const Document & root,Variables * variables) const3941 Value ExpressionSubstrBytes::evaluate(const Document& root, Variables* variables) const {
3942 Value pString(vpOperand[0]->evaluate(root, variables));
3943 Value pLower(vpOperand[1]->evaluate(root, variables));
3944 Value pLength(vpOperand[2]->evaluate(root, variables));
3945
3946 string str = pString.coerceToString();
3947 uassert(16034,
3948 str::stream() << getOpName()
3949 << ": starting index must be a numeric type (is BSON type "
3950 << typeName(pLower.getType())
3951 << ")",
3952 (pLower.getType() == NumberInt || pLower.getType() == NumberLong ||
3953 pLower.getType() == NumberDouble));
3954 uassert(16035,
3955 str::stream() << getOpName() << ": length must be a numeric type (is BSON type "
3956 << typeName(pLength.getType())
3957 << ")",
3958 (pLength.getType() == NumberInt || pLength.getType() == NumberLong ||
3959 pLength.getType() == NumberDouble));
3960
3961 string::size_type lower = static_cast<string::size_type>(pLower.coerceToLong());
3962 string::size_type length = static_cast<string::size_type>(pLength.coerceToLong());
3963
3964 uassert(28656,
3965 str::stream() << getOpName()
3966 << ": Invalid range, starting index is a UTF-8 continuation byte.",
3967 (lower >= str.length() || !str::isUTF8ContinuationByte(str[lower])));
3968
3969 // Check the byte after the last character we'd return. If it is a continuation byte, that
3970 // means we're in the middle of a UTF-8 character.
3971 uassert(
3972 28657,
3973 str::stream() << getOpName()
3974 << ": Invalid range, ending index is in the middle of a UTF-8 character.",
3975 (lower + length >= str.length() || !str::isUTF8ContinuationByte(str[lower + length])));
3976
3977 if (lower >= str.length()) {
3978 // If lower > str.length() then string::substr() will throw out_of_range, so return an
3979 // empty string if lower is not a valid string index.
3980 return Value(StringData());
3981 }
3982 return Value(str.substr(lower, length));
3983 }
3984
3985 // $substr is deprecated in favor of $substrBytes, but for now will just parse into a $substrBytes.
3986 REGISTER_EXPRESSION(substrBytes, ExpressionSubstrBytes::parse);
3987 REGISTER_EXPRESSION(substr, ExpressionSubstrBytes::parse);
getOpName() const3988 const char* ExpressionSubstrBytes::getOpName() const {
3989 return "$substrBytes";
3990 }
3991
3992 /* ----------------------- ExpressionSubstrCP ---------------------------- */
3993
evaluate(const Document & root,Variables * variables) const3994 Value ExpressionSubstrCP::evaluate(const Document& root, Variables* variables) const {
3995 Value inputVal(vpOperand[0]->evaluate(root, variables));
3996 Value lowerVal(vpOperand[1]->evaluate(root, variables));
3997 Value lengthVal(vpOperand[2]->evaluate(root, variables));
3998
3999 std::string str = inputVal.coerceToString();
4000 uassert(34450,
4001 str::stream() << getOpName() << ": starting index must be a numeric type (is BSON type "
4002 << typeName(lowerVal.getType())
4003 << ")",
4004 lowerVal.numeric());
4005 uassert(34451,
4006 str::stream() << getOpName()
4007 << ": starting index cannot be represented as a 32-bit integral value: "
4008 << lowerVal.toString(),
4009 lowerVal.integral());
4010 uassert(34452,
4011 str::stream() << getOpName() << ": length must be a numeric type (is BSON type "
4012 << typeName(lengthVal.getType())
4013 << ")",
4014 lengthVal.numeric());
4015 uassert(34453,
4016 str::stream() << getOpName()
4017 << ": length cannot be represented as a 32-bit integral value: "
4018 << lengthVal.toString(),
4019 lengthVal.integral());
4020
4021 int startIndexCodePoints = lowerVal.coerceToInt();
4022 int length = lengthVal.coerceToInt();
4023
4024 uassert(34454,
4025 str::stream() << getOpName() << ": length must be a nonnegative integer.",
4026 length >= 0);
4027
4028 uassert(34455,
4029 str::stream() << getOpName() << ": the starting index must be nonnegative integer.",
4030 startIndexCodePoints >= 0);
4031
4032 size_t startIndexBytes = 0;
4033
4034 for (int i = 0; i < startIndexCodePoints; i++) {
4035 if (startIndexBytes >= str.size()) {
4036 return Value(StringData());
4037 }
4038 uassert(34456,
4039 str::stream() << getOpName() << ": invalid UTF-8 string",
4040 !str::isUTF8ContinuationByte(str[startIndexBytes]));
4041 size_t codePointLength = getCodePointLength(str[startIndexBytes]);
4042 uassert(
4043 34457, str::stream() << getOpName() << ": invalid UTF-8 string", codePointLength <= 4);
4044 startIndexBytes += codePointLength;
4045 }
4046
4047 size_t endIndexBytes = startIndexBytes;
4048
4049 for (int i = 0; i < length && endIndexBytes < str.size(); i++) {
4050 uassert(34458,
4051 str::stream() << getOpName() << ": invalid UTF-8 string",
4052 !str::isUTF8ContinuationByte(str[endIndexBytes]));
4053 size_t codePointLength = getCodePointLength(str[endIndexBytes]);
4054 uassert(
4055 34459, str::stream() << getOpName() << ": invalid UTF-8 string", codePointLength <= 4);
4056 endIndexBytes += codePointLength;
4057 }
4058
4059 return Value(std::string(str, startIndexBytes, endIndexBytes - startIndexBytes));
4060 }
4061
4062 REGISTER_EXPRESSION(substrCP, ExpressionSubstrCP::parse);
getOpName() const4063 const char* ExpressionSubstrCP::getOpName() const {
4064 return "$substrCP";
4065 }
4066
4067 /* ----------------------- ExpressionStrLenBytes ------------------------- */
4068
evaluate(const Document & root,Variables * variables) const4069 Value ExpressionStrLenBytes::evaluate(const Document& root, Variables* variables) const {
4070 Value str(vpOperand[0]->evaluate(root, variables));
4071
4072 uassert(34473,
4073 str::stream() << "$strLenBytes requires a string argument, found: "
4074 << typeName(str.getType()),
4075 str.getType() == String);
4076
4077 size_t strLen = str.getString().size();
4078
4079 uassert(34470,
4080 "string length could not be represented as an int.",
4081 strLen <= std::numeric_limits<int>::max());
4082 return Value(static_cast<int>(strLen));
4083 }
4084
4085 REGISTER_EXPRESSION(strLenBytes, ExpressionStrLenBytes::parse);
getOpName() const4086 const char* ExpressionStrLenBytes::getOpName() const {
4087 return "$strLenBytes";
4088 }
4089
4090 /* ----------------------- ExpressionStrLenCP ------------------------- */
4091
evaluate(const Document & root,Variables * variables) const4092 Value ExpressionStrLenCP::evaluate(const Document& root, Variables* variables) const {
4093 Value val(vpOperand[0]->evaluate(root, variables));
4094
4095 uassert(34471,
4096 str::stream() << "$strLenCP requires a string argument, found: "
4097 << typeName(val.getType()),
4098 val.getType() == String);
4099
4100 std::string stringVal = val.getString();
4101 size_t strLen = str::lengthInUTF8CodePoints(stringVal);
4102
4103 uassert(34472,
4104 "string length could not be represented as an int.",
4105 strLen <= std::numeric_limits<int>::max());
4106
4107 return Value(static_cast<int>(strLen));
4108 }
4109
4110 REGISTER_EXPRESSION(strLenCP, ExpressionStrLenCP::parse);
getOpName() const4111 const char* ExpressionStrLenCP::getOpName() const {
4112 return "$strLenCP";
4113 }
4114
4115 /* ----------------------- ExpressionSubtract ---------------------------- */
4116
evaluate(const Document & root,Variables * variables) const4117 Value ExpressionSubtract::evaluate(const Document& root, Variables* variables) const {
4118 Value lhs = vpOperand[0]->evaluate(root, variables);
4119 Value rhs = vpOperand[1]->evaluate(root, variables);
4120
4121 BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType());
4122
4123 if (diffType == NumberDecimal) {
4124 Decimal128 right = rhs.coerceToDecimal();
4125 Decimal128 left = lhs.coerceToDecimal();
4126 return Value(left.subtract(right));
4127 } else if (diffType == NumberDouble) {
4128 double right = rhs.coerceToDouble();
4129 double left = lhs.coerceToDouble();
4130 return Value(left - right);
4131 } else if (diffType == NumberLong) {
4132 long long result;
4133
4134 // If there is an overflow, convert the values to doubles.
4135 if (mongoSignedSubtractOverflow64(lhs.coerceToLong(), rhs.coerceToLong(), &result)) {
4136 return Value(lhs.coerceToDouble() - rhs.coerceToDouble());
4137 }
4138 return Value(result);
4139 } else if (diffType == NumberInt) {
4140 long long right = rhs.coerceToLong();
4141 long long left = lhs.coerceToLong();
4142 return Value::createIntOrLong(left - right);
4143 } else if (lhs.nullish() || rhs.nullish()) {
4144 return Value(BSONNULL);
4145 } else if (lhs.getType() == Date) {
4146 if (rhs.getType() == Date) {
4147 return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
4148 } else if (rhs.numeric()) {
4149 return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong()));
4150 } else {
4151 uasserted(16613,
4152 str::stream() << "cant $subtract a " << typeName(rhs.getType())
4153 << " from a Date");
4154 }
4155 } else {
4156 uasserted(16556,
4157 str::stream() << "cant $subtract a" << typeName(rhs.getType()) << " from a "
4158 << typeName(lhs.getType()));
4159 }
4160 }
4161
4162 REGISTER_EXPRESSION(subtract, ExpressionSubtract::parse);
getOpName() const4163 const char* ExpressionSubtract::getOpName() const {
4164 return "$subtract";
4165 }
4166
4167 /* ------------------------- ExpressionSwitch ------------------------------ */
4168
4169 REGISTER_EXPRESSION(switch, ExpressionSwitch::parse);
4170
evaluate(const Document & root,Variables * variables) const4171 Value ExpressionSwitch::evaluate(const Document& root, Variables* variables) const {
4172 for (auto&& branch : _branches) {
4173 Value caseExpression(branch.first->evaluate(root, variables));
4174
4175 if (caseExpression.coerceToBool()) {
4176 return branch.second->evaluate(root, variables);
4177 }
4178 }
4179
4180 uassert(40066,
4181 "$switch could not find a matching branch for an input, and no default was specified.",
4182 _default);
4183
4184 return _default->evaluate(root, variables);
4185 }
4186
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)4187 boost::intrusive_ptr<Expression> ExpressionSwitch::parse(
4188 const boost::intrusive_ptr<ExpressionContext>& expCtx,
4189 BSONElement expr,
4190 const VariablesParseState& vps) {
4191 uassert(40060,
4192 str::stream() << "$switch requires an object as an argument, found: "
4193 << typeName(expr.type()),
4194 expr.type() == Object);
4195
4196 intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch(expCtx));
4197
4198 for (auto&& elem : expr.Obj()) {
4199 auto field = elem.fieldNameStringData();
4200
4201 if (field == "branches") {
4202 // Parse each branch separately.
4203 uassert(40061,
4204 str::stream() << "$switch expected an array for 'branches', found: "
4205 << typeName(elem.type()),
4206 elem.type() == Array);
4207
4208 for (auto&& branch : elem.Array()) {
4209 uassert(40062,
4210 str::stream() << "$switch expected each branch to be an object, found: "
4211 << typeName(branch.type()),
4212 branch.type() == Object);
4213
4214 ExpressionPair branchExpression;
4215
4216 for (auto&& branchElement : branch.Obj()) {
4217 auto branchField = branchElement.fieldNameStringData();
4218
4219 if (branchField == "case") {
4220 branchExpression.first = parseOperand(expCtx, branchElement, vps);
4221 } else if (branchField == "then") {
4222 branchExpression.second = parseOperand(expCtx, branchElement, vps);
4223 } else {
4224 uasserted(40063,
4225 str::stream() << "$switch found an unknown argument to a branch: "
4226 << branchField);
4227 }
4228 }
4229
4230 uassert(40064,
4231 "$switch requires each branch have a 'case' expression",
4232 branchExpression.first);
4233 uassert(40065,
4234 "$switch requires each branch have a 'then' expression.",
4235 branchExpression.second);
4236
4237 expression->_branches.push_back(branchExpression);
4238 }
4239 } else if (field == "default") {
4240 // Optional, arbitrary expression.
4241 expression->_default = parseOperand(expCtx, elem, vps);
4242 } else {
4243 uasserted(40067, str::stream() << "$switch found an unknown argument: " << field);
4244 }
4245 }
4246
4247 uassert(40068, "$switch requires at least one branch.", !expression->_branches.empty());
4248
4249 return expression;
4250 }
4251
_doAddDependencies(DepsTracker * deps) const4252 void ExpressionSwitch::_doAddDependencies(DepsTracker* deps) const {
4253 for (auto&& branch : _branches) {
4254 branch.first->addDependencies(deps);
4255 branch.second->addDependencies(deps);
4256 }
4257
4258 if (_default) {
4259 _default->addDependencies(deps);
4260 }
4261 }
4262
optimize()4263 boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() {
4264 if (_default) {
4265 _default = _default->optimize();
4266 }
4267
4268 std::transform(_branches.begin(),
4269 _branches.end(),
4270 _branches.begin(),
4271 [](ExpressionPair branch) -> ExpressionPair {
4272 return {branch.first->optimize(), branch.second->optimize()};
4273 });
4274
4275 return this;
4276 }
4277
serialize(bool explain) const4278 Value ExpressionSwitch::serialize(bool explain) const {
4279 std::vector<Value> serializedBranches;
4280 serializedBranches.reserve(_branches.size());
4281
4282 for (auto&& branch : _branches) {
4283 serializedBranches.push_back(Value(Document{{"case", branch.first->serialize(explain)},
4284 {"then", branch.second->serialize(explain)}}));
4285 }
4286
4287 if (_default) {
4288 return Value(Document{{"$switch",
4289 Document{{"branches", Value(serializedBranches)},
4290 {"default", _default->serialize(explain)}}}});
4291 }
4292
4293 return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}});
4294 }
4295
4296 /* ------------------------- ExpressionToLower ----------------------------- */
4297
evaluate(const Document & root,Variables * variables) const4298 Value ExpressionToLower::evaluate(const Document& root, Variables* variables) const {
4299 Value pString(vpOperand[0]->evaluate(root, variables));
4300 string str = pString.coerceToString();
4301 boost::to_lower(str);
4302 return Value(str);
4303 }
4304
4305 REGISTER_EXPRESSION(toLower, ExpressionToLower::parse);
getOpName() const4306 const char* ExpressionToLower::getOpName() const {
4307 return "$toLower";
4308 }
4309
4310 /* ------------------------- ExpressionToUpper -------------------------- */
4311
evaluate(const Document & root,Variables * variables) const4312 Value ExpressionToUpper::evaluate(const Document& root, Variables* variables) const {
4313 Value pString(vpOperand[0]->evaluate(root, variables));
4314 string str(pString.coerceToString());
4315 boost::to_upper(str);
4316 return Value(str);
4317 }
4318
4319 REGISTER_EXPRESSION(toUpper, ExpressionToUpper::parse);
getOpName() const4320 const char* ExpressionToUpper::getOpName() const {
4321 return "$toUpper";
4322 }
4323
4324 /* ------------------------- ExpressionTrunc -------------------------- */
4325
evaluateNumericArg(const Value & numericArg) const4326 Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const {
4327 // There's no point in truncating integers or longs, it will have no effect.
4328
4329 switch (numericArg.getType()) {
4330 case NumberDecimal:
4331 return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
4332 Decimal128::kRoundTowardZero));
4333 case NumberDouble:
4334 return Value(std::trunc(numericArg.getDouble()));
4335 default:
4336 return numericArg;
4337 }
4338 }
4339
4340 REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse);
getOpName() const4341 const char* ExpressionTrunc::getOpName() const {
4342 return "$trunc";
4343 }
4344
4345 /* ------------------------- ExpressionType ----------------------------- */
4346
evaluate(const Document & root,Variables * variables) const4347 Value ExpressionType::evaluate(const Document& root, Variables* variables) const {
4348 Value val(vpOperand[0]->evaluate(root, variables));
4349 return Value(StringData(typeName(val.getType())));
4350 }
4351
4352 REGISTER_EXPRESSION(type, ExpressionType::parse);
getOpName() const4353 const char* ExpressionType::getOpName() const {
4354 return "$type";
4355 }
4356
4357 /* -------------------------- ExpressionZip ------------------------------ */
4358
4359 REGISTER_EXPRESSION(zip, ExpressionZip::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)4360 intrusive_ptr<Expression> ExpressionZip::parse(
4361 const boost::intrusive_ptr<ExpressionContext>& expCtx,
4362 BSONElement expr,
4363 const VariablesParseState& vps) {
4364 uassert(34460,
4365 str::stream() << "$zip only supports an object as an argument, found "
4366 << typeName(expr.type()),
4367 expr.type() == Object);
4368
4369 intrusive_ptr<ExpressionZip> newZip(new ExpressionZip(expCtx));
4370
4371 for (auto&& elem : expr.Obj()) {
4372 const auto field = elem.fieldNameStringData();
4373 if (field == "inputs") {
4374 uassert(34461,
4375 str::stream() << "inputs must be an array of expressions, found "
4376 << typeName(elem.type()),
4377 elem.type() == Array);
4378 for (auto&& subExpr : elem.Array()) {
4379 newZip->_inputs.push_back(parseOperand(expCtx, subExpr, vps));
4380 }
4381 } else if (field == "defaults") {
4382 uassert(34462,
4383 str::stream() << "defaults must be an array of expressions, found "
4384 << typeName(elem.type()),
4385 elem.type() == Array);
4386 for (auto&& subExpr : elem.Array()) {
4387 newZip->_defaults.push_back(parseOperand(expCtx, subExpr, vps));
4388 }
4389 } else if (field == "useLongestLength") {
4390 uassert(34463,
4391 str::stream() << "useLongestLength must be a bool, found "
4392 << typeName(expr.type()),
4393 elem.type() == Bool);
4394 newZip->_useLongestLength = elem.Bool();
4395 } else {
4396 uasserted(34464,
4397 str::stream() << "$zip found an unknown argument: " << elem.fieldName());
4398 }
4399 }
4400
4401 uassert(34465, "$zip requires at least one input array", !newZip->_inputs.empty());
4402 uassert(34466,
4403 "cannot specify defaults unless useLongestLength is true",
4404 (newZip->_useLongestLength || newZip->_defaults.empty()));
4405 uassert(34467,
4406 "defaults and inputs must have the same length",
4407 (newZip->_defaults.empty() || newZip->_defaults.size() == newZip->_inputs.size()));
4408
4409 return std::move(newZip);
4410 }
4411
evaluate(const Document & root,Variables * variables) const4412 Value ExpressionZip::evaluate(const Document& root, Variables* variables) const {
4413 // Evaluate input values.
4414 vector<vector<Value>> inputValues;
4415 inputValues.reserve(_inputs.size());
4416
4417 size_t minArraySize = 0;
4418 size_t maxArraySize = 0;
4419 for (size_t i = 0; i < _inputs.size(); i++) {
4420 Value evalExpr = _inputs[i].get()->evaluate(root, variables);
4421 if (evalExpr.nullish()) {
4422 return Value(BSONNULL);
4423 }
4424
4425 uassert(34468,
4426 str::stream() << "$zip found a non-array expression in input: "
4427 << evalExpr.toString(),
4428 evalExpr.isArray());
4429
4430 inputValues.push_back(evalExpr.getArray());
4431
4432 size_t arraySize = evalExpr.getArrayLength();
4433
4434 if (i == 0) {
4435 minArraySize = arraySize;
4436 maxArraySize = arraySize;
4437 } else {
4438 auto arraySizes = std::minmax({minArraySize, arraySize, maxArraySize});
4439 minArraySize = arraySizes.first;
4440 maxArraySize = arraySizes.second;
4441 }
4442 }
4443
4444 vector<Value> evaluatedDefaults(_inputs.size(), Value(BSONNULL));
4445
4446 // If we need default values, evaluate each expression.
4447 if (minArraySize != maxArraySize) {
4448 for (size_t i = 0; i < _defaults.size(); i++) {
4449 evaluatedDefaults[i] = _defaults[i].get()->evaluate(root, variables);
4450 }
4451 }
4452
4453 size_t outputLength = _useLongestLength ? maxArraySize : minArraySize;
4454
4455 // The final output array, e.g. [[1, 2, 3], [2, 3, 4]].
4456 vector<Value> output;
4457
4458 // Used to construct each array in the output, e.g. [1, 2, 3].
4459 vector<Value> outputChild;
4460
4461 output.reserve(outputLength);
4462 outputChild.reserve(_inputs.size());
4463
4464 for (size_t row = 0; row < outputLength; row++) {
4465 outputChild.clear();
4466 for (size_t col = 0; col < _inputs.size(); col++) {
4467 if (inputValues[col].size() > row) {
4468 // Add the value from the appropriate input array.
4469 outputChild.push_back(inputValues[col][row]);
4470 } else {
4471 // Add the corresponding default value.
4472 outputChild.push_back(evaluatedDefaults[col]);
4473 }
4474 }
4475 output.push_back(Value(outputChild));
4476 }
4477
4478 return Value(output);
4479 }
4480
optimize()4481 boost::intrusive_ptr<Expression> ExpressionZip::optimize() {
4482 std::transform(_inputs.begin(),
4483 _inputs.end(),
4484 _inputs.begin(),
4485 [](intrusive_ptr<Expression> inputExpression) -> intrusive_ptr<Expression> {
4486 return inputExpression->optimize();
4487 });
4488
4489 std::transform(_defaults.begin(),
4490 _defaults.end(),
4491 _defaults.begin(),
4492 [](intrusive_ptr<Expression> defaultExpression) -> intrusive_ptr<Expression> {
4493 return defaultExpression->optimize();
4494 });
4495
4496 return this;
4497 }
4498
serialize(bool explain) const4499 Value ExpressionZip::serialize(bool explain) const {
4500 vector<Value> serializedInput;
4501 vector<Value> serializedDefaults;
4502 Value serializedUseLongestLength = Value(_useLongestLength);
4503
4504 for (auto&& expr : _inputs) {
4505 serializedInput.push_back(expr->serialize(explain));
4506 }
4507
4508 for (auto&& expr : _defaults) {
4509 serializedDefaults.push_back(expr->serialize(explain));
4510 }
4511
4512 return Value(DOC("$zip" << DOC("inputs" << Value(serializedInput) << "defaults"
4513 << Value(serializedDefaults)
4514 << "useLongestLength"
4515 << serializedUseLongestLength)));
4516 }
4517
_doAddDependencies(DepsTracker * deps) const4518 void ExpressionZip::_doAddDependencies(DepsTracker* deps) const {
4519 std::for_each(
4520 _inputs.begin(), _inputs.end(), [&deps](intrusive_ptr<Expression> inputExpression) -> void {
4521 inputExpression->addDependencies(deps);
4522 });
4523 std::for_each(_defaults.begin(),
4524 _defaults.end(),
4525 [&deps](intrusive_ptr<Expression> defaultExpression) -> void {
4526 defaultExpression->addDependencies(deps);
4527 });
4528 }
4529
4530 } // namespace mongo
4531