1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 
32 #include "mongo/platform/basic.h"
33 
34 #include "mongo/db/pipeline/expression.h"
35 
36 #include <algorithm>
37 #include <boost/algorithm/string.hpp>
38 #include <cstdio>
39 #include <vector>
40 
41 #include "mongo/db/jsobj.h"
42 #include "mongo/db/pipeline/document.h"
43 #include "mongo/db/pipeline/expression_context.h"
44 #include "mongo/db/pipeline/value.h"
45 #include "mongo/db/query/datetime/date_time_support.h"
46 #include "mongo/platform/bits.h"
47 #include "mongo/platform/decimal128.h"
48 #include "mongo/util/mongoutils/str.h"
49 #include "mongo/util/string_map.h"
50 #include "mongo/util/summation.h"
51 
52 namespace mongo {
53 using Parser = Expression::Parser;
54 
55 using namespace mongoutils;
56 
57 using boost::intrusive_ptr;
58 using std::map;
59 using std::move;
60 using std::pair;
61 using std::set;
62 using std::string;
63 using std::vector;
64 
65 /// Helper function to easily wrap constants with $const.
serializeConstant(Value val)66 static Value serializeConstant(Value val) {
67     if (val.missing()) {
68         return Value("$$REMOVE"_sd);
69     }
70 
71     return Value(DOC("$const" << val));
72 }
73 
74 /* --------------------------- Expression ------------------------------ */
75 
removeFieldPrefix(const string & prefixedField)76 string Expression::removeFieldPrefix(const string& prefixedField) {
77     uassert(16419,
78             str::stream() << "field path must not contain embedded null characters"
79                           << prefixedField.find("\0")
80                           << ",",
81             prefixedField.find('\0') == string::npos);
82 
83     const char* pPrefixedField = prefixedField.c_str();
84     uassert(15982,
85             str::stream() << "field path references must be prefixed with a '$' ('" << prefixedField
86                           << "'",
87             pPrefixedField[0] == '$');
88 
89     return string(pPrefixedField + 1);
90 }
91 
parseObject(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)92 intrusive_ptr<Expression> Expression::parseObject(
93     const boost::intrusive_ptr<ExpressionContext>& expCtx,
94     BSONObj obj,
95     const VariablesParseState& vps) {
96     if (obj.isEmpty()) {
97         return ExpressionObject::create(expCtx, {});
98     }
99 
100     if (obj.firstElementFieldName()[0] == '$') {
101         // Assume this is an expression like {$add: [...]}.
102         return parseExpression(expCtx, obj, vps);
103     }
104 
105     return ExpressionObject::parse(expCtx, obj, vps);
106 }
107 
108 namespace {
109 StringMap<Parser> parserMap;
110 }
111 
registerExpression(string key,Parser parser)112 void Expression::registerExpression(string key, Parser parser) {
113     auto op = parserMap.find(key);
114     massert(17064,
115             str::stream() << "Duplicate expression (" << key << ") registered.",
116             op == parserMap.end());
117     parserMap[key] = parser;
118 }
119 
parseExpression(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)120 intrusive_ptr<Expression> Expression::parseExpression(
121     const boost::intrusive_ptr<ExpressionContext>& expCtx,
122     BSONObj obj,
123     const VariablesParseState& vps) {
124     uassert(15983,
125             str::stream() << "An object representing an expression must have exactly one "
126                              "field: "
127                           << obj.toString(),
128             obj.nFields() == 1);
129 
130     // Look up the parser associated with the expression name.
131     const char* opName = obj.firstElementFieldName();
132     auto op = parserMap.find(opName);
133     uassert(ErrorCodes::InvalidPipelineOperator,
134             str::stream() << "Unrecognized expression '" << opName << "'",
135             op != parserMap.end());
136     return op->second(expCtx, obj.firstElement(), vps);
137 }
138 
parseArguments(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)139 Expression::ExpressionVector ExpressionNary::parseArguments(
140     const boost::intrusive_ptr<ExpressionContext>& expCtx,
141     BSONElement exprElement,
142     const VariablesParseState& vps) {
143     ExpressionVector out;
144     if (exprElement.type() == Array) {
145         BSONForEach(elem, exprElement.Obj()) {
146             out.push_back(Expression::parseOperand(expCtx, elem, vps));
147         }
148     } else {  // Assume it's an operand that accepts a single argument.
149         out.push_back(Expression::parseOperand(expCtx, exprElement, vps));
150     }
151 
152     return out;
153 }
154 
parseOperand(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)155 intrusive_ptr<Expression> Expression::parseOperand(
156     const boost::intrusive_ptr<ExpressionContext>& expCtx,
157     BSONElement exprElement,
158     const VariablesParseState& vps) {
159     BSONType type = exprElement.type();
160 
161     if (type == String && exprElement.valuestr()[0] == '$') {
162         /* if we got here, this is a field path expression */
163         return ExpressionFieldPath::parse(expCtx, exprElement.str(), vps);
164     } else if (type == Object) {
165         return Expression::parseObject(expCtx, exprElement.Obj(), vps);
166     } else if (type == Array) {
167         return ExpressionArray::parse(expCtx, exprElement, vps);
168     } else {
169         return ExpressionConstant::parse(expCtx, exprElement, vps);
170     }
171 }
172 
173 namespace {
174 /**
175  * UTF-8 multi-byte code points consist of one leading byte of the form 11xxxxxx, and potentially
176  * many continuation bytes of the form 10xxxxxx. This method checks whether 'charByte' is a leading
177  * byte.
178  */
isLeadingByte(char charByte)179 bool isLeadingByte(char charByte) {
180     return (charByte & 0xc0) == 0xc0;
181 }
182 
183 /**
184  * UTF-8 single-byte code points are of the form 0xxxxxxx. This method checks whether 'charByte' is
185  * a single-byte code point.
186  */
isSingleByte(char charByte)187 bool isSingleByte(char charByte) {
188     return (charByte & 0x80) == 0x0;
189 }
190 
getCodePointLength(char charByte)191 size_t getCodePointLength(char charByte) {
192     if (isSingleByte(charByte)) {
193         return 1;
194     }
195 
196     invariant(isLeadingByte(charByte));
197 
198     // In UTF-8, the number of leading ones is the number of bytes the code point takes up.
199     return countLeadingZeros64(~(uint64_t(charByte) << (64 - 8)));
200 }
201 }  // namespace
202 
203 /* ------------------------- Register Date Expressions ----------------------------- */
204 
205 REGISTER_EXPRESSION(dayOfMonth, ExpressionDayOfMonth::parse);
206 REGISTER_EXPRESSION(dayOfWeek, ExpressionDayOfWeek::parse);
207 REGISTER_EXPRESSION(dayOfYear, ExpressionDayOfYear::parse);
208 REGISTER_EXPRESSION(hour, ExpressionHour::parse);
209 REGISTER_EXPRESSION(isoDayOfWeek, ExpressionIsoDayOfWeek::parse);
210 REGISTER_EXPRESSION(isoWeek, ExpressionIsoWeek::parse);
211 REGISTER_EXPRESSION(isoWeekYear, ExpressionIsoWeekYear::parse);
212 REGISTER_EXPRESSION(millisecond, ExpressionMillisecond::parse);
213 REGISTER_EXPRESSION(minute, ExpressionMinute::parse);
214 REGISTER_EXPRESSION(month, ExpressionMonth::parse);
215 REGISTER_EXPRESSION(second, ExpressionSecond::parse);
216 REGISTER_EXPRESSION(week, ExpressionWeek::parse);
217 REGISTER_EXPRESSION(year, ExpressionYear::parse);
218 
219 /* ----------------------- ExpressionAbs ---------------------------- */
220 
evaluateNumericArg(const Value & numericArg) const221 Value ExpressionAbs::evaluateNumericArg(const Value& numericArg) const {
222     BSONType type = numericArg.getType();
223     if (type == NumberDouble) {
224         return Value(std::abs(numericArg.getDouble()));
225     } else if (type == NumberDecimal) {
226         return Value(numericArg.getDecimal().toAbs());
227     } else {
228         long long num = numericArg.getLong();
229         uassert(28680,
230                 "can't take $abs of long long min",
231                 num != std::numeric_limits<long long>::min());
232         long long absVal = std::abs(num);
233         return type == NumberLong ? Value(absVal) : Value::createIntOrLong(absVal);
234     }
235 }
236 
237 REGISTER_EXPRESSION(abs, ExpressionAbs::parse);
getOpName() const238 const char* ExpressionAbs::getOpName() const {
239     return "$abs";
240 }
241 
242 /* ------------------------- ExpressionAdd ----------------------------- */
243 
evaluate(const Document & root,Variables * variables) const244 Value ExpressionAdd::evaluate(const Document& root, Variables* variables) const {
245     // We'll try to return the narrowest possible result value while avoiding overflow, loss
246     // of precision due to intermediate rounding or implicit use of decimal types. To do that,
247     // compute a compensated sum for non-decimal values and a separate decimal sum for decimal
248     // values, and track the current narrowest type.
249     DoubleDoubleSummation nonDecimalTotal;
250     Decimal128 decimalTotal;
251     BSONType totalType = NumberInt;
252     bool haveDate = false;
253 
254     const size_t n = vpOperand.size();
255     for (size_t i = 0; i < n; ++i) {
256         Value val = vpOperand[i]->evaluate(root, variables);
257 
258         switch (val.getType()) {
259             case NumberDecimal:
260                 decimalTotal = decimalTotal.add(val.getDecimal());
261                 totalType = NumberDecimal;
262                 break;
263             case NumberDouble:
264                 nonDecimalTotal.addDouble(val.getDouble());
265                 if (totalType != NumberDecimal)
266                     totalType = NumberDouble;
267                 break;
268             case NumberLong:
269                 nonDecimalTotal.addLong(val.getLong());
270                 if (totalType == NumberInt)
271                     totalType = NumberLong;
272                 break;
273             case NumberInt:
274                 nonDecimalTotal.addDouble(val.getInt());
275                 break;
276             case Date:
277                 uassert(16612, "only one date allowed in an $add expression", !haveDate);
278                 haveDate = true;
279                 nonDecimalTotal.addLong(val.getDate().toMillisSinceEpoch());
280                 break;
281             default:
282                 uassert(16554,
283                         str::stream() << "$add only supports numeric or date types, not "
284                                       << typeName(val.getType()),
285                         val.nullish());
286                 return Value(BSONNULL);
287         }
288     }
289 
290     if (haveDate) {
291         int64_t longTotal;
292         if (totalType == NumberDecimal) {
293             longTotal = decimalTotal.add(nonDecimalTotal.getDecimal()).toLong();
294         } else {
295             uassert(ErrorCodes::Overflow, "date overflow in $add", nonDecimalTotal.fitsLong());
296             longTotal = nonDecimalTotal.getLong();
297         }
298         return Value(Date_t::fromMillisSinceEpoch(longTotal));
299     }
300     switch (totalType) {
301         case NumberDecimal:
302             return Value(decimalTotal.add(nonDecimalTotal.getDecimal()));
303         case NumberLong:
304             dassert(nonDecimalTotal.isInteger());
305             if (nonDecimalTotal.fitsLong())
306                 return Value(nonDecimalTotal.getLong());
307         // Fallthrough.
308         case NumberInt:
309             if (nonDecimalTotal.fitsLong())
310                 return Value::createIntOrLong(nonDecimalTotal.getLong());
311         // Fallthrough.
312         case NumberDouble:
313             return Value(nonDecimalTotal.getDouble());
314         default:
315             massert(16417, "$add resulted in a non-numeric type", false);
316     }
317 }
318 
319 REGISTER_EXPRESSION(add, ExpressionAdd::parse);
getOpName() const320 const char* ExpressionAdd::getOpName() const {
321     return "$add";
322 }
323 
324 /* ------------------------- ExpressionAllElementsTrue -------------------------- */
325 
evaluate(const Document & root,Variables * variables) const326 Value ExpressionAllElementsTrue::evaluate(const Document& root, Variables* variables) const {
327     const Value arr = vpOperand[0]->evaluate(root, variables);
328     uassert(17040,
329             str::stream() << getOpName() << "'s argument must be an array, but is "
330                           << typeName(arr.getType()),
331             arr.isArray());
332     const vector<Value>& array = arr.getArray();
333     for (vector<Value>::const_iterator it = array.begin(); it != array.end(); ++it) {
334         if (!it->coerceToBool()) {
335             return Value(false);
336         }
337     }
338     return Value(true);
339 }
340 
341 REGISTER_EXPRESSION(allElementsTrue, ExpressionAllElementsTrue::parse);
getOpName() const342 const char* ExpressionAllElementsTrue::getOpName() const {
343     return "$allElementsTrue";
344 }
345 
346 /* ------------------------- ExpressionAnd ----------------------------- */
347 
optimize()348 intrusive_ptr<Expression> ExpressionAnd::optimize() {
349     /* optimize the conjunction as much as possible */
350     intrusive_ptr<Expression> pE(ExpressionNary::optimize());
351 
352     /* if the result isn't a conjunction, we can't do anything */
353     ExpressionAnd* pAnd = dynamic_cast<ExpressionAnd*>(pE.get());
354     if (!pAnd)
355         return pE;
356 
357     /*
358       Check the last argument on the result; if it's not constant (as
359       promised by ExpressionNary::optimize(),) then there's nothing
360       we can do.
361     */
362     const size_t n = pAnd->vpOperand.size();
363     // ExpressionNary::optimize() generates an ExpressionConstant for {$and:[]}.
364     verify(n > 0);
365     intrusive_ptr<Expression> pLast(pAnd->vpOperand[n - 1]);
366     const ExpressionConstant* pConst = dynamic_cast<ExpressionConstant*>(pLast.get());
367     if (!pConst)
368         return pE;
369 
370     /*
371       Evaluate and coerce the last argument to a boolean.  If it's false,
372       then we can replace this entire expression.
373      */
374     bool last = pConst->getValue().coerceToBool();
375     if (!last) {
376         intrusive_ptr<ExpressionConstant> pFinal(
377             ExpressionConstant::create(getExpressionContext(), Value(false)));
378         return pFinal;
379     }
380 
381     /*
382       If we got here, the final operand was true, so we don't need it
383       anymore.  If there was only one other operand, we don't need the
384       conjunction either.  Note we still need to keep the promise that
385       the result will be a boolean.
386      */
387     if (n == 2) {
388         intrusive_ptr<Expression> pFinal(
389             ExpressionCoerceToBool::create(getExpressionContext(), pAnd->vpOperand[0]));
390         return pFinal;
391     }
392 
393     /*
394       Remove the final "true" value, and return the new expression.
395 
396       CW TODO:
397       Note that because of any implicit conversions, we may need to
398       apply an implicit boolean conversion.
399     */
400     pAnd->vpOperand.resize(n - 1);
401     return pE;
402 }
403 
evaluate(const Document & root,Variables * variables) const404 Value ExpressionAnd::evaluate(const Document& root, Variables* variables) const {
405     const size_t n = vpOperand.size();
406     for (size_t i = 0; i < n; ++i) {
407         Value pValue(vpOperand[i]->evaluate(root, variables));
408         if (!pValue.coerceToBool())
409             return Value(false);
410     }
411 
412     return Value(true);
413 }
414 
415 REGISTER_EXPRESSION(and, ExpressionAnd::parse);
getOpName() const416 const char* ExpressionAnd::getOpName() const {
417     return "$and";
418 }
419 
420 /* ------------------------- ExpressionAnyElementTrue -------------------------- */
421 
evaluate(const Document & root,Variables * variables) const422 Value ExpressionAnyElementTrue::evaluate(const Document& root, Variables* variables) const {
423     const Value arr = vpOperand[0]->evaluate(root, variables);
424     uassert(17041,
425             str::stream() << getOpName() << "'s argument must be an array, but is "
426                           << typeName(arr.getType()),
427             arr.isArray());
428     const vector<Value>& array = arr.getArray();
429     for (vector<Value>::const_iterator it = array.begin(); it != array.end(); ++it) {
430         if (it->coerceToBool()) {
431             return Value(true);
432         }
433     }
434     return Value(false);
435 }
436 
437 REGISTER_EXPRESSION(anyElementTrue, ExpressionAnyElementTrue::parse);
getOpName() const438 const char* ExpressionAnyElementTrue::getOpName() const {
439     return "$anyElementTrue";
440 }
441 
442 /* ---------------------- ExpressionArray --------------------------- */
443 
evaluate(const Document & root,Variables * variables) const444 Value ExpressionArray::evaluate(const Document& root, Variables* variables) const {
445     vector<Value> values;
446     values.reserve(vpOperand.size());
447     for (auto&& expr : vpOperand) {
448         Value elemVal = expr->evaluate(root, variables);
449         values.push_back(elemVal.missing() ? Value(BSONNULL) : std::move(elemVal));
450     }
451     return Value(std::move(values));
452 }
453 
serialize(bool explain) const454 Value ExpressionArray::serialize(bool explain) const {
455     vector<Value> expressions;
456     expressions.reserve(vpOperand.size());
457     for (auto&& expr : vpOperand) {
458         expressions.push_back(expr->serialize(explain));
459     }
460     return Value(std::move(expressions));
461 }
462 
getOpName() const463 const char* ExpressionArray::getOpName() const {
464     // This should never be called, but is needed to inherit from ExpressionNary.
465     return "$array";
466 }
467 
468 /* ------------------------- ExpressionArrayElemAt -------------------------- */
469 
evaluate(const Document & root,Variables * variables) const470 Value ExpressionArrayElemAt::evaluate(const Document& root, Variables* variables) const {
471     const Value array = vpOperand[0]->evaluate(root, variables);
472     const Value indexArg = vpOperand[1]->evaluate(root, variables);
473 
474     if (array.nullish() || indexArg.nullish()) {
475         return Value(BSONNULL);
476     }
477 
478     uassert(28689,
479             str::stream() << getOpName() << "'s first argument must be an array, but is "
480                           << typeName(array.getType()),
481             array.isArray());
482     uassert(28690,
483             str::stream() << getOpName() << "'s second argument must be a numeric value,"
484                           << " but is "
485                           << typeName(indexArg.getType()),
486             indexArg.numeric());
487     uassert(28691,
488             str::stream() << getOpName() << "'s second argument must be representable as"
489                           << " a 32-bit integer: "
490                           << indexArg.coerceToDouble(),
491             indexArg.integral());
492 
493     long long i = indexArg.coerceToLong();
494     if (i < 0 && static_cast<size_t>(std::abs(i)) > array.getArrayLength()) {
495         // Positive indices that are too large are handled automatically by Value.
496         return Value();
497     } else if (i < 0) {
498         // Index from the back of the array.
499         i = array.getArrayLength() + i;
500     }
501     const size_t index = static_cast<size_t>(i);
502     return array[index];
503 }
504 
505 REGISTER_EXPRESSION(arrayElemAt, ExpressionArrayElemAt::parse);
getOpName() const506 const char* ExpressionArrayElemAt::getOpName() const {
507     return "$arrayElemAt";
508 }
509 
510 /* ------------------------- ExpressionObjectToArray -------------------------- */
511 
512 
evaluate(const Document & root,Variables * variables) const513 Value ExpressionObjectToArray::evaluate(const Document& root, Variables* variables) const {
514     const Value targetVal = vpOperand[0]->evaluate(root, variables);
515 
516     if (targetVal.nullish()) {
517         return Value(BSONNULL);
518     }
519 
520     uassert(40390,
521             str::stream() << "$objectToArray requires a document input, found: "
522                           << typeName(targetVal.getType()),
523             (targetVal.getType() == BSONType::Object));
524 
525     vector<Value> output;
526 
527     FieldIterator iter = targetVal.getDocument().fieldIterator();
528     while (iter.more()) {
529         Document::FieldPair pair = iter.next();
530         MutableDocument keyvalue;
531         keyvalue.addField("k", Value(pair.first));
532         keyvalue.addField("v", pair.second);
533         output.push_back(keyvalue.freezeToValue());
534     }
535 
536     return Value(output);
537 }
538 
539 REGISTER_EXPRESSION(objectToArray, ExpressionObjectToArray::parse);
getOpName() const540 const char* ExpressionObjectToArray::getOpName() const {
541     return "$objectToArray";
542 }
543 
544 /* ------------------------- ExpressionArrayToObject -------------------------- */
545 
evaluate(const Document & root,Variables * variables) const546 Value ExpressionArrayToObject::evaluate(const Document& root, Variables* variables) const {
547     const Value input = vpOperand[0]->evaluate(root, variables);
548     if (input.nullish()) {
549         return Value(BSONNULL);
550     }
551 
552     uassert(40386,
553             str::stream() << "$arrayToObject requires an array input, found: "
554                           << typeName(input.getType()),
555             input.isArray());
556 
557     MutableDocument output;
558     const vector<Value>& array = input.getArray();
559     if (array.empty()) {
560         return output.freezeToValue();
561     }
562 
563     // There are two accepted input formats in an array: [ [key, val] ] or [ {k:key, v:val} ]. The
564     // first array element determines the format for the rest of the array. Mixing input formats is
565     // not allowed.
566     bool inputArrayFormat;
567     if (array[0].isArray()) {
568         inputArrayFormat = true;
569     } else if (array[0].getType() == BSONType::Object) {
570         inputArrayFormat = false;
571     } else {
572         uasserted(40398,
573                   str::stream() << "Unrecognised input type format for $arrayToObject: "
574                                 << typeName(array[0].getType()));
575     }
576 
577     for (auto&& elem : array) {
578         if (inputArrayFormat == true) {
579             uassert(
580                 40396,
581                 str::stream() << "$arrayToObject requires a consistent input format. Elements must"
582                                  "all be arrays or all be objects. Array was detected, now found: "
583                               << typeName(elem.getType()),
584                 elem.isArray());
585 
586             const vector<Value>& valArray = elem.getArray();
587 
588             uassert(40397,
589                     str::stream() << "$arrayToObject requires an array of size 2 arrays,"
590                                      "found array of size: "
591                                   << valArray.size(),
592                     (valArray.size() == 2));
593 
594             uassert(40395,
595                     str::stream() << "$arrayToObject requires an array of key-value pairs, where "
596                                      "the key must be of type string. Found key type: "
597                                   << typeName(valArray[0].getType()),
598                     (valArray[0].getType() == BSONType::String));
599 
600             auto keyName = valArray[0].getString();
601 
602             uassert(4940400,
603                     "Key field cannot contain an embedded null byte",
604                     keyName.find('\0') == std::string::npos);
605 
606             output[keyName] = valArray[1];
607 
608         } else {
609             uassert(
610                 40391,
611                 str::stream() << "$arrayToObject requires a consistent input format. Elements must"
612                                  "all be arrays or all be objects. Object was detected, now found: "
613                               << typeName(elem.getType()),
614                 (elem.getType() == BSONType::Object));
615 
616             uassert(40392,
617                     str::stream() << "$arrayToObject requires an object keys of 'k' and 'v'. "
618                                      "Found incorrect number of keys:"
619                                   << elem.getDocument().size(),
620                     (elem.getDocument().size() == 2));
621 
622             Value key = elem.getDocument().getField("k");
623             Value value = elem.getDocument().getField("v");
624 
625             uassert(40393,
626                     str::stream() << "$arrayToObject requires an object with keys 'k' and 'v'. "
627                                      "Missing either or both keys from: "
628                                   << elem.toString(),
629                     (!key.missing() && !value.missing()));
630 
631             uassert(
632                 40394,
633                 str::stream() << "$arrayToObject requires an object with keys 'k' and 'v', where "
634                                  "the value of 'k' must be of type string. Found type: "
635                               << typeName(key.getType()),
636                 (key.getType() == BSONType::String));
637 
638             auto keyName = key.getString();
639 
640             uassert(4940401,
641                     "Key field cannot contain an embedded null byte",
642                     keyName.find('\0') == std::string::npos);
643 
644             output[keyName] = value;
645         }
646     }
647 
648     return output.freezeToValue();
649 }
650 
651 REGISTER_EXPRESSION(arrayToObject, ExpressionArrayToObject::parse);
getOpName() const652 const char* ExpressionArrayToObject::getOpName() const {
653     return "$arrayToObject";
654 }
655 
656 /* ------------------------- ExpressionCeil -------------------------- */
657 
evaluateNumericArg(const Value & numericArg) const658 Value ExpressionCeil::evaluateNumericArg(const Value& numericArg) const {
659     // There's no point in taking the ceiling of integers or longs, it will have no effect.
660     switch (numericArg.getType()) {
661         case NumberDouble:
662             return Value(std::ceil(numericArg.getDouble()));
663         case NumberDecimal:
664             // Round toward the nearest decimal with a zero exponent in the positive direction.
665             return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
666                                                           Decimal128::kRoundTowardPositive));
667         default:
668             return numericArg;
669     }
670 }
671 
672 REGISTER_EXPRESSION(ceil, ExpressionCeil::parse);
getOpName() const673 const char* ExpressionCeil::getOpName() const {
674     return "$ceil";
675 }
676 
677 /* -------------------- ExpressionCoerceToBool ------------------------- */
678 
create(const intrusive_ptr<ExpressionContext> & expCtx,const intrusive_ptr<Expression> & pExpression)679 intrusive_ptr<ExpressionCoerceToBool> ExpressionCoerceToBool::create(
680     const intrusive_ptr<ExpressionContext>& expCtx, const intrusive_ptr<Expression>& pExpression) {
681     intrusive_ptr<ExpressionCoerceToBool> pNew(new ExpressionCoerceToBool(expCtx, pExpression));
682     return pNew;
683 }
684 
ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext> & expCtx,const intrusive_ptr<Expression> & pTheExpression)685 ExpressionCoerceToBool::ExpressionCoerceToBool(const intrusive_ptr<ExpressionContext>& expCtx,
686                                                const intrusive_ptr<Expression>& pTheExpression)
687     : Expression(expCtx), pExpression(pTheExpression) {}
688 
optimize()689 intrusive_ptr<Expression> ExpressionCoerceToBool::optimize() {
690     /* optimize the operand */
691     pExpression = pExpression->optimize();
692 
693     /* if the operand already produces a boolean, then we don't need this */
694     /* LATER - Expression to support a "typeof" query? */
695     Expression* pE = pExpression.get();
696     if (dynamic_cast<ExpressionAnd*>(pE) || dynamic_cast<ExpressionOr*>(pE) ||
697         dynamic_cast<ExpressionNot*>(pE) || dynamic_cast<ExpressionCoerceToBool*>(pE))
698         return pExpression;
699 
700     return intrusive_ptr<Expression>(this);
701 }
702 
_doAddDependencies(DepsTracker * deps) const703 void ExpressionCoerceToBool::_doAddDependencies(DepsTracker* deps) const {
704     pExpression->addDependencies(deps);
705 }
706 
evaluate(const Document & root,Variables * variables) const707 Value ExpressionCoerceToBool::evaluate(const Document& root, Variables* variables) const {
708     Value pResult(pExpression->evaluate(root, variables));
709     bool b = pResult.coerceToBool();
710     if (b)
711         return Value(true);
712     return Value(false);
713 }
714 
serialize(bool explain) const715 Value ExpressionCoerceToBool::serialize(bool explain) const {
716     // When not explaining, serialize to an $and expression. When parsed, the $and expression
717     // will be optimized back into a ExpressionCoerceToBool.
718     const char* name = explain ? "$coerceToBool" : "$and";
719     return Value(DOC(name << DOC_ARRAY(pExpression->serialize(explain))));
720 }
721 
722 /* ----------------------- ExpressionCompare --------------------------- */
723 
724 REGISTER_EXPRESSION(cmp,
725                     stdx::bind(ExpressionCompare::parse,
726                                stdx::placeholders::_1,
727                                stdx::placeholders::_2,
728                                stdx::placeholders::_3,
729                                ExpressionCompare::CMP));
730 REGISTER_EXPRESSION(eq,
731                     stdx::bind(ExpressionCompare::parse,
732                                stdx::placeholders::_1,
733                                stdx::placeholders::_2,
734                                stdx::placeholders::_3,
735                                ExpressionCompare::EQ));
736 REGISTER_EXPRESSION(gt,
737                     stdx::bind(ExpressionCompare::parse,
738                                stdx::placeholders::_1,
739                                stdx::placeholders::_2,
740                                stdx::placeholders::_3,
741                                ExpressionCompare::GT));
742 REGISTER_EXPRESSION(gte,
743                     stdx::bind(ExpressionCompare::parse,
744                                stdx::placeholders::_1,
745                                stdx::placeholders::_2,
746                                stdx::placeholders::_3,
747                                ExpressionCompare::GTE));
748 REGISTER_EXPRESSION(lt,
749                     stdx::bind(ExpressionCompare::parse,
750                                stdx::placeholders::_1,
751                                stdx::placeholders::_2,
752                                stdx::placeholders::_3,
753                                ExpressionCompare::LT));
754 REGISTER_EXPRESSION(lte,
755                     stdx::bind(ExpressionCompare::parse,
756                                stdx::placeholders::_1,
757                                stdx::placeholders::_2,
758                                stdx::placeholders::_3,
759                                ExpressionCompare::LTE));
760 REGISTER_EXPRESSION(ne,
761                     stdx::bind(ExpressionCompare::parse,
762                                stdx::placeholders::_1,
763                                stdx::placeholders::_2,
764                                stdx::placeholders::_3,
765                                ExpressionCompare::NE));
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement bsonExpr,const VariablesParseState & vps,CmpOp op)766 intrusive_ptr<Expression> ExpressionCompare::parse(
767     const boost::intrusive_ptr<ExpressionContext>& expCtx,
768     BSONElement bsonExpr,
769     const VariablesParseState& vps,
770     CmpOp op) {
771     intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, op);
772     ExpressionVector args = parseArguments(expCtx, bsonExpr, vps);
773     expr->validateArguments(args);
774     expr->vpOperand = args;
775     return expr;
776 }
777 
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,CmpOp cmpOp,const boost::intrusive_ptr<Expression> & exprLeft,const boost::intrusive_ptr<Expression> & exprRight)778 boost::intrusive_ptr<ExpressionCompare> ExpressionCompare::create(
779     const boost::intrusive_ptr<ExpressionContext>& expCtx,
780     CmpOp cmpOp,
781     const boost::intrusive_ptr<Expression>& exprLeft,
782     const boost::intrusive_ptr<Expression>& exprRight) {
783     boost::intrusive_ptr<ExpressionCompare> expr = new ExpressionCompare(expCtx, cmpOp);
784     expr->vpOperand = {exprLeft, exprRight};
785     return expr;
786 }
787 
788 namespace {
789 // Lookup table for truth value returns
790 struct CmpLookup {
791     const bool truthValue[3];                // truth value for -1, 0, 1
792     const ExpressionCompare::CmpOp reverse;  // reverse(b,a) returns the same as op(a,b)
793     const char name[5];                      // string name with trailing '\0'
794 };
795 static const CmpLookup cmpLookup[7] = {
796     /*             -1      0      1      reverse                  name   */
797     /* EQ  */ {{false, true, false}, ExpressionCompare::EQ, "$eq"},
798     /* NE  */ {{true, false, true}, ExpressionCompare::NE, "$ne"},
799     /* GT  */ {{false, false, true}, ExpressionCompare::LT, "$gt"},
800     /* GTE */ {{false, true, true}, ExpressionCompare::LTE, "$gte"},
801     /* LT  */ {{true, false, false}, ExpressionCompare::GT, "$lt"},
802     /* LTE */ {{true, true, false}, ExpressionCompare::GTE, "$lte"},
803 
804     // CMP is special. Only name is used.
805     /* CMP */ {{false, false, false}, ExpressionCompare::CMP, "$cmp"},
806 };
807 }  // namespace
808 
809 
evaluate(const Document & root,Variables * variables) const810 Value ExpressionCompare::evaluate(const Document& root, Variables* variables) const {
811     Value pLeft(vpOperand[0]->evaluate(root, variables));
812     Value pRight(vpOperand[1]->evaluate(root, variables));
813 
814     int cmp = getExpressionContext()->getValueComparator().compare(pLeft, pRight);
815 
816     // Make cmp one of 1, 0, or -1.
817     if (cmp == 0) {
818         // leave as 0
819     } else if (cmp < 0) {
820         cmp = -1;
821     } else if (cmp > 0) {
822         cmp = 1;
823     }
824 
825     if (cmpOp == CMP)
826         return Value(cmp);
827 
828     bool returnValue = cmpLookup[cmpOp].truthValue[cmp + 1];
829     return Value(returnValue);
830 }
831 
getOpName() const832 const char* ExpressionCompare::getOpName() const {
833     return cmpLookup[cmpOp].name;
834 }
835 
836 /* ------------------------- ExpressionConcat ----------------------------- */
837 
evaluate(const Document & root,Variables * variables) const838 Value ExpressionConcat::evaluate(const Document& root, Variables* variables) const {
839     const size_t n = vpOperand.size();
840 
841     StringBuilder result;
842     for (size_t i = 0; i < n; ++i) {
843         Value val = vpOperand[i]->evaluate(root, variables);
844         if (val.nullish())
845             return Value(BSONNULL);
846 
847         uassert(16702,
848                 str::stream() << "$concat only supports strings, not " << typeName(val.getType()),
849                 val.getType() == String);
850 
851         result << val.coerceToString();
852     }
853 
854     return Value(result.str());
855 }
856 
857 REGISTER_EXPRESSION(concat, ExpressionConcat::parse);
getOpName() const858 const char* ExpressionConcat::getOpName() const {
859     return "$concat";
860 }
861 
862 /* ------------------------- ExpressionConcatArrays ----------------------------- */
863 
evaluate(const Document & root,Variables * variables) const864 Value ExpressionConcatArrays::evaluate(const Document& root, Variables* variables) const {
865     const size_t n = vpOperand.size();
866     vector<Value> values;
867 
868     for (size_t i = 0; i < n; ++i) {
869         Value val = vpOperand[i]->evaluate(root, variables);
870         if (val.nullish()) {
871             return Value(BSONNULL);
872         }
873 
874         uassert(28664,
875                 str::stream() << "$concatArrays only supports arrays, not "
876                               << typeName(val.getType()),
877                 val.isArray());
878 
879         const auto& subValues = val.getArray();
880         values.insert(values.end(), subValues.begin(), subValues.end());
881     }
882     return Value(std::move(values));
883 }
884 
885 REGISTER_EXPRESSION(concatArrays, ExpressionConcatArrays::parse);
getOpName() const886 const char* ExpressionConcatArrays::getOpName() const {
887     return "$concatArrays";
888 }
889 
890 /* ----------------------- ExpressionCond ------------------------------ */
891 
evaluate(const Document & root,Variables * variables) const892 Value ExpressionCond::evaluate(const Document& root, Variables* variables) const {
893     Value pCond(vpOperand[0]->evaluate(root, variables));
894     int idx = pCond.coerceToBool() ? 1 : 2;
895     return vpOperand[idx]->evaluate(root, variables);
896 }
897 
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)898 intrusive_ptr<Expression> ExpressionCond::parse(
899     const boost::intrusive_ptr<ExpressionContext>& expCtx,
900     BSONElement expr,
901     const VariablesParseState& vps) {
902     if (expr.type() != Object) {
903         return Base::parse(expCtx, expr, vps);
904     }
905     verify(str::equals(expr.fieldName(), "$cond"));
906 
907     intrusive_ptr<ExpressionCond> ret = new ExpressionCond(expCtx);
908     ret->vpOperand.resize(3);
909 
910     const BSONObj args = expr.embeddedObject();
911     BSONForEach(arg, args) {
912         if (str::equals(arg.fieldName(), "if")) {
913             ret->vpOperand[0] = parseOperand(expCtx, arg, vps);
914         } else if (str::equals(arg.fieldName(), "then")) {
915             ret->vpOperand[1] = parseOperand(expCtx, arg, vps);
916         } else if (str::equals(arg.fieldName(), "else")) {
917             ret->vpOperand[2] = parseOperand(expCtx, arg, vps);
918         } else {
919             uasserted(17083,
920                       str::stream() << "Unrecognized parameter to $cond: " << arg.fieldName());
921         }
922     }
923 
924     uassert(17080, "Missing 'if' parameter to $cond", ret->vpOperand[0]);
925     uassert(17081, "Missing 'then' parameter to $cond", ret->vpOperand[1]);
926     uassert(17082, "Missing 'else' parameter to $cond", ret->vpOperand[2]);
927 
928     return ret;
929 }
930 
931 REGISTER_EXPRESSION(cond, ExpressionCond::parse);
getOpName() const932 const char* ExpressionCond::getOpName() const {
933     return "$cond";
934 }
935 
936 /* ---------------------- ExpressionConstant --------------------------- */
937 
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement exprElement,const VariablesParseState & vps)938 intrusive_ptr<Expression> ExpressionConstant::parse(
939     const boost::intrusive_ptr<ExpressionContext>& expCtx,
940     BSONElement exprElement,
941     const VariablesParseState& vps) {
942     return new ExpressionConstant(expCtx, Value(exprElement));
943 }
944 
945 
create(const intrusive_ptr<ExpressionContext> & expCtx,const Value & value)946 intrusive_ptr<ExpressionConstant> ExpressionConstant::create(
947     const intrusive_ptr<ExpressionContext>& expCtx, const Value& value) {
948     intrusive_ptr<ExpressionConstant> pEC(new ExpressionConstant(expCtx, value));
949     return pEC;
950 }
951 
ExpressionConstant(const boost::intrusive_ptr<ExpressionContext> & expCtx,const Value & value)952 ExpressionConstant::ExpressionConstant(const boost::intrusive_ptr<ExpressionContext>& expCtx,
953                                        const Value& value)
954     : Expression(expCtx), _value(value) {}
955 
956 
optimize()957 intrusive_ptr<Expression> ExpressionConstant::optimize() {
958     /* nothing to do */
959     return intrusive_ptr<Expression>(this);
960 }
961 
_doAddDependencies(DepsTracker * deps) const962 void ExpressionConstant::_doAddDependencies(DepsTracker* deps) const {
963     /* nothing to do */
964 }
965 
evaluate(const Document & root,Variables * variables) const966 Value ExpressionConstant::evaluate(const Document& root, Variables* variables) const {
967     return _value;
968 }
969 
serialize(bool explain) const970 Value ExpressionConstant::serialize(bool explain) const {
971     return serializeConstant(_value);
972 }
973 
974 REGISTER_EXPRESSION(const, ExpressionConstant::parse);
975 REGISTER_EXPRESSION(literal, ExpressionConstant::parse);  // alias
getOpName() const976 const char* ExpressionConstant::getOpName() const {
977     return "$const";
978 }
979 
980 /* ---------------------- ExpressionDateFromParts ----------------------- */
981 
982 /* Helper functions also shared with ExpressionDateToParts */
983 
984 namespace {
985 
makeTimeZone(const TimeZoneDatabase * tzdb,const Document & root,const Expression * timeZone,Variables * variables)986 boost::optional<TimeZone> makeTimeZone(const TimeZoneDatabase* tzdb,
987                                        const Document& root,
988                                        const Expression* timeZone,
989                                        Variables* variables) {
990     invariant(tzdb);
991 
992     if (!timeZone) {
993         return mongo::TimeZoneDatabase::utcZone();
994     }
995 
996     auto timeZoneId = timeZone->evaluate(root, variables);
997 
998     if (timeZoneId.nullish()) {
999         return boost::none;
1000     }
1001 
1002     uassert(40517,
1003             str::stream() << "timezone must evaluate to a string, found "
1004                           << typeName(timeZoneId.getType()),
1005             timeZoneId.getType() == BSONType::String);
1006 
1007     return tzdb->getTimeZone(timeZoneId.getString());
1008 }
1009 
1010 }  // namespace
1011 
1012 
1013 REGISTER_EXPRESSION(dateFromParts, ExpressionDateFromParts::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1014 intrusive_ptr<Expression> ExpressionDateFromParts::parse(
1015     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1016     BSONElement expr,
1017     const VariablesParseState& vps) {
1018 
1019     uassert(40519,
1020             "$dateFromParts only supports an object as its argument",
1021             expr.type() == BSONType::Object);
1022 
1023     BSONElement yearElem;
1024     BSONElement monthElem;
1025     BSONElement dayElem;
1026     BSONElement hourElem;
1027     BSONElement minuteElem;
1028     BSONElement secondElem;
1029     BSONElement millisecondElem;
1030     BSONElement isoWeekYearElem;
1031     BSONElement isoWeekElem;
1032     BSONElement isoDayOfWeekElem;
1033     BSONElement timeZoneElem;
1034 
1035     const BSONObj args = expr.embeddedObject();
1036     for (auto&& arg : args) {
1037         auto field = arg.fieldNameStringData();
1038 
1039         if (field == "year"_sd) {
1040             yearElem = arg;
1041         } else if (field == "month"_sd) {
1042             monthElem = arg;
1043         } else if (field == "day"_sd) {
1044             dayElem = arg;
1045         } else if (field == "hour"_sd) {
1046             hourElem = arg;
1047         } else if (field == "minute"_sd) {
1048             minuteElem = arg;
1049         } else if (field == "second"_sd) {
1050             secondElem = arg;
1051         } else if (field == "millisecond"_sd) {
1052             millisecondElem = arg;
1053         } else if (field == "isoWeekYear"_sd) {
1054             isoWeekYearElem = arg;
1055         } else if (field == "isoWeek"_sd) {
1056             isoWeekElem = arg;
1057         } else if (field == "isoDayOfWeek"_sd) {
1058             isoDayOfWeekElem = arg;
1059         } else if (field == "timezone"_sd) {
1060             timeZoneElem = arg;
1061         } else {
1062             uasserted(40518,
1063                       str::stream() << "Unrecognized argument to $dateFromParts: "
1064                                     << arg.fieldName());
1065         }
1066     }
1067 
1068     if (!yearElem && !isoWeekYearElem) {
1069         uasserted(40516, "$dateFromParts requires either 'year' or 'isoWeekYear' to be present");
1070     }
1071 
1072     if (yearElem && (isoWeekYearElem || isoWeekElem || isoDayOfWeekElem)) {
1073         uasserted(40489, "$dateFromParts does not allow mixing natural dates with ISO dates");
1074     }
1075 
1076     if (isoWeekYearElem && (yearElem || monthElem || dayElem)) {
1077         uasserted(40525, "$dateFromParts does not allow mixing ISO dates with natural dates");
1078     }
1079 
1080     return new ExpressionDateFromParts(
1081         expCtx,
1082         yearElem ? parseOperand(expCtx, yearElem, vps) : nullptr,
1083         monthElem ? parseOperand(expCtx, monthElem, vps) : nullptr,
1084         dayElem ? parseOperand(expCtx, dayElem, vps) : nullptr,
1085         hourElem ? parseOperand(expCtx, hourElem, vps) : nullptr,
1086         minuteElem ? parseOperand(expCtx, minuteElem, vps) : nullptr,
1087         secondElem ? parseOperand(expCtx, secondElem, vps) : nullptr,
1088         millisecondElem ? parseOperand(expCtx, millisecondElem, vps) : nullptr,
1089         isoWeekYearElem ? parseOperand(expCtx, isoWeekYearElem, vps) : nullptr,
1090         isoWeekElem ? parseOperand(expCtx, isoWeekElem, vps) : nullptr,
1091         isoDayOfWeekElem ? parseOperand(expCtx, isoDayOfWeekElem, vps) : nullptr,
1092         timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps) : nullptr);
1093 }
1094 
ExpressionDateFromParts(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> year,intrusive_ptr<Expression> month,intrusive_ptr<Expression> day,intrusive_ptr<Expression> hour,intrusive_ptr<Expression> minute,intrusive_ptr<Expression> second,intrusive_ptr<Expression> millisecond,intrusive_ptr<Expression> isoWeekYear,intrusive_ptr<Expression> isoWeek,intrusive_ptr<Expression> isoDayOfWeek,intrusive_ptr<Expression> timeZone)1095 ExpressionDateFromParts::ExpressionDateFromParts(
1096     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1097     intrusive_ptr<Expression> year,
1098     intrusive_ptr<Expression> month,
1099     intrusive_ptr<Expression> day,
1100     intrusive_ptr<Expression> hour,
1101     intrusive_ptr<Expression> minute,
1102     intrusive_ptr<Expression> second,
1103     intrusive_ptr<Expression> millisecond,
1104     intrusive_ptr<Expression> isoWeekYear,
1105     intrusive_ptr<Expression> isoWeek,
1106     intrusive_ptr<Expression> isoDayOfWeek,
1107     intrusive_ptr<Expression> timeZone)
1108     : Expression(expCtx),
1109       _year(std::move(year)),
1110       _month(std::move(month)),
1111       _day(std::move(day)),
1112       _hour(std::move(hour)),
1113       _minute(std::move(minute)),
1114       _second(std::move(second)),
1115       _millisecond(std::move(millisecond)),
1116       _isoWeekYear(std::move(isoWeekYear)),
1117       _isoWeek(std::move(isoWeek)),
1118       _isoDayOfWeek(std::move(isoDayOfWeek)),
1119       _timeZone(std::move(timeZone)) {}
1120 
optimize()1121 intrusive_ptr<Expression> ExpressionDateFromParts::optimize() {
1122     if (_year) {
1123         _year = _year->optimize();
1124     }
1125     if (_month) {
1126         _month = _month->optimize();
1127     }
1128     if (_day) {
1129         _day = _day->optimize();
1130     }
1131     if (_hour) {
1132         _hour = _hour->optimize();
1133     }
1134     if (_minute) {
1135         _minute = _minute->optimize();
1136     }
1137     if (_second) {
1138         _second = _second->optimize();
1139     }
1140     if (_millisecond) {
1141         _millisecond = _millisecond->optimize();
1142     }
1143     if (_isoWeekYear) {
1144         _isoWeekYear = _isoWeekYear->optimize();
1145     }
1146     if (_isoWeek) {
1147         _isoWeek = _isoWeek->optimize();
1148     }
1149     if (_isoDayOfWeek) {
1150         _isoDayOfWeek = _isoDayOfWeek->optimize();
1151     }
1152     if (_timeZone) {
1153         _timeZone = _timeZone->optimize();
1154     }
1155 
1156     if (ExpressionConstant::allNullOrConstant({_year,
1157                                                _month,
1158                                                _day,
1159                                                _hour,
1160                                                _minute,
1161                                                _second,
1162                                                _millisecond,
1163                                                _isoWeekYear,
1164                                                _isoWeek,
1165                                                _isoDayOfWeek,
1166                                                _timeZone})) {
1167 
1168         // Everything is a constant, so we can turn into a constant.
1169         return ExpressionConstant::create(
1170             getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1171     }
1172 
1173     return this;
1174 }
1175 
serialize(bool explain) const1176 Value ExpressionDateFromParts::serialize(bool explain) const {
1177     return Value(Document{
1178         {"$dateFromParts",
1179          Document{{"year", _year ? _year->serialize(explain) : Value()},
1180                   {"month", _month ? _month->serialize(explain) : Value()},
1181                   {"day", _day ? _day->serialize(explain) : Value()},
1182                   {"hour", _hour ? _hour->serialize(explain) : Value()},
1183                   {"minute", _minute ? _minute->serialize(explain) : Value()},
1184                   {"second", _second ? _second->serialize(explain) : Value()},
1185                   {"millisecond", _millisecond ? _millisecond->serialize(explain) : Value()},
1186                   {"isoWeekYear", _isoWeekYear ? _isoWeekYear->serialize(explain) : Value()},
1187                   {"isoWeek", _isoWeek ? _isoWeek->serialize(explain) : Value()},
1188                   {"isoDayOfWeek", _isoDayOfWeek ? _isoDayOfWeek->serialize(explain) : Value()},
1189                   {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1190 }
1191 
1192 /**
1193  * This function checks whether a field is a number, and fits in the given range.
1194  *
1195  * If the field does not exist, the default value is returned trough the returnValue out parameter
1196  * and the function returns true.
1197  *
1198  * If the field exists:
1199  * - if the value is "nullish", the function returns false, so that the calling function can return
1200  *   a BSONNULL value.
1201  * - if the value can not be coerced to an integral value, an exception is returned.
1202  * - if the value is out of the range [minValue..maxValue], an exception is returned.
1203  * - otherwise, the coerced integral value is returned through the returnValue
1204  *   out parameter, and the function returns true.
1205  */
evaluateNumberWithinRange(const Document & root,const Expression * field,StringData fieldName,int defaultValue,int minValue,int maxValue,int * returnValue,Variables * variables) const1206 bool ExpressionDateFromParts::evaluateNumberWithinRange(const Document& root,
1207                                                         const Expression* field,
1208                                                         StringData fieldName,
1209                                                         int defaultValue,
1210                                                         int minValue,
1211                                                         int maxValue,
1212                                                         int* returnValue,
1213                                                         Variables* variables) const {
1214     if (!field) {
1215         *returnValue = defaultValue;
1216         return true;
1217     }
1218 
1219     auto fieldValue = field->evaluate(root, variables);
1220 
1221     if (fieldValue.nullish()) {
1222         return false;
1223     }
1224 
1225     uassert(40515,
1226             str::stream() << "'" << fieldName << "' must evaluate to an integer, found "
1227                           << typeName(fieldValue.getType())
1228                           << " with value "
1229                           << fieldValue.toString(),
1230             fieldValue.integral());
1231 
1232     *returnValue = fieldValue.coerceToInt();
1233 
1234     uassert(40523,
1235             str::stream() << "'" << fieldName << "' must evaluate to an integer in the range "
1236                           << minValue
1237                           << " to "
1238                           << maxValue
1239                           << ", found "
1240                           << *returnValue,
1241             *returnValue >= minValue && *returnValue <= maxValue);
1242 
1243     return true;
1244 }
1245 
evaluate(const Document & root,Variables * variables) const1246 Value ExpressionDateFromParts::evaluate(const Document& root, Variables* variables) const {
1247     int hour, minute, second, millisecond;
1248 
1249     if (!evaluateNumberWithinRange(root, _hour.get(), "hour"_sd, 0, 0, 24, &hour, variables) ||
1250         !evaluateNumberWithinRange(
1251             root, _minute.get(), "minute"_sd, 0, 0, 59, &minute, variables) ||
1252         !evaluateNumberWithinRange(
1253             root, _second.get(), "second"_sd, 0, 0, 59, &second, variables) ||
1254         !evaluateNumberWithinRange(
1255             root, _millisecond.get(), "millisecond"_sd, 0, 0, 999, &millisecond, variables)) {
1256         return Value(BSONNULL);
1257     }
1258 
1259     auto timeZone =
1260         makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1261 
1262     if (!timeZone) {
1263         return Value(BSONNULL);
1264     }
1265 
1266     if (_year) {
1267         int year, month, day;
1268 
1269         if (!evaluateNumberWithinRange(
1270                 root, _year.get(), "year"_sd, 1970, 0, 9999, &year, variables) ||
1271             !evaluateNumberWithinRange(
1272                 root, _month.get(), "month"_sd, 1, 1, 12, &month, variables) ||
1273             !evaluateNumberWithinRange(root, _day.get(), "day"_sd, 1, 1, 31, &day, variables)) {
1274             return Value(BSONNULL);
1275         }
1276 
1277         return Value(
1278             timeZone->createFromDateParts(year, month, day, hour, minute, second, millisecond));
1279     }
1280 
1281     if (_isoWeekYear) {
1282         int isoWeekYear, isoWeek, isoDayOfWeek;
1283 
1284         if (!evaluateNumberWithinRange(root,
1285                                        _isoWeekYear.get(),
1286                                        "isoWeekYear"_sd,
1287                                        1970,
1288                                        0,
1289                                        9999,
1290                                        &isoWeekYear,
1291                                        variables) ||
1292             !evaluateNumberWithinRange(
1293                 root, _isoWeek.get(), "isoWeek"_sd, 1, 1, 53, &isoWeek, variables) ||
1294             !evaluateNumberWithinRange(
1295                 root, _isoDayOfWeek.get(), "isoDayOfWeek"_sd, 1, 1, 7, &isoDayOfWeek, variables)) {
1296             return Value(BSONNULL);
1297         }
1298 
1299         return Value(timeZone->createFromIso8601DateParts(
1300             isoWeekYear, isoWeek, isoDayOfWeek, hour, minute, second, millisecond));
1301     }
1302 
1303     MONGO_UNREACHABLE;
1304 }
1305 
_doAddDependencies(DepsTracker * deps) const1306 void ExpressionDateFromParts::_doAddDependencies(DepsTracker* deps) const {
1307     if (_year) {
1308         _year->addDependencies(deps);
1309     }
1310     if (_month) {
1311         _month->addDependencies(deps);
1312     }
1313     if (_day) {
1314         _day->addDependencies(deps);
1315     }
1316     if (_hour) {
1317         _hour->addDependencies(deps);
1318     }
1319     if (_minute) {
1320         _minute->addDependencies(deps);
1321     }
1322     if (_second) {
1323         _second->addDependencies(deps);
1324     }
1325     if (_millisecond) {
1326         _millisecond->addDependencies(deps);
1327     }
1328     if (_isoWeekYear) {
1329         _isoWeekYear->addDependencies(deps);
1330     }
1331     if (_isoWeek) {
1332         _isoWeek->addDependencies(deps);
1333     }
1334     if (_isoDayOfWeek) {
1335         _isoDayOfWeek->addDependencies(deps);
1336     }
1337     if (_timeZone) {
1338         _timeZone->addDependencies(deps);
1339     }
1340 }
1341 
1342 /* ---------------------- ExpressionDateFromString --------------------- */
1343 
1344 REGISTER_EXPRESSION(dateFromString, ExpressionDateFromString::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1345 intrusive_ptr<Expression> ExpressionDateFromString::parse(
1346     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1347     BSONElement expr,
1348     const VariablesParseState& vps) {
1349 
1350     uassert(40540,
1351             str::stream() << "$dateFromString only supports an object as an argument, found: "
1352                           << typeName(expr.type()),
1353             expr.type() == BSONType::Object);
1354 
1355     BSONElement dateStringElem;
1356     BSONElement timeZoneElem;
1357 
1358     const BSONObj args = expr.embeddedObject();
1359     for (auto&& arg : args) {
1360         auto field = arg.fieldNameStringData();
1361 
1362         if (field == "dateString"_sd) {
1363             dateStringElem = arg;
1364         } else if (field == "timezone"_sd) {
1365             timeZoneElem = arg;
1366         } else {
1367             uasserted(40541,
1368                       str::stream() << "Unrecognized argument to $dateFromString: "
1369                                     << arg.fieldName());
1370         }
1371     }
1372 
1373     uassert(40542, "Missing 'dateString' parameter to $dateFromString", dateStringElem);
1374 
1375     return new ExpressionDateFromString(expCtx,
1376                                         parseOperand(expCtx, dateStringElem, vps),
1377                                         timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps)
1378                                                      : nullptr);
1379 }
1380 
ExpressionDateFromString(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> dateString,intrusive_ptr<Expression> timeZone)1381 ExpressionDateFromString::ExpressionDateFromString(
1382     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1383     intrusive_ptr<Expression> dateString,
1384     intrusive_ptr<Expression> timeZone)
1385     : Expression(expCtx), _dateString(std::move(dateString)), _timeZone(std::move(timeZone)) {}
1386 
optimize()1387 intrusive_ptr<Expression> ExpressionDateFromString::optimize() {
1388     _dateString = _dateString->optimize();
1389     if (_timeZone) {
1390         _timeZone = _timeZone->optimize();
1391     }
1392 
1393     if (ExpressionConstant::allNullOrConstant({_dateString, _timeZone})) {
1394         // Everything is a constant, so we can turn into a constant.
1395         return ExpressionConstant::create(
1396             getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1397     }
1398     return this;
1399 }
1400 
serialize(bool explain) const1401 Value ExpressionDateFromString::serialize(bool explain) const {
1402     return Value(
1403         Document{{"$dateFromString",
1404                   Document{{"dateString", _dateString->serialize(explain)},
1405                            {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1406 }
1407 
evaluate(const Document & root,Variables * variables) const1408 Value ExpressionDateFromString::evaluate(const Document& root, Variables* variables) const {
1409     const Value dateString = _dateString->evaluate(root, variables);
1410 
1411     auto timeZone =
1412         makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1413 
1414     if (!timeZone || dateString.nullish()) {
1415         return Value(BSONNULL);
1416     }
1417 
1418     uassert(40543,
1419             str::stream() << "$dateFromString requires that 'dateString' be a string, found: "
1420                           << typeName(dateString.getType())
1421                           << " with value "
1422                           << dateString.toString(),
1423             dateString.getType() == BSONType::String);
1424     const std::string& dateTimeString = dateString.getString();
1425 
1426     return Value(getExpressionContext()->timeZoneDatabase->fromString(dateTimeString, timeZone));
1427 }
1428 
_doAddDependencies(DepsTracker * deps) const1429 void ExpressionDateFromString::_doAddDependencies(DepsTracker* deps) const {
1430     _dateString->addDependencies(deps);
1431     if (_timeZone) {
1432         _timeZone->addDependencies(deps);
1433     }
1434 }
1435 
1436 /* ---------------------- ExpressionDateToParts ----------------------- */
1437 
1438 REGISTER_EXPRESSION(dateToParts, ExpressionDateToParts::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1439 intrusive_ptr<Expression> ExpressionDateToParts::parse(
1440     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1441     BSONElement expr,
1442     const VariablesParseState& vps) {
1443 
1444     uassert(40524,
1445             "$dateToParts only supports an object as its argument",
1446             expr.type() == BSONType::Object);
1447 
1448     BSONElement dateElem;
1449     BSONElement timeZoneElem;
1450     BSONElement isoDateElem;
1451 
1452     const BSONObj args = expr.embeddedObject();
1453     for (auto&& arg : args) {
1454         auto field = arg.fieldNameStringData();
1455 
1456         if (field == "date"_sd) {
1457             dateElem = arg;
1458         } else if (field == "timezone"_sd) {
1459             timeZoneElem = arg;
1460         } else if (field == "iso8601"_sd) {
1461             isoDateElem = arg;
1462         } else {
1463             uasserted(40520,
1464                       str::stream() << "Unrecognized argument to $dateToParts: "
1465                                     << arg.fieldName());
1466         }
1467     }
1468 
1469     uassert(40522, "Missing 'date' parameter to $dateToParts", dateElem);
1470 
1471     return new ExpressionDateToParts(
1472         expCtx,
1473         parseOperand(expCtx, dateElem, vps),
1474         timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps) : nullptr,
1475         isoDateElem ? parseOperand(expCtx, isoDateElem, vps) : nullptr);
1476 }
1477 
ExpressionDateToParts(const boost::intrusive_ptr<ExpressionContext> & expCtx,intrusive_ptr<Expression> date,intrusive_ptr<Expression> timeZone,intrusive_ptr<Expression> iso8601)1478 ExpressionDateToParts::ExpressionDateToParts(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1479                                              intrusive_ptr<Expression> date,
1480                                              intrusive_ptr<Expression> timeZone,
1481                                              intrusive_ptr<Expression> iso8601)
1482     : Expression(expCtx),
1483       _date(std::move(date)),
1484       _timeZone(std::move(timeZone)),
1485       _iso8601(std::move(iso8601)) {}
1486 
optimize()1487 intrusive_ptr<Expression> ExpressionDateToParts::optimize() {
1488     _date = _date->optimize();
1489     if (_timeZone) {
1490         _timeZone = _timeZone->optimize();
1491     }
1492     if (_iso8601) {
1493         _iso8601 = _iso8601->optimize();
1494     }
1495 
1496     if (ExpressionConstant::allNullOrConstant({_date, _iso8601, _timeZone})) {
1497         // Everything is a constant, so we can turn into a constant.
1498         return ExpressionConstant::create(
1499             getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1500     }
1501 
1502     return this;
1503 }
1504 
serialize(bool explain) const1505 Value ExpressionDateToParts::serialize(bool explain) const {
1506     return Value(
1507         Document{{"$dateToParts",
1508                   Document{{"date", _date->serialize(explain)},
1509                            {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()},
1510                            {"iso8601", _iso8601 ? _iso8601->serialize(explain) : Value()}}}});
1511 }
1512 
evaluateIso8601Flag(const Document & root,Variables * variables) const1513 boost::optional<int> ExpressionDateToParts::evaluateIso8601Flag(const Document& root,
1514                                                                 Variables* variables) const {
1515     if (!_iso8601) {
1516         return false;
1517     }
1518 
1519     auto iso8601Output = _iso8601->evaluate(root, variables);
1520 
1521     if (iso8601Output.nullish()) {
1522         return boost::none;
1523     }
1524 
1525     uassert(40521,
1526             str::stream() << "iso8601 must evaluate to a bool, found "
1527                           << typeName(iso8601Output.getType()),
1528             iso8601Output.getType() == BSONType::Bool);
1529 
1530     return iso8601Output.getBool();
1531 }
1532 
evaluate(const Document & root,Variables * variables) const1533 Value ExpressionDateToParts::evaluate(const Document& root, Variables* variables) const {
1534     const Value date = _date->evaluate(root, variables);
1535 
1536     auto timeZone =
1537         makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1538     if (!timeZone) {
1539         return Value(BSONNULL);
1540     }
1541 
1542     auto iso8601 = evaluateIso8601Flag(root, variables);
1543     if (!iso8601) {
1544         return Value(BSONNULL);
1545     }
1546 
1547     if (date.nullish()) {
1548         return Value(BSONNULL);
1549     }
1550 
1551     auto dateValue = date.coerceToDate();
1552 
1553     if (*iso8601) {
1554         auto parts = timeZone->dateIso8601Parts(dateValue);
1555         return Value(Document{{"isoWeekYear", parts.year},
1556                               {"isoWeek", parts.weekOfYear},
1557                               {"isoDayOfWeek", parts.dayOfWeek},
1558                               {"hour", parts.hour},
1559                               {"minute", parts.minute},
1560                               {"second", parts.second},
1561                               {"millisecond", parts.millisecond}});
1562     } else {
1563         auto parts = timeZone->dateParts(dateValue);
1564         return Value(Document{{"year", parts.year},
1565                               {"month", parts.month},
1566                               {"day", parts.dayOfMonth},
1567                               {"hour", parts.hour},
1568                               {"minute", parts.minute},
1569                               {"second", parts.second},
1570                               {"millisecond", parts.millisecond}});
1571     }
1572 }
1573 
_doAddDependencies(DepsTracker * deps) const1574 void ExpressionDateToParts::_doAddDependencies(DepsTracker* deps) const {
1575     _date->addDependencies(deps);
1576     if (_timeZone) {
1577         _timeZone->addDependencies(deps);
1578     }
1579     if (_iso8601) {
1580         _iso8601->addDependencies(deps);
1581     }
1582 }
1583 
1584 
1585 /* ---------------------- ExpressionDateToString ----------------------- */
1586 
1587 REGISTER_EXPRESSION(dateToString, ExpressionDateToString::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)1588 intrusive_ptr<Expression> ExpressionDateToString::parse(
1589     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1590     BSONElement expr,
1591     const VariablesParseState& vps) {
1592     verify(str::equals(expr.fieldName(), "$dateToString"));
1593 
1594     uassert(18629, "$dateToString only supports an object as its argument", expr.type() == Object);
1595 
1596     BSONElement formatElem;
1597     BSONElement dateElem;
1598     BSONElement timeZoneElem;
1599     const BSONObj args = expr.embeddedObject();
1600     BSONForEach(arg, args) {
1601         if (str::equals(arg.fieldName(), "format")) {
1602             formatElem = arg;
1603         } else if (str::equals(arg.fieldName(), "date")) {
1604             dateElem = arg;
1605         } else if (str::equals(arg.fieldName(), "timezone")) {
1606             timeZoneElem = arg;
1607         } else {
1608             uasserted(18534,
1609                       str::stream() << "Unrecognized argument to $dateToString: "
1610                                     << arg.fieldName());
1611         }
1612     }
1613 
1614     uassert(18627, "Missing 'format' parameter to $dateToString", !formatElem.eoo());
1615     uassert(18628, "Missing 'date' parameter to $dateToString", !dateElem.eoo());
1616 
1617     uassert(18533,
1618             "The 'format' parameter to $dateToString must be a string literal",
1619             formatElem.type() == String);
1620 
1621     const string format = formatElem.str();
1622 
1623     TimeZone::validateFormat(format);
1624 
1625     return new ExpressionDateToString(expCtx,
1626                                       format,
1627                                       parseOperand(expCtx, dateElem, vps),
1628                                       timeZoneElem ? parseOperand(expCtx, timeZoneElem, vps)
1629                                                    : nullptr);
1630 }
1631 
ExpressionDateToString(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & format,intrusive_ptr<Expression> date,intrusive_ptr<Expression> timeZone)1632 ExpressionDateToString::ExpressionDateToString(
1633     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1634     const string& format,
1635     intrusive_ptr<Expression> date,
1636     intrusive_ptr<Expression> timeZone)
1637     : Expression(expCtx), _format(format), _date(std::move(date)), _timeZone(std::move(timeZone)) {}
1638 
optimize()1639 intrusive_ptr<Expression> ExpressionDateToString::optimize() {
1640     _date = _date->optimize();
1641     if (_timeZone) {
1642         _timeZone = _timeZone->optimize();
1643     }
1644 
1645     if (ExpressionConstant::allNullOrConstant({_date, _timeZone})) {
1646         // Everything is a constant, so we can turn into a constant.
1647         return ExpressionConstant::create(
1648             getExpressionContext(), evaluate(Document{}, &(getExpressionContext()->variables)));
1649     }
1650 
1651     return this;
1652 }
1653 
serialize(bool explain) const1654 Value ExpressionDateToString::serialize(bool explain) const {
1655     return Value(
1656         Document{{"$dateToString",
1657                   Document{{"format", _format},
1658                            {"date", _date->serialize(explain)},
1659                            {"timezone", _timeZone ? _timeZone->serialize(explain) : Value()}}}});
1660 }
1661 
evaluate(const Document & root,Variables * variables) const1662 Value ExpressionDateToString::evaluate(const Document& root, Variables* variables) const {
1663     const Value date = _date->evaluate(root, variables);
1664 
1665     auto timeZone =
1666         makeTimeZone(getExpressionContext()->timeZoneDatabase, root, _timeZone.get(), variables);
1667     if (!timeZone) {
1668         return Value(BSONNULL);
1669     }
1670 
1671     if (date.nullish()) {
1672         return Value(BSONNULL);
1673     }
1674 
1675     return Value(timeZone->formatDate(_format, date.coerceToDate()));
1676 }
1677 
_doAddDependencies(DepsTracker * deps) const1678 void ExpressionDateToString::_doAddDependencies(DepsTracker* deps) const {
1679     _date->addDependencies(deps);
1680     if (_timeZone) {
1681         _timeZone->addDependencies(deps);
1682     }
1683 }
1684 
1685 /* ----------------------- ExpressionDivide ---------------------------- */
1686 
evaluate(const Document & root,Variables * variables) const1687 Value ExpressionDivide::evaluate(const Document& root, Variables* variables) const {
1688     Value lhs = vpOperand[0]->evaluate(root, variables);
1689     Value rhs = vpOperand[1]->evaluate(root, variables);
1690 
1691     auto assertNonZero = [](bool nonZero) { uassert(16608, "can't $divide by zero", nonZero); };
1692 
1693     if (lhs.numeric() && rhs.numeric()) {
1694         // If, and only if, either side is decimal, return decimal.
1695         if (lhs.getType() == NumberDecimal || rhs.getType() == NumberDecimal) {
1696             Decimal128 numer = lhs.coerceToDecimal();
1697             Decimal128 denom = rhs.coerceToDecimal();
1698             assertNonZero(!denom.isZero());
1699             return Value(numer.divide(denom));
1700         }
1701 
1702         double numer = lhs.coerceToDouble();
1703         double denom = rhs.coerceToDouble();
1704         assertNonZero(denom != 0.0);
1705 
1706         return Value(numer / denom);
1707     } else if (lhs.nullish() || rhs.nullish()) {
1708         return Value(BSONNULL);
1709     } else {
1710         uasserted(16609,
1711                   str::stream() << "$divide only supports numeric types, not "
1712                                 << typeName(lhs.getType())
1713                                 << " and "
1714                                 << typeName(rhs.getType()));
1715     }
1716 }
1717 
1718 REGISTER_EXPRESSION(divide, ExpressionDivide::parse);
getOpName() const1719 const char* ExpressionDivide::getOpName() const {
1720     return "$divide";
1721 }
1722 
1723 /* ----------------------- ExpressionExp ---------------------------- */
1724 
evaluateNumericArg(const Value & numericArg) const1725 Value ExpressionExp::evaluateNumericArg(const Value& numericArg) const {
1726     // $exp always returns either a double or a decimal number, as e is irrational.
1727     if (numericArg.getType() == NumberDecimal)
1728         return Value(numericArg.coerceToDecimal().exponential());
1729 
1730     return Value(exp(numericArg.coerceToDouble()));
1731 }
1732 
1733 REGISTER_EXPRESSION(exp, ExpressionExp::parse);
getOpName() const1734 const char* ExpressionExp::getOpName() const {
1735     return "$exp";
1736 }
1737 
1738 /* ---------------------- ExpressionObject --------------------------- */
1739 
ExpressionObject(const boost::intrusive_ptr<ExpressionContext> & expCtx,vector<pair<string,intrusive_ptr<Expression>>> && expressions)1740 ExpressionObject::ExpressionObject(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1741                                    vector<pair<string, intrusive_ptr<Expression>>>&& expressions)
1742     : Expression(expCtx), _expressions(std::move(expressions)) {}
1743 
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,vector<pair<string,intrusive_ptr<Expression>>> && expressions)1744 intrusive_ptr<ExpressionObject> ExpressionObject::create(
1745     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1746     vector<pair<string, intrusive_ptr<Expression>>>&& expressions) {
1747     return new ExpressionObject(expCtx, std::move(expressions));
1748 }
1749 
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONObj obj,const VariablesParseState & vps)1750 intrusive_ptr<ExpressionObject> ExpressionObject::parse(
1751     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1752     BSONObj obj,
1753     const VariablesParseState& vps) {
1754     // Make sure we don't have any duplicate field names.
1755     stdx::unordered_set<string> specifiedFields;
1756 
1757     vector<pair<string, intrusive_ptr<Expression>>> expressions;
1758     for (auto&& elem : obj) {
1759         // Make sure this element has a valid field name. Use StringData here so that we can detect
1760         // if the field name contains a null byte.
1761         FieldPath::uassertValidFieldName(elem.fieldNameStringData());
1762 
1763         auto fieldName = elem.fieldName();
1764         uassert(16406,
1765                 str::stream() << "duplicate field name specified in object literal: "
1766                               << obj.toString(),
1767                 specifiedFields.find(fieldName) == specifiedFields.end());
1768         specifiedFields.insert(fieldName);
1769         expressions.emplace_back(fieldName, parseOperand(expCtx, elem, vps));
1770     }
1771 
1772     return new ExpressionObject{expCtx, std::move(expressions)};
1773 }
1774 
optimize()1775 intrusive_ptr<Expression> ExpressionObject::optimize() {
1776     for (auto&& pair : _expressions) {
1777         pair.second = pair.second->optimize();
1778     }
1779     return this;
1780 }
1781 
_doAddDependencies(DepsTracker * deps) const1782 void ExpressionObject::_doAddDependencies(DepsTracker* deps) const {
1783     for (auto&& pair : _expressions) {
1784         pair.second->addDependencies(deps);
1785     }
1786 }
1787 
evaluate(const Document & root,Variables * variables) const1788 Value ExpressionObject::evaluate(const Document& root, Variables* variables) const {
1789     MutableDocument outputDoc;
1790     for (auto&& pair : _expressions) {
1791         outputDoc.addField(pair.first, pair.second->evaluate(root, variables));
1792     }
1793     return outputDoc.freezeToValue();
1794 }
1795 
serialize(bool explain) const1796 Value ExpressionObject::serialize(bool explain) const {
1797     MutableDocument outputDoc;
1798     for (auto&& pair : _expressions) {
1799         outputDoc.addField(pair.first, pair.second->serialize(explain));
1800     }
1801     return outputDoc.freezeToValue();
1802 }
1803 
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const1804 Expression::ComputedPaths ExpressionObject::getComputedPaths(const std::string& exprFieldPath,
1805                                                              Variables::Id renamingVar) const {
1806     ComputedPaths outputPaths;
1807     for (auto&& pair : _expressions) {
1808         auto exprComputedPaths = pair.second->getComputedPaths(pair.first, renamingVar);
1809         for (auto&& renames : exprComputedPaths.renames) {
1810             auto newPath = FieldPath::getFullyQualifiedPath(exprFieldPath, renames.first);
1811             outputPaths.renames[std::move(newPath)] = renames.second;
1812         }
1813         for (auto&& path : exprComputedPaths.paths) {
1814             outputPaths.paths.insert(FieldPath::getFullyQualifiedPath(exprFieldPath, path));
1815         }
1816     }
1817 
1818     return outputPaths;
1819 }
1820 
1821 /* --------------------- ExpressionFieldPath --------------------------- */
1822 
1823 // this is the old deprecated version only used by tests not using variables
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & fieldPath)1824 intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::create(
1825     const boost::intrusive_ptr<ExpressionContext>& expCtx, const string& fieldPath) {
1826     return new ExpressionFieldPath(expCtx, "CURRENT." + fieldPath, Variables::kRootId);
1827 }
1828 
1829 // this is the new version that supports every syntax
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & raw,const VariablesParseState & vps)1830 intrusive_ptr<ExpressionFieldPath> ExpressionFieldPath::parse(
1831     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1832     const string& raw,
1833     const VariablesParseState& vps) {
1834     uassert(16873,
1835             str::stream() << "FieldPath '" << raw << "' doesn't start with $",
1836             raw.c_str()[0] == '$');  // c_str()[0] is always a valid reference.
1837 
1838     uassert(16872,
1839             str::stream() << "'$' by itself is not a valid FieldPath",
1840             raw.size() >= 2);  // need at least "$" and either "$" or a field name
1841 
1842     if (raw[1] == '$') {
1843         const StringData rawSD = raw;
1844         const StringData fieldPath = rawSD.substr(2);  // strip off $$
1845         const StringData varName = fieldPath.substr(0, fieldPath.find('.'));
1846         Variables::uassertValidNameForUserRead(varName);
1847         return new ExpressionFieldPath(expCtx, fieldPath.toString(), vps.getVariable(varName));
1848     } else {
1849         return new ExpressionFieldPath(expCtx,
1850                                        "CURRENT." + raw.substr(1),  // strip the "$" prefix
1851                                        vps.getVariable("CURRENT"));
1852     }
1853 }
1854 
ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & theFieldPath,Variables::Id variable)1855 ExpressionFieldPath::ExpressionFieldPath(const boost::intrusive_ptr<ExpressionContext>& expCtx,
1856                                          const string& theFieldPath,
1857                                          Variables::Id variable)
1858     : Expression(expCtx), _fieldPath(theFieldPath), _variable(variable) {}
1859 
optimize()1860 intrusive_ptr<Expression> ExpressionFieldPath::optimize() {
1861     if (_variable == Variables::kRemoveId) {
1862         // The REMOVE system variable optimizes to a constant missing value.
1863         return ExpressionConstant::create(getExpressionContext(), Value());
1864     }
1865 
1866     if (getExpressionContext()->variables.hasConstantValue(_variable)) {
1867         return ExpressionConstant::create(
1868             getExpressionContext(), evaluate(Document(), &(getExpressionContext()->variables)));
1869     }
1870 
1871     return intrusive_ptr<Expression>(this);
1872 }
1873 
_doAddDependencies(DepsTracker * deps) const1874 void ExpressionFieldPath::_doAddDependencies(DepsTracker* deps) const {
1875     if (_variable == Variables::kRootId) {  // includes CURRENT when it is equivalent to ROOT.
1876         if (_fieldPath.getPathLength() == 1) {
1877             deps->needWholeDocument = true;  // need full doc if just "$$ROOT"
1878         } else {
1879             deps->fields.insert(_fieldPath.tail().fullPath());
1880         }
1881     } else if (Variables::isUserDefinedVariable(_variable)) {
1882         deps->vars.insert(_variable);
1883     }
1884 }
1885 
evaluatePathArray(size_t index,const Value & input) const1886 Value ExpressionFieldPath::evaluatePathArray(size_t index, const Value& input) const {
1887     dassert(input.isArray());
1888 
1889     // Check for remaining path in each element of array
1890     vector<Value> result;
1891     const vector<Value>& array = input.getArray();
1892     for (size_t i = 0; i < array.size(); i++) {
1893         if (array[i].getType() != Object)
1894             continue;
1895 
1896         const Value nested = evaluatePath(index, array[i].getDocument());
1897         if (!nested.missing())
1898             result.push_back(nested);
1899     }
1900 
1901     return Value(std::move(result));
1902 }
evaluatePath(size_t index,const Document & input) const1903 Value ExpressionFieldPath::evaluatePath(size_t index, const Document& input) const {
1904     // Note this function is very hot so it is important that is is well optimized.
1905     // In particular, all return paths should support RVO.
1906 
1907     /* if we've hit the end of the path, stop */
1908     if (index == _fieldPath.getPathLength() - 1)
1909         return input[_fieldPath.getFieldName(index)];
1910 
1911     // Try to dive deeper
1912     const Value val = input[_fieldPath.getFieldName(index)];
1913     switch (val.getType()) {
1914         case Object:
1915             return evaluatePath(index + 1, val.getDocument());
1916 
1917         case Array:
1918             return evaluatePathArray(index + 1, val);
1919 
1920         default:
1921             return Value();
1922     }
1923 }
1924 
evaluate(const Document & root,Variables * variables) const1925 Value ExpressionFieldPath::evaluate(const Document& root, Variables* variables) const {
1926     if (_fieldPath.getPathLength() == 1)  // get the whole variable
1927         return variables->getValue(_variable, root);
1928 
1929     if (_variable == Variables::kRootId) {
1930         // ROOT is always a document so use optimized code path
1931         return evaluatePath(1, root);
1932     }
1933 
1934     Value var = variables->getValue(_variable, root);
1935     switch (var.getType()) {
1936         case Object:
1937             return evaluatePath(1, var.getDocument());
1938         case Array:
1939             return evaluatePathArray(1, var);
1940         default:
1941             return Value();
1942     }
1943 }
1944 
serialize(bool explain) const1945 Value ExpressionFieldPath::serialize(bool explain) const {
1946     if (_fieldPath.getFieldName(0) == "CURRENT" && _fieldPath.getPathLength() > 1) {
1947         // use short form for "$$CURRENT.foo" but not just "$$CURRENT"
1948         return Value("$" + _fieldPath.tail().fullPath());
1949     } else {
1950         return Value("$$" + _fieldPath.fullPath());
1951     }
1952 }
1953 
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const1954 Expression::ComputedPaths ExpressionFieldPath::getComputedPaths(const std::string& exprFieldPath,
1955                                                                 Variables::Id renamingVar) const {
1956     // An expression field path is either considered a rename or a computed path. We need to find
1957     // out which case we fall into.
1958     //
1959     // The caller has told us that renames must have 'varId' as the first component. We also check
1960     // that there is only one additional component---no dotted field paths are allowed!  This is
1961     // because dotted ExpressionFieldPaths can actually reshape the document rather than just
1962     // changing the field names. This can happen only if there are arrays along the dotted path.
1963     //
1964     // For example, suppose you have document {a: [{b: 1}, {b: 2}]}. The projection {"c.d": "$a.b"}
1965     // does *not* perform the strict rename to yield document {c: [{d: 1}, {d: 2}]}. Instead, it
1966     // results in the document {c: {d: [1, 2]}}. Due to this reshaping, matches expressed over "a.b"
1967     // before the $project is applied may not have the same behavior when expressed over "c.d" after
1968     // the $project is applied.
1969     ComputedPaths outputPaths;
1970     if (_variable == renamingVar && _fieldPath.getPathLength() == 2u) {
1971         outputPaths.renames[exprFieldPath] = _fieldPath.tail().fullPath();
1972     } else {
1973         outputPaths.paths.insert(exprFieldPath);
1974     }
1975 
1976     return outputPaths;
1977 }
1978 
1979 /* ------------------------- ExpressionFilter ----------------------------- */
1980 
1981 REGISTER_EXPRESSION(filter, ExpressionFilter::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)1982 intrusive_ptr<Expression> ExpressionFilter::parse(
1983     const boost::intrusive_ptr<ExpressionContext>& expCtx,
1984     BSONElement expr,
1985     const VariablesParseState& vpsIn) {
1986     verify(str::equals(expr.fieldName(), "$filter"));
1987 
1988     uassert(28646, "$filter only supports an object as its argument", expr.type() == Object);
1989 
1990     // "cond" must be parsed after "as" regardless of BSON order.
1991     BSONElement inputElem;
1992     BSONElement asElem;
1993     BSONElement condElem;
1994     for (auto elem : expr.Obj()) {
1995         if (str::equals(elem.fieldName(), "input")) {
1996             inputElem = elem;
1997         } else if (str::equals(elem.fieldName(), "as")) {
1998             asElem = elem;
1999         } else if (str::equals(elem.fieldName(), "cond")) {
2000             condElem = elem;
2001         } else {
2002             uasserted(28647,
2003                       str::stream() << "Unrecognized parameter to $filter: " << elem.fieldName());
2004         }
2005     }
2006 
2007     uassert(28648, "Missing 'input' parameter to $filter", !inputElem.eoo());
2008     uassert(28650, "Missing 'cond' parameter to $filter", !condElem.eoo());
2009 
2010     // Parse "input", only has outer variables.
2011     intrusive_ptr<Expression> input = parseOperand(expCtx, inputElem, vpsIn);
2012 
2013     // Parse "as".
2014     VariablesParseState vpsSub(vpsIn);  // vpsSub gets our variable, vpsIn doesn't.
2015 
2016     // If "as" is not specified, then use "this" by default.
2017     auto varName = asElem.eoo() ? "this" : asElem.str();
2018 
2019     Variables::uassertValidNameForUserWrite(varName);
2020     Variables::Id varId = vpsSub.defineVariable(varName);
2021 
2022     // Parse "cond", has access to "as" variable.
2023     intrusive_ptr<Expression> cond = parseOperand(expCtx, condElem, vpsSub);
2024 
2025     return new ExpressionFilter(
2026         expCtx, std::move(varName), varId, std::move(input), std::move(cond));
2027 }
2028 
ExpressionFilter(const boost::intrusive_ptr<ExpressionContext> & expCtx,string varName,Variables::Id varId,intrusive_ptr<Expression> input,intrusive_ptr<Expression> filter)2029 ExpressionFilter::ExpressionFilter(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2030                                    string varName,
2031                                    Variables::Id varId,
2032                                    intrusive_ptr<Expression> input,
2033                                    intrusive_ptr<Expression> filter)
2034     : Expression(expCtx),
2035       _varName(std::move(varName)),
2036       _varId(varId),
2037       _input(std::move(input)),
2038       _filter(std::move(filter)) {}
2039 
optimize()2040 intrusive_ptr<Expression> ExpressionFilter::optimize() {
2041     // TODO handle when _input is constant.
2042     _input = _input->optimize();
2043     _filter = _filter->optimize();
2044     return this;
2045 }
2046 
serialize(bool explain) const2047 Value ExpressionFilter::serialize(bool explain) const {
2048     return Value(
2049         DOC("$filter" << DOC("input" << _input->serialize(explain) << "as" << _varName << "cond"
2050                                      << _filter->serialize(explain))));
2051 }
2052 
evaluate(const Document & root,Variables * variables) const2053 Value ExpressionFilter::evaluate(const Document& root, Variables* variables) const {
2054     // We are guaranteed at parse time that this isn't using our _varId.
2055     const Value inputVal = _input->evaluate(root, variables);
2056     if (inputVal.nullish())
2057         return Value(BSONNULL);
2058 
2059     uassert(28651,
2060             str::stream() << "input to $filter must be an array not "
2061                           << typeName(inputVal.getType()),
2062             inputVal.isArray());
2063 
2064     const vector<Value>& input = inputVal.getArray();
2065 
2066     if (input.empty())
2067         return inputVal;
2068 
2069     vector<Value> output;
2070     for (const auto& elem : input) {
2071         variables->setValue(_varId, elem);
2072 
2073         if (_filter->evaluate(root, variables).coerceToBool()) {
2074             output.push_back(std::move(elem));
2075         }
2076     }
2077 
2078     return Value(std::move(output));
2079 }
2080 
_doAddDependencies(DepsTracker * deps) const2081 void ExpressionFilter::_doAddDependencies(DepsTracker* deps) const {
2082     _input->addDependencies(deps);
2083     _filter->addDependencies(deps);
2084 }
2085 
2086 /* ------------------------- ExpressionFloor -------------------------- */
2087 
evaluateNumericArg(const Value & numericArg) const2088 Value ExpressionFloor::evaluateNumericArg(const Value& numericArg) const {
2089     // There's no point in taking the floor of integers or longs, it will have no effect.
2090     switch (numericArg.getType()) {
2091         case NumberDouble:
2092             return Value(std::floor(numericArg.getDouble()));
2093         case NumberDecimal:
2094             // Round toward the nearest decimal with a zero exponent in the negative direction.
2095             return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
2096                                                           Decimal128::kRoundTowardNegative));
2097         default:
2098             return numericArg;
2099     }
2100 }
2101 
2102 REGISTER_EXPRESSION(floor, ExpressionFloor::parse);
getOpName() const2103 const char* ExpressionFloor::getOpName() const {
2104     return "$floor";
2105 }
2106 
2107 /* ------------------------- ExpressionLet ----------------------------- */
2108 
2109 REGISTER_EXPRESSION(let, ExpressionLet::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2110 intrusive_ptr<Expression> ExpressionLet::parse(
2111     const boost::intrusive_ptr<ExpressionContext>& expCtx,
2112     BSONElement expr,
2113     const VariablesParseState& vpsIn) {
2114     verify(str::equals(expr.fieldName(), "$let"));
2115 
2116     uassert(16874, "$let only supports an object as its argument", expr.type() == Object);
2117     const BSONObj args = expr.embeddedObject();
2118 
2119     // varsElem must be parsed before inElem regardless of BSON order.
2120     BSONElement varsElem;
2121     BSONElement inElem;
2122     BSONForEach(arg, args) {
2123         if (str::equals(arg.fieldName(), "vars")) {
2124             varsElem = arg;
2125         } else if (str::equals(arg.fieldName(), "in")) {
2126             inElem = arg;
2127         } else {
2128             uasserted(16875,
2129                       str::stream() << "Unrecognized parameter to $let: " << arg.fieldName());
2130         }
2131     }
2132 
2133     uassert(16876, "Missing 'vars' parameter to $let", !varsElem.eoo());
2134     uassert(16877, "Missing 'in' parameter to $let", !inElem.eoo());
2135 
2136     // parse "vars"
2137     VariablesParseState vpsSub(vpsIn);  // vpsSub gets our vars, vpsIn doesn't.
2138     VariableMap vars;
2139     BSONForEach(varElem, varsElem.embeddedObjectUserCheck()) {
2140         const string varName = varElem.fieldName();
2141         Variables::uassertValidNameForUserWrite(varName);
2142         Variables::Id id = vpsSub.defineVariable(varName);
2143 
2144         vars[id] = NameAndExpression(varName,
2145                                      parseOperand(expCtx, varElem, vpsIn));  // only has outer vars
2146     }
2147 
2148     // parse "in"
2149     intrusive_ptr<Expression> subExpression = parseOperand(expCtx, inElem, vpsSub);  // has our vars
2150 
2151     return new ExpressionLet(expCtx, vars, subExpression);
2152 }
2153 
ExpressionLet(const boost::intrusive_ptr<ExpressionContext> & expCtx,const VariableMap & vars,intrusive_ptr<Expression> subExpression)2154 ExpressionLet::ExpressionLet(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2155                              const VariableMap& vars,
2156                              intrusive_ptr<Expression> subExpression)
2157     : Expression(expCtx), _variables(vars), _subExpression(subExpression) {}
2158 
optimize()2159 intrusive_ptr<Expression> ExpressionLet::optimize() {
2160     if (_variables.empty()) {
2161         // we aren't binding any variables so just return the subexpression
2162         return _subExpression->optimize();
2163     }
2164 
2165     for (VariableMap::iterator it = _variables.begin(), end = _variables.end(); it != end; ++it) {
2166         it->second.expression = it->second.expression->optimize();
2167     }
2168 
2169     _subExpression = _subExpression->optimize();
2170 
2171     return this;
2172 }
2173 
serialize(bool explain) const2174 Value ExpressionLet::serialize(bool explain) const {
2175     MutableDocument vars;
2176     for (VariableMap::const_iterator it = _variables.begin(), end = _variables.end(); it != end;
2177          ++it) {
2178         vars[it->second.name] = it->second.expression->serialize(explain);
2179     }
2180 
2181     return Value(
2182         DOC("$let" << DOC("vars" << vars.freeze() << "in" << _subExpression->serialize(explain))));
2183 }
2184 
evaluate(const Document & root,Variables * variables) const2185 Value ExpressionLet::evaluate(const Document& root, Variables* variables) const {
2186     for (const auto& item : _variables) {
2187         // It is guaranteed at parse-time that these expressions don't use the variable ids we
2188         // are setting
2189         variables->setValue(item.first, item.second.expression->evaluate(root, variables));
2190     }
2191 
2192     return _subExpression->evaluate(root, variables);
2193 }
2194 
_doAddDependencies(DepsTracker * deps) const2195 void ExpressionLet::_doAddDependencies(DepsTracker* deps) const {
2196     for (auto&& idToNameExp : _variables) {
2197         // Add the external dependencies from the 'vars' statement.
2198         idToNameExp.second.expression->addDependencies(deps);
2199     }
2200 
2201     // Add subexpression dependencies, which may contain a mix of local and external variable refs.
2202     _subExpression->addDependencies(deps);
2203 }
2204 
2205 /* ------------------------- ExpressionMap ----------------------------- */
2206 
2207 REGISTER_EXPRESSION(map, ExpressionMap::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2208 intrusive_ptr<Expression> ExpressionMap::parse(
2209     const boost::intrusive_ptr<ExpressionContext>& expCtx,
2210     BSONElement expr,
2211     const VariablesParseState& vpsIn) {
2212     verify(str::equals(expr.fieldName(), "$map"));
2213 
2214     uassert(16878, "$map only supports an object as its argument", expr.type() == Object);
2215 
2216     // "in" must be parsed after "as" regardless of BSON order
2217     BSONElement inputElem;
2218     BSONElement asElem;
2219     BSONElement inElem;
2220     const BSONObj args = expr.embeddedObject();
2221     BSONForEach(arg, args) {
2222         if (str::equals(arg.fieldName(), "input")) {
2223             inputElem = arg;
2224         } else if (str::equals(arg.fieldName(), "as")) {
2225             asElem = arg;
2226         } else if (str::equals(arg.fieldName(), "in")) {
2227             inElem = arg;
2228         } else {
2229             uasserted(16879,
2230                       str::stream() << "Unrecognized parameter to $map: " << arg.fieldName());
2231         }
2232     }
2233 
2234     uassert(16880, "Missing 'input' parameter to $map", !inputElem.eoo());
2235     uassert(16882, "Missing 'in' parameter to $map", !inElem.eoo());
2236 
2237     // parse "input"
2238     intrusive_ptr<Expression> input =
2239         parseOperand(expCtx, inputElem, vpsIn);  // only has outer vars
2240 
2241     // parse "as"
2242     VariablesParseState vpsSub(vpsIn);  // vpsSub gets our vars, vpsIn doesn't.
2243 
2244     // If "as" is not specified, then use "this" by default.
2245     auto varName = asElem.eoo() ? "this" : asElem.str();
2246 
2247     Variables::uassertValidNameForUserWrite(varName);
2248     Variables::Id varId = vpsSub.defineVariable(varName);
2249 
2250     // parse "in"
2251     intrusive_ptr<Expression> in =
2252         parseOperand(expCtx, inElem, vpsSub);  // has access to map variable
2253 
2254     return new ExpressionMap(expCtx, varName, varId, input, in);
2255 }
2256 
ExpressionMap(const boost::intrusive_ptr<ExpressionContext> & expCtx,const string & varName,Variables::Id varId,intrusive_ptr<Expression> input,intrusive_ptr<Expression> each)2257 ExpressionMap::ExpressionMap(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2258                              const string& varName,
2259                              Variables::Id varId,
2260                              intrusive_ptr<Expression> input,
2261                              intrusive_ptr<Expression> each)
2262     : Expression(expCtx), _varName(varName), _varId(varId), _input(input), _each(each) {}
2263 
optimize()2264 intrusive_ptr<Expression> ExpressionMap::optimize() {
2265     // TODO handle when _input is constant
2266     _input = _input->optimize();
2267     _each = _each->optimize();
2268     return this;
2269 }
2270 
serialize(bool explain) const2271 Value ExpressionMap::serialize(bool explain) const {
2272     return Value(DOC("$map" << DOC("input" << _input->serialize(explain) << "as" << _varName << "in"
2273                                            << _each->serialize(explain))));
2274 }
2275 
evaluate(const Document & root,Variables * variables) const2276 Value ExpressionMap::evaluate(const Document& root, Variables* variables) const {
2277     // guaranteed at parse time that this isn't using our _varId
2278     const Value inputVal = _input->evaluate(root, variables);
2279     if (inputVal.nullish())
2280         return Value(BSONNULL);
2281 
2282     uassert(16883,
2283             str::stream() << "input to $map must be an array not " << typeName(inputVal.getType()),
2284             inputVal.isArray());
2285 
2286     const vector<Value>& input = inputVal.getArray();
2287 
2288     if (input.empty())
2289         return inputVal;
2290 
2291     vector<Value> output;
2292     output.reserve(input.size());
2293     for (size_t i = 0; i < input.size(); i++) {
2294         variables->setValue(_varId, input[i]);
2295 
2296         Value toInsert = _each->evaluate(root, variables);
2297         if (toInsert.missing())
2298             toInsert = Value(BSONNULL);  // can't insert missing values into array
2299 
2300         output.push_back(toInsert);
2301     }
2302 
2303     return Value(std::move(output));
2304 }
2305 
_doAddDependencies(DepsTracker * deps) const2306 void ExpressionMap::_doAddDependencies(DepsTracker* deps) const {
2307     _input->addDependencies(deps);
2308     _each->addDependencies(deps);
2309 }
2310 
getComputedPaths(const std::string & exprFieldPath,Variables::Id renamingVar) const2311 Expression::ComputedPaths ExpressionMap::getComputedPaths(const std::string& exprFieldPath,
2312                                                           Variables::Id renamingVar) const {
2313     auto inputFieldPath = dynamic_cast<ExpressionFieldPath*>(_input.get());
2314     if (!inputFieldPath) {
2315         return {{exprFieldPath}, {}};
2316     }
2317 
2318     auto inputComputedPaths = inputFieldPath->getComputedPaths("", renamingVar);
2319     if (inputComputedPaths.renames.empty()) {
2320         return {{exprFieldPath}, {}};
2321     }
2322     invariant(inputComputedPaths.renames.size() == 1u);
2323     auto fieldPathRenameIter = inputComputedPaths.renames.find("");
2324     invariant(fieldPathRenameIter != inputComputedPaths.renames.end());
2325     const auto& oldArrayName = fieldPathRenameIter->second;
2326 
2327     auto eachComputedPaths = _each->getComputedPaths(exprFieldPath, _varId);
2328     if (eachComputedPaths.renames.empty()) {
2329         return {{exprFieldPath}, {}};
2330     }
2331 
2332     // Append the name of the array to the beginning of the old field path.
2333     for (auto&& rename : eachComputedPaths.renames) {
2334         eachComputedPaths.renames[rename.first] =
2335             FieldPath::getFullyQualifiedPath(oldArrayName, rename.second);
2336     }
2337     return eachComputedPaths;
2338 }
2339 
2340 /* ------------------------- ExpressionMeta ----------------------------- */
2341 
2342 REGISTER_EXPRESSION(meta, ExpressionMeta::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vpsIn)2343 intrusive_ptr<Expression> ExpressionMeta::parse(
2344     const boost::intrusive_ptr<ExpressionContext>& expCtx,
2345     BSONElement expr,
2346     const VariablesParseState& vpsIn) {
2347     uassert(17307, "$meta only supports string arguments", expr.type() == String);
2348     if (expr.valueStringData() == "textScore") {
2349         return new ExpressionMeta(expCtx, MetaType::TEXT_SCORE);
2350     } else if (expr.valueStringData() == "randVal") {
2351         return new ExpressionMeta(expCtx, MetaType::RAND_VAL);
2352     } else {
2353         uasserted(17308, "Unsupported argument to $meta: " + expr.String());
2354     }
2355 }
2356 
ExpressionMeta(const boost::intrusive_ptr<ExpressionContext> & expCtx,MetaType metaType)2357 ExpressionMeta::ExpressionMeta(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2358                                MetaType metaType)
2359     : Expression(expCtx), _metaType(metaType) {}
2360 
serialize(bool explain) const2361 Value ExpressionMeta::serialize(bool explain) const {
2362     switch (_metaType) {
2363         case MetaType::TEXT_SCORE:
2364             return Value(DOC("$meta"
2365                              << "textScore"_sd));
2366         case MetaType::RAND_VAL:
2367             return Value(DOC("$meta"
2368                              << "randVal"_sd));
2369     }
2370     MONGO_UNREACHABLE;
2371 }
2372 
evaluate(const Document & root,Variables * variables) const2373 Value ExpressionMeta::evaluate(const Document& root, Variables* variables) const {
2374     switch (_metaType) {
2375         case MetaType::TEXT_SCORE:
2376             return root.hasTextScore() ? Value(root.getTextScore()) : Value();
2377         case MetaType::RAND_VAL:
2378             return root.hasRandMetaField() ? Value(root.getRandMetaField()) : Value();
2379     }
2380     MONGO_UNREACHABLE;
2381 }
2382 
_doAddDependencies(DepsTracker * deps) const2383 void ExpressionMeta::_doAddDependencies(DepsTracker* deps) const {
2384     if (_metaType == MetaType::TEXT_SCORE) {
2385         deps->setNeedTextScore(true);
2386     }
2387 }
2388 
2389 /* ----------------------- ExpressionMod ---------------------------- */
2390 
evaluate(const Document & root,Variables * variables) const2391 Value ExpressionMod::evaluate(const Document& root, Variables* variables) const {
2392     Value lhs = vpOperand[0]->evaluate(root, variables);
2393     Value rhs = vpOperand[1]->evaluate(root, variables);
2394 
2395     BSONType leftType = lhs.getType();
2396     BSONType rightType = rhs.getType();
2397 
2398     if (lhs.numeric() && rhs.numeric()) {
2399         auto assertNonZero = [](bool isZero) { uassert(16610, "can't $mod by zero", !isZero); };
2400 
2401         // If either side is decimal, perform the operation in decimal.
2402         if (leftType == NumberDecimal || rightType == NumberDecimal) {
2403             Decimal128 left = lhs.coerceToDecimal();
2404             Decimal128 right = rhs.coerceToDecimal();
2405             assertNonZero(right.isZero());
2406 
2407             return Value(left.modulo(right));
2408         }
2409 
2410         // ensure we aren't modding by 0
2411         double right = rhs.coerceToDouble();
2412         assertNonZero(right == 0);
2413 
2414         if (leftType == NumberDouble || (rightType == NumberDouble && !rhs.integral())) {
2415             // Need to do fmod. Integer-valued double case is handled below.
2416 
2417             double left = lhs.coerceToDouble();
2418             return Value(fmod(left, right));
2419         }
2420 
2421         if (leftType == NumberLong || rightType == NumberLong) {
2422             // if either is long, return long
2423             long long left = lhs.coerceToLong();
2424             long long rightLong = rhs.coerceToLong();
2425             return Value(mongoSafeMod(left, rightLong));
2426         }
2427 
2428         // lastly they must both be ints, return int
2429         int left = lhs.coerceToInt();
2430         int rightInt = rhs.coerceToInt();
2431         return Value(mongoSafeMod(left, rightInt));
2432     } else if (lhs.nullish() || rhs.nullish()) {
2433         return Value(BSONNULL);
2434     } else {
2435         uasserted(16611,
2436                   str::stream() << "$mod only supports numeric types, not "
2437                                 << typeName(lhs.getType())
2438                                 << " and "
2439                                 << typeName(rhs.getType()));
2440     }
2441 }
2442 
2443 REGISTER_EXPRESSION(mod, ExpressionMod::parse);
getOpName() const2444 const char* ExpressionMod::getOpName() const {
2445     return "$mod";
2446 }
2447 
2448 /* ------------------------- ExpressionMultiply ----------------------------- */
2449 
evaluate(const Document & root,Variables * variables) const2450 Value ExpressionMultiply::evaluate(const Document& root, Variables* variables) const {
2451     /*
2452       We'll try to return the narrowest possible result value.  To do that
2453       without creating intermediate Values, do the arithmetic for double
2454       and integral types in parallel, tracking the current narrowest
2455       type.
2456      */
2457     double doubleProduct = 1;
2458     long long longProduct = 1;
2459     Decimal128 decimalProduct;  // This will be initialized on encountering the first decimal.
2460 
2461     BSONType productType = NumberInt;
2462 
2463     const size_t n = vpOperand.size();
2464     for (size_t i = 0; i < n; ++i) {
2465         Value val = vpOperand[i]->evaluate(root, variables);
2466 
2467         if (val.numeric()) {
2468             BSONType oldProductType = productType;
2469             productType = Value::getWidestNumeric(productType, val.getType());
2470             if (productType == NumberDecimal) {
2471                 // On finding the first decimal, convert the partial product to decimal.
2472                 if (oldProductType != NumberDecimal) {
2473                     decimalProduct = oldProductType == NumberDouble
2474                         ? Decimal128(doubleProduct, Decimal128::kRoundTo15Digits)
2475                         : Decimal128(static_cast<int64_t>(longProduct));
2476                 }
2477                 decimalProduct = decimalProduct.multiply(val.coerceToDecimal());
2478             } else {
2479                 doubleProduct *= val.coerceToDouble();
2480                 if (!std::isfinite(val.coerceToDouble()) ||
2481                     mongoSignedMultiplyOverflow64(longProduct, val.coerceToLong(), &longProduct)) {
2482                     // The number is either Infinity or NaN, or the 'longProduct' would have
2483                     // overflowed, so we're abandoning it.
2484                     productType = NumberDouble;
2485                 }
2486             }
2487         } else if (val.nullish()) {
2488             return Value(BSONNULL);
2489         } else {
2490             uasserted(16555,
2491                       str::stream() << "$multiply only supports numeric types, not "
2492                                     << typeName(val.getType()));
2493         }
2494     }
2495 
2496     if (productType == NumberDouble)
2497         return Value(doubleProduct);
2498     else if (productType == NumberLong)
2499         return Value(longProduct);
2500     else if (productType == NumberInt)
2501         return Value::createIntOrLong(longProduct);
2502     else if (productType == NumberDecimal)
2503         return Value(decimalProduct);
2504     else
2505         massert(16418, "$multiply resulted in a non-numeric type", false);
2506 }
2507 
2508 REGISTER_EXPRESSION(multiply, ExpressionMultiply::parse);
getOpName() const2509 const char* ExpressionMultiply::getOpName() const {
2510     return "$multiply";
2511 }
2512 
2513 /* ----------------------- ExpressionIfNull ---------------------------- */
2514 
evaluate(const Document & root,Variables * variables) const2515 Value ExpressionIfNull::evaluate(const Document& root, Variables* variables) const {
2516     Value pLeft(vpOperand[0]->evaluate(root, variables));
2517     if (!pLeft.nullish())
2518         return pLeft;
2519 
2520     Value pRight(vpOperand[1]->evaluate(root, variables));
2521     return pRight;
2522 }
2523 
2524 REGISTER_EXPRESSION(ifNull, ExpressionIfNull::parse);
getOpName() const2525 const char* ExpressionIfNull::getOpName() const {
2526     return "$ifNull";
2527 }
2528 
2529 /* ----------------------- ExpressionIn ---------------------------- */
2530 
evaluate(const Document & root,Variables * variables) const2531 Value ExpressionIn::evaluate(const Document& root, Variables* variables) const {
2532     Value argument(vpOperand[0]->evaluate(root, variables));
2533     Value arrayOfValues(vpOperand[1]->evaluate(root, variables));
2534 
2535     uassert(40081,
2536             str::stream() << "$in requires an array as a second argument, found: "
2537                           << typeName(arrayOfValues.getType()),
2538             arrayOfValues.isArray());
2539     for (auto&& value : arrayOfValues.getArray()) {
2540         if (getExpressionContext()->getValueComparator().evaluate(argument == value)) {
2541             return Value(true);
2542         }
2543     }
2544     return Value(false);
2545 }
2546 
2547 REGISTER_EXPRESSION(in, ExpressionIn::parse);
getOpName() const2548 const char* ExpressionIn::getOpName() const {
2549     return "$in";
2550 }
2551 
2552 /* ----------------------- ExpressionIndexOfArray ------------------ */
2553 
2554 namespace {
2555 
uassertIfNotIntegralAndNonNegative(Value val,StringData expressionName,StringData argumentName)2556 void uassertIfNotIntegralAndNonNegative(Value val,
2557                                         StringData expressionName,
2558                                         StringData argumentName) {
2559     uassert(40096,
2560             str::stream() << expressionName << "requires an integral " << argumentName
2561                           << ", found a value of type: "
2562                           << typeName(val.getType())
2563                           << ", with value: "
2564                           << val.toString(),
2565             val.integral());
2566     uassert(40097,
2567             str::stream() << expressionName << " requires a nonnegative " << argumentName
2568                           << ", found: "
2569                           << val.toString(),
2570             val.coerceToInt() >= 0);
2571 }
2572 
2573 }  // namespace
2574 
evaluate(const Document & root,Variables * variables) const2575 Value ExpressionIndexOfArray::evaluate(const Document& root, Variables* variables) const {
2576     Value arrayArg = vpOperand[0]->evaluate(root, variables);
2577 
2578     if (arrayArg.nullish()) {
2579         return Value(BSONNULL);
2580     }
2581 
2582     uassert(40090,
2583             str::stream() << "$indexOfArray requires an array as a first argument, found: "
2584                           << typeName(arrayArg.getType()),
2585             arrayArg.isArray());
2586 
2587     std::vector<Value> array = arrayArg.getArray();
2588 
2589     Value searchItem = vpOperand[1]->evaluate(root, variables);
2590 
2591     size_t startIndex = 0;
2592     if (vpOperand.size() > 2) {
2593         Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2594         uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2595         startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2596     }
2597 
2598     size_t endIndex = array.size();
2599     if (vpOperand.size() > 3) {
2600         Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2601         uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2602         // Don't let 'endIndex' exceed the length of the array.
2603         endIndex = std::min(array.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
2604     }
2605 
2606     for (size_t i = startIndex; i < endIndex; i++) {
2607         if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) {
2608             return Value(static_cast<int>(i));
2609         }
2610     }
2611 
2612     return Value(-1);
2613 }
2614 
2615 REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse);
getOpName() const2616 const char* ExpressionIndexOfArray::getOpName() const {
2617     return "$indexOfArray";
2618 }
2619 
2620 /* ----------------------- ExpressionIndexOfBytes ------------------ */
2621 
2622 namespace {
2623 
stringHasTokenAtIndex(size_t index,const std::string & input,const std::string & token)2624 bool stringHasTokenAtIndex(size_t index, const std::string& input, const std::string& token) {
2625     if (token.size() + index > input.size()) {
2626         return false;
2627     }
2628     return input.compare(index, token.size(), token) == 0;
2629 }
2630 
2631 }  // namespace
2632 
evaluate(const Document & root,Variables * variables) const2633 Value ExpressionIndexOfBytes::evaluate(const Document& root, Variables* variables) const {
2634     Value stringArg = vpOperand[0]->evaluate(root, variables);
2635 
2636     if (stringArg.nullish()) {
2637         return Value(BSONNULL);
2638     }
2639 
2640     uassert(40091,
2641             str::stream() << "$indexOfBytes requires a string as the first argument, found: "
2642                           << typeName(stringArg.getType()),
2643             stringArg.getType() == String);
2644     const std::string& input = stringArg.getString();
2645 
2646     Value tokenArg = vpOperand[1]->evaluate(root, variables);
2647     uassert(40092,
2648             str::stream() << "$indexOfBytes requires a string as the second argument, found: "
2649                           << typeName(tokenArg.getType()),
2650             tokenArg.getType() == String);
2651     const std::string& token = tokenArg.getString();
2652 
2653     size_t startIndex = 0;
2654     if (vpOperand.size() > 2) {
2655         Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2656         uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2657         startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2658     }
2659 
2660     size_t endIndex = input.size();
2661     if (vpOperand.size() > 3) {
2662         Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2663         uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2664         // Don't let 'endIndex' exceed the length of the string.
2665         endIndex = std::min(input.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
2666     }
2667 
2668     if (startIndex > input.length() || endIndex < startIndex) {
2669         return Value(-1);
2670     }
2671 
2672     size_t position = input.substr(0, endIndex).find(token, startIndex);
2673     if (position == std::string::npos) {
2674         return Value(-1);
2675     }
2676 
2677     return Value(static_cast<int>(position));
2678 }
2679 
2680 REGISTER_EXPRESSION(indexOfBytes, ExpressionIndexOfBytes::parse);
getOpName() const2681 const char* ExpressionIndexOfBytes::getOpName() const {
2682     return "$indexOfBytes";
2683 }
2684 
2685 /* ----------------------- ExpressionIndexOfCP --------------------- */
2686 
evaluate(const Document & root,Variables * variables) const2687 Value ExpressionIndexOfCP::evaluate(const Document& root, Variables* variables) const {
2688     Value stringArg = vpOperand[0]->evaluate(root, variables);
2689 
2690     if (stringArg.nullish()) {
2691         return Value(BSONNULL);
2692     }
2693 
2694     uassert(40093,
2695             str::stream() << "$indexOfCP requires a string as the first argument, found: "
2696                           << typeName(stringArg.getType()),
2697             stringArg.getType() == String);
2698     const std::string& input = stringArg.getString();
2699 
2700     Value tokenArg = vpOperand[1]->evaluate(root, variables);
2701     uassert(40094,
2702             str::stream() << "$indexOfCP requires a string as the second argument, found: "
2703                           << typeName(tokenArg.getType()),
2704             tokenArg.getType() == String);
2705     const std::string& token = tokenArg.getString();
2706 
2707     size_t startCodePointIndex = 0;
2708     if (vpOperand.size() > 2) {
2709         Value startIndexArg = vpOperand[2]->evaluate(root, variables);
2710         uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2711         startCodePointIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2712     }
2713 
2714     // Compute the length (in code points) of the input, and convert 'startCodePointIndex' to a byte
2715     // index.
2716     size_t codePointLength = 0;
2717     size_t startByteIndex = 0;
2718     for (size_t byteIx = 0; byteIx < input.size(); ++codePointLength) {
2719         if (codePointLength == startCodePointIndex) {
2720             // We have determined the byte at which our search will start.
2721             startByteIndex = byteIx;
2722         }
2723 
2724         uassert(40095,
2725                 "$indexOfCP found bad UTF-8 in the input",
2726                 !str::isUTF8ContinuationByte(input[byteIx]));
2727         byteIx += getCodePointLength(input[byteIx]);
2728     }
2729 
2730     size_t endCodePointIndex = codePointLength;
2731     if (vpOperand.size() > 3) {
2732         Value endIndexArg = vpOperand[3]->evaluate(root, variables);
2733         uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
2734 
2735         // Don't let 'endCodePointIndex' exceed the number of code points in the string.
2736         endCodePointIndex =
2737             std::min(codePointLength, static_cast<size_t>(endIndexArg.coerceToInt()));
2738     }
2739 
2740     if (startByteIndex == 0 && input.empty() && token.empty()) {
2741         // If we are finding the index of "" in the string "", the below loop will not loop, so we
2742         // need a special case for this.
2743         return Value(0);
2744     }
2745 
2746     // We must keep track of which byte, and which code point, we are examining, being careful not
2747     // to overflow either the length of the string or the ending code point.
2748 
2749     size_t currentCodePointIndex = startCodePointIndex;
2750     for (size_t byteIx = startByteIndex; currentCodePointIndex < endCodePointIndex;
2751          ++currentCodePointIndex) {
2752         if (stringHasTokenAtIndex(byteIx, input, token)) {
2753             return Value(static_cast<int>(currentCodePointIndex));
2754         }
2755         byteIx += getCodePointLength(input[byteIx]);
2756     }
2757 
2758     return Value(-1);
2759 }
2760 
2761 REGISTER_EXPRESSION(indexOfCP, ExpressionIndexOfCP::parse);
getOpName() const2762 const char* ExpressionIndexOfCP::getOpName() const {
2763     return "$indexOfCP";
2764 }
2765 
2766 /* ----------------------- ExpressionLn ---------------------------- */
2767 
evaluateNumericArg(const Value & numericArg) const2768 Value ExpressionLn::evaluateNumericArg(const Value& numericArg) const {
2769     if (numericArg.getType() == NumberDecimal) {
2770         Decimal128 argDecimal = numericArg.getDecimal();
2771         if (argDecimal.isGreater(Decimal128::kNormalizedZero))
2772             return Value(argDecimal.logarithm());
2773         // Fall through for error case.
2774     }
2775     double argDouble = numericArg.coerceToDouble();
2776     uassert(28766,
2777             str::stream() << "$ln's argument must be a positive number, but is " << argDouble,
2778             argDouble > 0 || std::isnan(argDouble));
2779     return Value(std::log(argDouble));
2780 }
2781 
2782 REGISTER_EXPRESSION(ln, ExpressionLn::parse);
getOpName() const2783 const char* ExpressionLn::getOpName() const {
2784     return "$ln";
2785 }
2786 
2787 /* ----------------------- ExpressionLog ---------------------------- */
2788 
evaluate(const Document & root,Variables * variables) const2789 Value ExpressionLog::evaluate(const Document& root, Variables* variables) const {
2790     Value argVal = vpOperand[0]->evaluate(root, variables);
2791     Value baseVal = vpOperand[1]->evaluate(root, variables);
2792     if (argVal.nullish() || baseVal.nullish())
2793         return Value(BSONNULL);
2794 
2795     uassert(28756,
2796             str::stream() << "$log's argument must be numeric, not " << typeName(argVal.getType()),
2797             argVal.numeric());
2798     uassert(28757,
2799             str::stream() << "$log's base must be numeric, not " << typeName(baseVal.getType()),
2800             baseVal.numeric());
2801 
2802     if (argVal.getType() == NumberDecimal || baseVal.getType() == NumberDecimal) {
2803         Decimal128 argDecimal = argVal.coerceToDecimal();
2804         Decimal128 baseDecimal = baseVal.coerceToDecimal();
2805 
2806         if (argDecimal.isGreater(Decimal128::kNormalizedZero) &&
2807             baseDecimal.isNotEqual(Decimal128(1)) &&
2808             baseDecimal.isGreater(Decimal128::kNormalizedZero)) {
2809             return Value(argDecimal.logarithm(baseDecimal));
2810         }
2811         // Fall through for error cases.
2812     }
2813 
2814     double argDouble = argVal.coerceToDouble();
2815     double baseDouble = baseVal.coerceToDouble();
2816     uassert(28758,
2817             str::stream() << "$log's argument must be a positive number, but is " << argDouble,
2818             argDouble > 0 || std::isnan(argDouble));
2819     uassert(28759,
2820             str::stream() << "$log's base must be a positive number not equal to 1, but is "
2821                           << baseDouble,
2822             (baseDouble > 0 && baseDouble != 1) || std::isnan(baseDouble));
2823     return Value(std::log(argDouble) / std::log(baseDouble));
2824 }
2825 
2826 REGISTER_EXPRESSION(log, ExpressionLog::parse);
getOpName() const2827 const char* ExpressionLog::getOpName() const {
2828     return "$log";
2829 }
2830 
2831 /* ----------------------- ExpressionLog10 ---------------------------- */
2832 
evaluateNumericArg(const Value & numericArg) const2833 Value ExpressionLog10::evaluateNumericArg(const Value& numericArg) const {
2834     if (numericArg.getType() == NumberDecimal) {
2835         Decimal128 argDecimal = numericArg.getDecimal();
2836         if (argDecimal.isGreater(Decimal128::kNormalizedZero))
2837             return Value(argDecimal.logarithm(Decimal128(10)));
2838         // Fall through for error case.
2839     }
2840 
2841     double argDouble = numericArg.coerceToDouble();
2842     uassert(28761,
2843             str::stream() << "$log10's argument must be a positive number, but is " << argDouble,
2844             argDouble > 0 || std::isnan(argDouble));
2845     return Value(std::log10(argDouble));
2846 }
2847 
2848 REGISTER_EXPRESSION(log10, ExpressionLog10::parse);
getOpName() const2849 const char* ExpressionLog10::getOpName() const {
2850     return "$log10";
2851 }
2852 
2853 /* ------------------------ ExpressionNary ----------------------------- */
2854 
2855 /**
2856  * Optimize a general Nary expression.
2857  *
2858  * The optimization has the following properties:
2859  *   1) Optimize each of the operators.
2860  *   2) If the operand is associative, flatten internal operators of the same type. I.e.:
2861  *      A+B+(C+D)+E => A+B+C+D+E
2862  *   3) If the operand is commutative & associative, group all constant operators. For example:
2863  *      c1 + c2 + n1 + c3 + n2 => n1 + n2 + c1 + c2 + c3
2864  *   4) If the operand is associative, execute the operation over all the contiguous constant
2865  *      operators and replacing them by the result. For example: c1 + c2 + n1 + c3 + c4 + n5 =>
2866  *      c5 = c1 + c2, c6 = c3 + c4 => c5 + n1 + c6 + n5
2867  *
2868  * It returns the optimized expression. It can be exactly the same expression, a modified version
2869  * of the same expression or a completely different expression.
2870  */
optimize()2871 intrusive_ptr<Expression> ExpressionNary::optimize() {
2872     uint32_t constOperandCount = 0;
2873 
2874     for (auto& operand : vpOperand) {
2875         operand = operand->optimize();
2876         if (dynamic_cast<ExpressionConstant*>(operand.get())) {
2877             ++constOperandCount;
2878         }
2879     }
2880     // If all the operands are constant expressions, collapse the expression into one constant
2881     // expression.
2882     if (constOperandCount == vpOperand.size()) {
2883         return intrusive_ptr<Expression>(ExpressionConstant::create(
2884             getExpressionContext(), evaluate(Document(), &(getExpressionContext()->variables))));
2885     }
2886 
2887     // If the expression is associative, we can collapse all the consecutive constant operands into
2888     // one by applying the expression to those consecutive constant operands.
2889     // If the expression is also commutative we can reorganize all the operands so that all of the
2890     // constant ones are together (arbitrarily at the back) and we can collapse all of them into
2891     // one.
2892     if (isAssociative()) {
2893         ExpressionVector constExpressions;
2894         ExpressionVector optimizedOperands;
2895         for (size_t i = 0; i < vpOperand.size();) {
2896             intrusive_ptr<Expression> operand = vpOperand[i];
2897             // If the operand is a constant one, add it to the current list of consecutive constant
2898             // operands.
2899             if (dynamic_cast<ExpressionConstant*>(operand.get())) {
2900                 constExpressions.push_back(operand);
2901                 ++i;
2902                 continue;
2903             }
2904 
2905             // If the operand is exactly the same type as the one we are currently optimizing and
2906             // is also associative, replace the expression for the operands it has.
2907             // E.g: sum(a, b, sum(c, d), e) => sum(a, b, c, d, e)
2908             ExpressionNary* nary = dynamic_cast<ExpressionNary*>(operand.get());
2909             if (nary && str::equals(nary->getOpName(), getOpName()) && nary->isAssociative()) {
2910                 invariant(!nary->vpOperand.empty());
2911                 vpOperand[i] = std::move(nary->vpOperand[0]);
2912                 vpOperand.insert(
2913                     vpOperand.begin() + i + 1, nary->vpOperand.begin() + 1, nary->vpOperand.end());
2914                 continue;
2915             }
2916 
2917             // If the operand is not a constant nor a same-type expression and the expression is
2918             // not commutative, evaluate an expression of the same type as the one we are
2919             // optimizing on the list of consecutive constant operands and use the resulting value
2920             // as a constant expression operand.
2921             // If the list of consecutive constant operands has less than 2 operands just place
2922             // back the operands.
2923             if (!isCommutative()) {
2924                 if (constExpressions.size() > 1) {
2925                     ExpressionVector vpOperandSave = std::move(vpOperand);
2926                     vpOperand = std::move(constExpressions);
2927                     optimizedOperands.emplace_back(ExpressionConstant::create(
2928                         getExpressionContext(),
2929                         evaluate(Document(), &getExpressionContext()->variables)));
2930                     vpOperand = std::move(vpOperandSave);
2931                 } else {
2932                     optimizedOperands.insert(
2933                         optimizedOperands.end(), constExpressions.begin(), constExpressions.end());
2934                 }
2935                 constExpressions.clear();
2936             }
2937             optimizedOperands.push_back(operand);
2938             ++i;
2939         }
2940 
2941         if (constExpressions.size() > 1) {
2942             vpOperand = std::move(constExpressions);
2943             optimizedOperands.emplace_back(ExpressionConstant::create(
2944                 getExpressionContext(),
2945                 evaluate(Document(), &(getExpressionContext()->variables))));
2946         } else {
2947             optimizedOperands.insert(
2948                 optimizedOperands.end(), constExpressions.begin(), constExpressions.end());
2949         }
2950 
2951         vpOperand = std::move(optimizedOperands);
2952     }
2953     return this;
2954 }
2955 
_doAddDependencies(DepsTracker * deps) const2956 void ExpressionNary::_doAddDependencies(DepsTracker* deps) const {
2957     for (auto&& operand : vpOperand) {
2958         operand->addDependencies(deps);
2959     }
2960 }
2961 
addOperand(const intrusive_ptr<Expression> & pExpression)2962 void ExpressionNary::addOperand(const intrusive_ptr<Expression>& pExpression) {
2963     vpOperand.push_back(pExpression);
2964 }
2965 
serialize(bool explain) const2966 Value ExpressionNary::serialize(bool explain) const {
2967     const size_t nOperand = vpOperand.size();
2968     vector<Value> array;
2969     /* build up the array */
2970     for (size_t i = 0; i < nOperand; i++)
2971         array.push_back(vpOperand[i]->serialize(explain));
2972 
2973     return Value(DOC(getOpName() << array));
2974 }
2975 
2976 /* ------------------------- ExpressionNot ----------------------------- */
2977 
evaluate(const Document & root,Variables * variables) const2978 Value ExpressionNot::evaluate(const Document& root, Variables* variables) const {
2979     Value pOp(vpOperand[0]->evaluate(root, variables));
2980 
2981     bool b = pOp.coerceToBool();
2982     return Value(!b);
2983 }
2984 
2985 REGISTER_EXPRESSION(not, ExpressionNot::parse);
getOpName() const2986 const char* ExpressionNot::getOpName() const {
2987     return "$not";
2988 }
2989 
2990 /* -------------------------- ExpressionOr ----------------------------- */
2991 
evaluate(const Document & root,Variables * variables) const2992 Value ExpressionOr::evaluate(const Document& root, Variables* variables) const {
2993     const size_t n = vpOperand.size();
2994     for (size_t i = 0; i < n; ++i) {
2995         Value pValue(vpOperand[i]->evaluate(root, variables));
2996         if (pValue.coerceToBool())
2997             return Value(true);
2998     }
2999 
3000     return Value(false);
3001 }
3002 
optimize()3003 intrusive_ptr<Expression> ExpressionOr::optimize() {
3004     /* optimize the disjunction as much as possible */
3005     intrusive_ptr<Expression> pE(ExpressionNary::optimize());
3006 
3007     /* if the result isn't a disjunction, we can't do anything */
3008     ExpressionOr* pOr = dynamic_cast<ExpressionOr*>(pE.get());
3009     if (!pOr)
3010         return pE;
3011 
3012     /*
3013       Check the last argument on the result; if it's not constant (as
3014       promised by ExpressionNary::optimize(),) then there's nothing
3015       we can do.
3016     */
3017     const size_t n = pOr->vpOperand.size();
3018     // ExpressionNary::optimize() generates an ExpressionConstant for {$or:[]}.
3019     verify(n > 0);
3020     intrusive_ptr<Expression> pLast(pOr->vpOperand[n - 1]);
3021     const ExpressionConstant* pConst = dynamic_cast<ExpressionConstant*>(pLast.get());
3022     if (!pConst)
3023         return pE;
3024 
3025     /*
3026       Evaluate and coerce the last argument to a boolean.  If it's true,
3027       then we can replace this entire expression.
3028      */
3029     bool last = pConst->getValue().coerceToBool();
3030     if (last) {
3031         intrusive_ptr<ExpressionConstant> pFinal(
3032             ExpressionConstant::create(getExpressionContext(), Value(true)));
3033         return pFinal;
3034     }
3035 
3036     /*
3037       If we got here, the final operand was false, so we don't need it
3038       anymore.  If there was only one other operand, we don't need the
3039       conjunction either.  Note we still need to keep the promise that
3040       the result will be a boolean.
3041      */
3042     if (n == 2) {
3043         intrusive_ptr<Expression> pFinal(
3044             ExpressionCoerceToBool::create(getExpressionContext(), pOr->vpOperand[0]));
3045         return pFinal;
3046     }
3047 
3048     /*
3049       Remove the final "false" value, and return the new expression.
3050     */
3051     pOr->vpOperand.resize(n - 1);
3052     return pE;
3053 }
3054 
3055 REGISTER_EXPRESSION(or, ExpressionOr::parse);
getOpName() const3056 const char* ExpressionOr::getOpName() const {
3057     return "$or";
3058 }
3059 
3060 namespace {
3061 /**
3062  * Helper for ExpressionPow to determine whether base^exp can be represented in a 64 bit int.
3063  *
3064  *'base' and 'exp' are both integers. Assumes 'exp' is in the range [0, 63].
3065  */
representableAsLong(long long base,long long exp)3066 bool representableAsLong(long long base, long long exp) {
3067     invariant(exp <= 63);
3068     invariant(exp >= 0);
3069     struct MinMax {
3070         long long min;
3071         long long max;
3072     };
3073 
3074     // Array indices correspond to exponents 0 through 63. The values in each index are the min
3075     // and max bases, respectively, that can be raised to that exponent without overflowing a
3076     // 64-bit int. For max bases, this was computed by solving for b in
3077     // b = (2^63-1)^(1/exp) for exp = [0, 63] and truncating b. To calculate min bases, for even
3078     // exps the equation  used was b = (2^63-1)^(1/exp), and for odd exps the equation used was
3079     // b = (-2^63)^(1/exp). Since the magnitude of long min is greater than long max, the
3080     // magnitude of some of the min bases raised to odd exps is greater than the corresponding
3081     // max bases raised to the same exponents.
3082 
3083     static const MinMax kBaseLimits[] = {
3084         {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()},  // 0
3085         {std::numeric_limits<long long>::min(), std::numeric_limits<long long>::max()},
3086         {-3037000499LL, 3037000499LL},
3087         {-2097152, 2097151},
3088         {-55108, 55108},
3089         {-6208, 6208},
3090         {-1448, 1448},
3091         {-512, 511},
3092         {-234, 234},
3093         {-128, 127},
3094         {-78, 78},  // 10
3095         {-52, 52},
3096         {-38, 38},
3097         {-28, 28},
3098         {-22, 22},
3099         {-18, 18},
3100         {-15, 15},
3101         {-13, 13},
3102         {-11, 11},
3103         {-9, 9},
3104         {-8, 8},  // 20
3105         {-8, 7},
3106         {-7, 7},
3107         {-6, 6},
3108         {-6, 6},
3109         {-5, 5},
3110         {-5, 5},
3111         {-5, 5},
3112         {-4, 4},
3113         {-4, 4},
3114         {-4, 4},  // 30
3115         {-4, 4},
3116         {-3, 3},
3117         {-3, 3},
3118         {-3, 3},
3119         {-3, 3},
3120         {-3, 3},
3121         {-3, 3},
3122         {-3, 3},
3123         {-3, 3},
3124         {-2, 2},  // 40
3125         {-2, 2},
3126         {-2, 2},
3127         {-2, 2},
3128         {-2, 2},
3129         {-2, 2},
3130         {-2, 2},
3131         {-2, 2},
3132         {-2, 2},
3133         {-2, 2},
3134         {-2, 2},  // 50
3135         {-2, 2},
3136         {-2, 2},
3137         {-2, 2},
3138         {-2, 2},
3139         {-2, 2},
3140         {-2, 2},
3141         {-2, 2},
3142         {-2, 2},
3143         {-2, 2},
3144         {-2, 2},  // 60
3145         {-2, 2},
3146         {-2, 2},
3147         {-2, 1}};
3148 
3149     return base >= kBaseLimits[exp].min && base <= kBaseLimits[exp].max;
3150 };
3151 }  // namespace
3152 
3153 /* ----------------------- ExpressionPow ---------------------------- */
3154 
create(const boost::intrusive_ptr<ExpressionContext> & expCtx,Value base,Value exp)3155 intrusive_ptr<Expression> ExpressionPow::create(
3156     const boost::intrusive_ptr<ExpressionContext>& expCtx, Value base, Value exp) {
3157     intrusive_ptr<ExpressionPow> expr(new ExpressionPow(expCtx));
3158     expr->vpOperand.push_back(
3159         ExpressionConstant::create(expr->getExpressionContext(), std::move(base)));
3160     expr->vpOperand.push_back(
3161         ExpressionConstant::create(expr->getExpressionContext(), std::move(exp)));
3162     return expr;
3163 }
3164 
evaluate(const Document & root,Variables * variables) const3165 Value ExpressionPow::evaluate(const Document& root, Variables* variables) const {
3166     Value baseVal = vpOperand[0]->evaluate(root, variables);
3167     Value expVal = vpOperand[1]->evaluate(root, variables);
3168     if (baseVal.nullish() || expVal.nullish())
3169         return Value(BSONNULL);
3170 
3171     BSONType baseType = baseVal.getType();
3172     BSONType expType = expVal.getType();
3173 
3174     uassert(28762,
3175             str::stream() << "$pow's base must be numeric, not " << typeName(baseType),
3176             baseVal.numeric());
3177     uassert(28763,
3178             str::stream() << "$pow's exponent must be numeric, not " << typeName(expType),
3179             expVal.numeric());
3180 
3181     auto checkNonZeroAndNeg = [](bool isZeroAndNeg) {
3182         uassert(28764, "$pow cannot take a base of 0 and a negative exponent", !isZeroAndNeg);
3183     };
3184 
3185     // If either argument is decimal, return a decimal.
3186     if (baseType == NumberDecimal || expType == NumberDecimal) {
3187         Decimal128 baseDecimal = baseVal.coerceToDecimal();
3188         Decimal128 expDecimal = expVal.coerceToDecimal();
3189         checkNonZeroAndNeg(baseDecimal.isZero() && expDecimal.isNegative());
3190         return Value(baseDecimal.power(expDecimal));
3191     }
3192 
3193     // pow() will cast args to doubles.
3194     double baseDouble = baseVal.coerceToDouble();
3195     double expDouble = expVal.coerceToDouble();
3196     checkNonZeroAndNeg(baseDouble == 0 && expDouble < 0);
3197 
3198     // If either argument is a double, return a double.
3199     if (baseType == NumberDouble || expType == NumberDouble) {
3200         return Value(std::pow(baseDouble, expDouble));
3201     }
3202 
3203     // If either number is a long, return a long. If both numbers are ints, then return an int if
3204     // the result fits or a long if it is too big.
3205     const auto formatResult = [baseType, expType](long long res) {
3206         if (baseType == NumberLong || expType == NumberLong) {
3207             return Value(res);
3208         }
3209         return Value::createIntOrLong(res);
3210     };
3211 
3212     const long long baseLong = baseVal.getLong();
3213     const long long expLong = expVal.getLong();
3214 
3215     // Use this when the result cannot be represented as a long.
3216     const auto computeDoubleResult = [baseLong, expLong]() {
3217         return Value(std::pow(baseLong, expLong));
3218     };
3219 
3220     // Avoid doing repeated multiplication or using std::pow if the base is -1, 0, or 1.
3221     if (baseLong == 0) {
3222         if (expLong == 0) {
3223             // 0^0 = 1.
3224             return formatResult(1);
3225         } else if (expLong > 0) {
3226             // 0^x where x > 0 is 0.
3227             return formatResult(0);
3228         }
3229 
3230         // We should have checked earlier that 0 to a negative power is banned.
3231         MONGO_UNREACHABLE;
3232     } else if (baseLong == 1) {
3233         return formatResult(1);
3234     } else if (baseLong == -1) {
3235         // -1^0 = -1^2 = -1^4 = -1^6 ... = 1
3236         // -1^1 = -1^3 = -1^5 = -1^7 ... = -1
3237         return formatResult((expLong % 2 == 0) ? 1 : -1);
3238     } else if (expLong > 63 || expLong < 0) {
3239         // If the base is not 0, 1, or -1 and the exponent is too large, or negative,
3240         // the result cannot be represented as a long.
3241         return computeDoubleResult();
3242     }
3243 
3244     // It's still possible that the result cannot be represented as a long. If that's the case,
3245     // return a double.
3246     if (!representableAsLong(baseLong, expLong)) {
3247         return computeDoubleResult();
3248     }
3249 
3250     // Use repeated multiplication, since pow() casts args to doubles which could result in
3251     // loss of precision if arguments are very large.
3252     const auto computeWithRepeatedMultiplication = [](long long base, long long exp) {
3253         long long result = 1;
3254 
3255         while (exp > 1) {
3256             if (exp % 2 == 1) {
3257                 result *= base;
3258                 exp--;
3259             }
3260             // 'exp' is now guaranteed to be even.
3261             base *= base;
3262             exp /= 2;
3263         }
3264 
3265         if (exp) {
3266             invariant(exp == 1);
3267             result *= base;
3268         }
3269 
3270         return result;
3271     };
3272 
3273     return formatResult(computeWithRepeatedMultiplication(baseLong, expLong));
3274 }
3275 
3276 REGISTER_EXPRESSION(pow, ExpressionPow::parse);
getOpName() const3277 const char* ExpressionPow::getOpName() const {
3278     return "$pow";
3279 }
3280 
3281 /* ------------------------- ExpressionRange ------------------------------ */
3282 
evaluate(const Document & root,Variables * variables) const3283 Value ExpressionRange::evaluate(const Document& root, Variables* variables) const {
3284     Value startVal(vpOperand[0]->evaluate(root, variables));
3285     Value endVal(vpOperand[1]->evaluate(root, variables));
3286 
3287     uassert(34443,
3288             str::stream() << "$range requires a numeric starting value, found value of type: "
3289                           << typeName(startVal.getType()),
3290             startVal.numeric());
3291     uassert(34444,
3292             str::stream() << "$range requires a starting value that can be represented as a 32-bit "
3293                              "integer, found value: "
3294                           << startVal.toString(),
3295             startVal.integral());
3296     uassert(34445,
3297             str::stream() << "$range requires a numeric ending value, found value of type: "
3298                           << typeName(endVal.getType()),
3299             endVal.numeric());
3300     uassert(34446,
3301             str::stream() << "$range requires an ending value that can be represented as a 32-bit "
3302                              "integer, found value: "
3303                           << endVal.toString(),
3304             endVal.integral());
3305 
3306     int current = startVal.coerceToInt();
3307     int end = endVal.coerceToInt();
3308 
3309     int step = 1;
3310     if (vpOperand.size() == 3) {
3311         // A step was specified by the user.
3312         Value stepVal(vpOperand[2]->evaluate(root, variables));
3313 
3314         uassert(34447,
3315                 str::stream() << "$range requires a numeric step value, found value of type:"
3316                               << typeName(stepVal.getType()),
3317                 stepVal.numeric());
3318         uassert(34448,
3319                 str::stream() << "$range requires a step value that can be represented as a 32-bit "
3320                                  "integer, found value: "
3321                               << stepVal.toString(),
3322                 stepVal.integral());
3323         step = stepVal.coerceToInt();
3324 
3325         uassert(34449, "$range requires a non-zero step value", step != 0);
3326     }
3327 
3328     std::vector<Value> output;
3329 
3330     while ((step > 0 ? current < end : current > end)) {
3331         output.push_back(Value(current));
3332         current += step;
3333     }
3334 
3335     return Value(output);
3336 }
3337 
3338 REGISTER_EXPRESSION(range, ExpressionRange::parse);
getOpName() const3339 const char* ExpressionRange::getOpName() const {
3340     return "$range";
3341 }
3342 
3343 /* ------------------------ ExpressionReduce ------------------------------ */
3344 
3345 REGISTER_EXPRESSION(reduce, ExpressionReduce::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)3346 intrusive_ptr<Expression> ExpressionReduce::parse(
3347     const boost::intrusive_ptr<ExpressionContext>& expCtx,
3348     BSONElement expr,
3349     const VariablesParseState& vps) {
3350     uassert(40075,
3351             str::stream() << "$reduce requires an object as an argument, found: "
3352                           << typeName(expr.type()),
3353             expr.type() == Object);
3354 
3355     intrusive_ptr<ExpressionReduce> reduce(new ExpressionReduce(expCtx));
3356 
3357     // vpsSub is used only to parse 'in', which must have access to $$this and $$value.
3358     VariablesParseState vpsSub(vps);
3359     reduce->_thisVar = vpsSub.defineVariable("this");
3360     reduce->_valueVar = vpsSub.defineVariable("value");
3361 
3362     for (auto&& elem : expr.Obj()) {
3363         auto field = elem.fieldNameStringData();
3364 
3365         if (field == "input") {
3366             reduce->_input = parseOperand(expCtx, elem, vps);
3367         } else if (field == "initialValue") {
3368             reduce->_initial = parseOperand(expCtx, elem, vps);
3369         } else if (field == "in") {
3370             reduce->_in = parseOperand(expCtx, elem, vpsSub);
3371         } else {
3372             uasserted(40076, str::stream() << "$reduce found an unknown argument: " << field);
3373         }
3374     }
3375 
3376     uassert(40077, "$reduce requires 'input' to be specified", reduce->_input);
3377     uassert(40078, "$reduce requires 'initialValue' to be specified", reduce->_initial);
3378     uassert(40079, "$reduce requires 'in' to be specified", reduce->_in);
3379 
3380     return reduce;
3381 }
3382 
evaluate(const Document & root,Variables * variables) const3383 Value ExpressionReduce::evaluate(const Document& root, Variables* variables) const {
3384     Value inputVal = _input->evaluate(root, variables);
3385 
3386     if (inputVal.nullish()) {
3387         return Value(BSONNULL);
3388     }
3389 
3390     uassert(40080,
3391             str::stream() << "$reduce requires that 'input' be an array, found: "
3392                           << inputVal.toString(),
3393             inputVal.isArray());
3394 
3395     Value accumulatedValue = _initial->evaluate(root, variables);
3396 
3397     for (auto&& elem : inputVal.getArray()) {
3398         variables->setValue(_thisVar, elem);
3399         variables->setValue(_valueVar, accumulatedValue);
3400 
3401         accumulatedValue = _in->evaluate(root, variables);
3402     }
3403 
3404     return accumulatedValue;
3405 }
3406 
optimize()3407 intrusive_ptr<Expression> ExpressionReduce::optimize() {
3408     _input = _input->optimize();
3409     _initial = _initial->optimize();
3410     _in = _in->optimize();
3411     return this;
3412 }
3413 
_doAddDependencies(DepsTracker * deps) const3414 void ExpressionReduce::_doAddDependencies(DepsTracker* deps) const {
3415     _input->addDependencies(deps);
3416     _initial->addDependencies(deps);
3417     _in->addDependencies(deps);
3418 }
3419 
serialize(bool explain) const3420 Value ExpressionReduce::serialize(bool explain) const {
3421     return Value(Document{{"$reduce",
3422                            Document{{"input", _input->serialize(explain)},
3423                                     {"initialValue", _initial->serialize(explain)},
3424                                     {"in", _in->serialize(explain)}}}});
3425 }
3426 
3427 /* ------------------------ ExpressionReverseArray ------------------------ */
3428 
evaluate(const Document & root,Variables * variables) const3429 Value ExpressionReverseArray::evaluate(const Document& root, Variables* variables) const {
3430     Value input(vpOperand[0]->evaluate(root, variables));
3431 
3432     if (input.nullish()) {
3433         return Value(BSONNULL);
3434     }
3435 
3436     uassert(34435,
3437             str::stream() << "The argument to $reverseArray must be an array, but was of type: "
3438                           << typeName(input.getType()),
3439             input.isArray());
3440 
3441     if (input.getArrayLength() < 2) {
3442         return input;
3443     }
3444 
3445     std::vector<Value> array = input.getArray();
3446     std::reverse(array.begin(), array.end());
3447     return Value(array);
3448 }
3449 
3450 REGISTER_EXPRESSION(reverseArray, ExpressionReverseArray::parse);
getOpName() const3451 const char* ExpressionReverseArray::getOpName() const {
3452     return "$reverseArray";
3453 }
3454 
3455 namespace {
arrayToSet(const Value & val,const ValueComparator & valueComparator)3456 ValueSet arrayToSet(const Value& val, const ValueComparator& valueComparator) {
3457     const vector<Value>& array = val.getArray();
3458     ValueSet valueSet = valueComparator.makeOrderedValueSet();
3459     valueSet.insert(array.begin(), array.end());
3460     return valueSet;
3461 }
3462 }  // namespace
3463 
3464 /* ----------------------- ExpressionSetDifference ---------------------------- */
3465 
evaluate(const Document & root,Variables * variables) const3466 Value ExpressionSetDifference::evaluate(const Document& root, Variables* variables) const {
3467     const Value lhs = vpOperand[0]->evaluate(root, variables);
3468     const Value rhs = vpOperand[1]->evaluate(root, variables);
3469 
3470     if (lhs.nullish() || rhs.nullish()) {
3471         return Value(BSONNULL);
3472     }
3473 
3474     uassert(17048,
3475             str::stream() << "both operands of $setDifference must be arrays. First "
3476                           << "argument is of type: "
3477                           << typeName(lhs.getType()),
3478             lhs.isArray());
3479     uassert(17049,
3480             str::stream() << "both operands of $setDifference must be arrays. Second "
3481                           << "argument is of type: "
3482                           << typeName(rhs.getType()),
3483             rhs.isArray());
3484 
3485     ValueSet rhsSet = arrayToSet(rhs, getExpressionContext()->getValueComparator());
3486     const vector<Value>& lhsArray = lhs.getArray();
3487     vector<Value> returnVec;
3488 
3489     for (vector<Value>::const_iterator it = lhsArray.begin(); it != lhsArray.end(); ++it) {
3490         // rhsSet serves the dual role of filtering out elements that were originally present
3491         // in RHS and of eleminating duplicates from LHS
3492         if (rhsSet.insert(*it).second) {
3493             returnVec.push_back(*it);
3494         }
3495     }
3496     return Value(std::move(returnVec));
3497 }
3498 
3499 REGISTER_EXPRESSION(setDifference, ExpressionSetDifference::parse);
getOpName() const3500 const char* ExpressionSetDifference::getOpName() const {
3501     return "$setDifference";
3502 }
3503 
3504 /* ----------------------- ExpressionSetEquals ---------------------------- */
3505 
validateArguments(const ExpressionVector & args) const3506 void ExpressionSetEquals::validateArguments(const ExpressionVector& args) const {
3507     uassert(17045,
3508             str::stream() << "$setEquals needs at least two arguments had: " << args.size(),
3509             args.size() >= 2);
3510 }
3511 
evaluate(const Document & root,Variables * variables) const3512 Value ExpressionSetEquals::evaluate(const Document& root, Variables* variables) const {
3513     const size_t n = vpOperand.size();
3514     const auto& valueComparator = getExpressionContext()->getValueComparator();
3515     ValueSet lhs = valueComparator.makeOrderedValueSet();
3516 
3517     for (size_t i = 0; i < n; i++) {
3518         const Value nextEntry = vpOperand[i]->evaluate(root, variables);
3519         uassert(17044,
3520                 str::stream() << "All operands of $setEquals must be arrays. One "
3521                               << "argument is of type: "
3522                               << typeName(nextEntry.getType()),
3523                 nextEntry.isArray());
3524 
3525         if (i == 0) {
3526             lhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3527         } else {
3528             ValueSet rhs = valueComparator.makeOrderedValueSet();
3529             rhs.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3530             if (lhs.size() != rhs.size()) {
3531                 return Value(false);
3532             }
3533 
3534             if (!std::equal(lhs.begin(), lhs.end(), rhs.begin(), valueComparator.getEqualTo())) {
3535                 return Value(false);
3536             }
3537         }
3538     }
3539     return Value(true);
3540 }
3541 
3542 REGISTER_EXPRESSION(setEquals, ExpressionSetEquals::parse);
getOpName() const3543 const char* ExpressionSetEquals::getOpName() const {
3544     return "$setEquals";
3545 }
3546 
3547 /* ----------------------- ExpressionSetIntersection ---------------------------- */
3548 
evaluate(const Document & root,Variables * variables) const3549 Value ExpressionSetIntersection::evaluate(const Document& root, Variables* variables) const {
3550     const size_t n = vpOperand.size();
3551     const auto& valueComparator = getExpressionContext()->getValueComparator();
3552     ValueSet currentIntersection = valueComparator.makeOrderedValueSet();
3553     for (size_t i = 0; i < n; i++) {
3554         const Value nextEntry = vpOperand[i]->evaluate(root, variables);
3555         if (nextEntry.nullish()) {
3556             return Value(BSONNULL);
3557         }
3558         uassert(17047,
3559                 str::stream() << "All operands of $setIntersection must be arrays. One "
3560                               << "argument is of type: "
3561                               << typeName(nextEntry.getType()),
3562                 nextEntry.isArray());
3563 
3564         if (i == 0) {
3565             currentIntersection.insert(nextEntry.getArray().begin(), nextEntry.getArray().end());
3566         } else {
3567             ValueSet nextSet = arrayToSet(nextEntry, valueComparator);
3568             if (currentIntersection.size() > nextSet.size()) {
3569                 // to iterate over whichever is the smaller set
3570                 nextSet.swap(currentIntersection);
3571             }
3572             ValueSet::iterator it = currentIntersection.begin();
3573             while (it != currentIntersection.end()) {
3574                 if (!nextSet.count(*it)) {
3575                     ValueSet::iterator del = it;
3576                     ++it;
3577                     currentIntersection.erase(del);
3578                 } else {
3579                     ++it;
3580                 }
3581             }
3582         }
3583         if (currentIntersection.empty()) {
3584             break;
3585         }
3586     }
3587     return Value(vector<Value>(currentIntersection.begin(), currentIntersection.end()));
3588 }
3589 
3590 REGISTER_EXPRESSION(setIntersection, ExpressionSetIntersection::parse);
getOpName() const3591 const char* ExpressionSetIntersection::getOpName() const {
3592     return "$setIntersection";
3593 }
3594 
3595 /* ----------------------- ExpressionSetIsSubset ---------------------------- */
3596 
3597 namespace {
setIsSubsetHelper(const vector<Value> & lhs,const ValueSet & rhs)3598 Value setIsSubsetHelper(const vector<Value>& lhs, const ValueSet& rhs) {
3599     // do not shortcircuit when lhs.size() > rhs.size()
3600     // because lhs can have redundant entries
3601     for (vector<Value>::const_iterator it = lhs.begin(); it != lhs.end(); ++it) {
3602         if (!rhs.count(*it)) {
3603             return Value(false);
3604         }
3605     }
3606     return Value(true);
3607 }
3608 }  // namespace
3609 
evaluate(const Document & root,Variables * variables) const3610 Value ExpressionSetIsSubset::evaluate(const Document& root, Variables* variables) const {
3611     const Value lhs = vpOperand[0]->evaluate(root, variables);
3612     const Value rhs = vpOperand[1]->evaluate(root, variables);
3613 
3614     uassert(17046,
3615             str::stream() << "both operands of $setIsSubset must be arrays. First "
3616                           << "argument is of type: "
3617                           << typeName(lhs.getType()),
3618             lhs.isArray());
3619     uassert(17042,
3620             str::stream() << "both operands of $setIsSubset must be arrays. Second "
3621                           << "argument is of type: "
3622                           << typeName(rhs.getType()),
3623             rhs.isArray());
3624 
3625     return setIsSubsetHelper(lhs.getArray(),
3626                              arrayToSet(rhs, getExpressionContext()->getValueComparator()));
3627 }
3628 
3629 /**
3630  * This class handles the case where the RHS set is constant.
3631  *
3632  * Since it is constant we can construct the hashset once which makes the runtime performance
3633  * effectively constant with respect to the size of RHS. Large, constant RHS is expected to be a
3634  * major use case for $redact and this has been verified to improve performance significantly.
3635  */
3636 class ExpressionSetIsSubset::Optimized : public ExpressionSetIsSubset {
3637 public:
Optimized(const boost::intrusive_ptr<ExpressionContext> & expCtx,const ValueSet & cachedRhsSet,const ExpressionVector & operands)3638     Optimized(const boost::intrusive_ptr<ExpressionContext>& expCtx,
3639               const ValueSet& cachedRhsSet,
3640               const ExpressionVector& operands)
3641         : ExpressionSetIsSubset(expCtx), _cachedRhsSet(cachedRhsSet) {
3642         vpOperand = operands;
3643     }
3644 
evaluate(const Document & root,Variables * variables) const3645     virtual Value evaluate(const Document& root, Variables* variables) const {
3646         const Value lhs = vpOperand[0]->evaluate(root, variables);
3647 
3648         uassert(17310,
3649                 str::stream() << "both operands of $setIsSubset must be arrays. First "
3650                               << "argument is of type: "
3651                               << typeName(lhs.getType()),
3652                 lhs.isArray());
3653 
3654         return setIsSubsetHelper(lhs.getArray(), _cachedRhsSet);
3655     }
3656 
3657 private:
3658     const ValueSet _cachedRhsSet;
3659 };
3660 
optimize()3661 intrusive_ptr<Expression> ExpressionSetIsSubset::optimize() {
3662     // perfore basic optimizations
3663     intrusive_ptr<Expression> optimized = ExpressionNary::optimize();
3664 
3665     // if ExpressionNary::optimize() created a new value, return it directly
3666     if (optimized.get() != this)
3667         return optimized;
3668 
3669     if (ExpressionConstant* ec = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
3670         const Value rhs = ec->getValue();
3671         uassert(17311,
3672                 str::stream() << "both operands of $setIsSubset must be arrays. Second "
3673                               << "argument is of type: "
3674                               << typeName(rhs.getType()),
3675                 rhs.isArray());
3676 
3677         intrusive_ptr<Expression> optimizedWithConstant(
3678             new Optimized(this->getExpressionContext(),
3679                           arrayToSet(rhs, getExpressionContext()->getValueComparator()),
3680                           vpOperand));
3681         return optimizedWithConstant;
3682     }
3683     return optimized;
3684 }
3685 
3686 REGISTER_EXPRESSION(setIsSubset, ExpressionSetIsSubset::parse);
getOpName() const3687 const char* ExpressionSetIsSubset::getOpName() const {
3688     return "$setIsSubset";
3689 }
3690 
3691 /* ----------------------- ExpressionSetUnion ---------------------------- */
3692 
evaluate(const Document & root,Variables * variables) const3693 Value ExpressionSetUnion::evaluate(const Document& root, Variables* variables) const {
3694     ValueSet unionedSet = getExpressionContext()->getValueComparator().makeOrderedValueSet();
3695     const size_t n = vpOperand.size();
3696     for (size_t i = 0; i < n; i++) {
3697         const Value newEntries = vpOperand[i]->evaluate(root, variables);
3698         if (newEntries.nullish()) {
3699             return Value(BSONNULL);
3700         }
3701         uassert(17043,
3702                 str::stream() << "All operands of $setUnion must be arrays. One argument"
3703                               << " is of type: "
3704                               << typeName(newEntries.getType()),
3705                 newEntries.isArray());
3706 
3707         unionedSet.insert(newEntries.getArray().begin(), newEntries.getArray().end());
3708     }
3709     return Value(vector<Value>(unionedSet.begin(), unionedSet.end()));
3710 }
3711 
3712 REGISTER_EXPRESSION(setUnion, ExpressionSetUnion::parse);
getOpName() const3713 const char* ExpressionSetUnion::getOpName() const {
3714     return "$setUnion";
3715 }
3716 
3717 /* ----------------------- ExpressionIsArray ---------------------------- */
3718 
evaluate(const Document & root,Variables * variables) const3719 Value ExpressionIsArray::evaluate(const Document& root, Variables* variables) const {
3720     Value argument = vpOperand[0]->evaluate(root, variables);
3721     return Value(argument.isArray());
3722 }
3723 
3724 REGISTER_EXPRESSION(isArray, ExpressionIsArray::parse);
getOpName() const3725 const char* ExpressionIsArray::getOpName() const {
3726     return "$isArray";
3727 }
3728 
3729 /* ----------------------- ExpressionSlice ---------------------------- */
3730 
evaluate(const Document & root,Variables * variables) const3731 Value ExpressionSlice::evaluate(const Document& root, Variables* variables) const {
3732     const size_t n = vpOperand.size();
3733 
3734     Value arrayVal = vpOperand[0]->evaluate(root, variables);
3735     // Could be either a start index or the length from 0.
3736     Value arg2 = vpOperand[1]->evaluate(root, variables);
3737 
3738     if (arrayVal.nullish() || arg2.nullish()) {
3739         return Value(BSONNULL);
3740     }
3741 
3742     uassert(28724,
3743             str::stream() << "First argument to $slice must be an array, but is"
3744                           << " of type: "
3745                           << typeName(arrayVal.getType()),
3746             arrayVal.isArray());
3747     uassert(28725,
3748             str::stream() << "Second argument to $slice must be a numeric value,"
3749                           << " but is of type: "
3750                           << typeName(arg2.getType()),
3751             arg2.numeric());
3752     uassert(28726,
3753             str::stream() << "Second argument to $slice can't be represented as"
3754                           << " a 32-bit integer: "
3755                           << arg2.coerceToDouble(),
3756             arg2.integral());
3757 
3758     const auto& array = arrayVal.getArray();
3759     size_t start;
3760     size_t end;
3761 
3762     if (n == 2) {
3763         // Only count given.
3764         int count = arg2.coerceToInt();
3765         start = 0;
3766         end = array.size();
3767         if (count >= 0) {
3768             end = std::min(end, size_t(count));
3769         } else {
3770             // Negative count's start from the back. If a abs(count) is greater
3771             // than the
3772             // length of the array, return the whole array.
3773             start = std::max(0, static_cast<int>(array.size()) + count);
3774         }
3775     } else {
3776         // We have both a start index and a count.
3777         int startInt = arg2.coerceToInt();
3778         if (startInt < 0) {
3779             // Negative values start from the back. If a abs(start) is greater
3780             // than the length
3781             // of the array, start from 0.
3782             start = std::max(0, static_cast<int>(array.size()) + startInt);
3783         } else {
3784             start = std::min(array.size(), size_t(startInt));
3785         }
3786 
3787         Value countVal = vpOperand[2]->evaluate(root, variables);
3788 
3789         if (countVal.nullish()) {
3790             return Value(BSONNULL);
3791         }
3792 
3793         uassert(28727,
3794                 str::stream() << "Third argument to $slice must be numeric, but "
3795                               << "is of type: "
3796                               << typeName(countVal.getType()),
3797                 countVal.numeric());
3798         uassert(28728,
3799                 str::stream() << "Third argument to $slice can't be represented"
3800                               << " as a 32-bit integer: "
3801                               << countVal.coerceToDouble(),
3802                 countVal.integral());
3803         uassert(28729,
3804                 str::stream() << "Third argument to $slice must be positive: "
3805                               << countVal.coerceToInt(),
3806                 countVal.coerceToInt() > 0);
3807 
3808         size_t count = size_t(countVal.coerceToInt());
3809         end = std::min(start + count, array.size());
3810     }
3811 
3812     return Value(vector<Value>(array.begin() + start, array.begin() + end));
3813 }
3814 
3815 REGISTER_EXPRESSION(slice, ExpressionSlice::parse);
getOpName() const3816 const char* ExpressionSlice::getOpName() const {
3817     return "$slice";
3818 }
3819 
3820 /* ----------------------- ExpressionSize ---------------------------- */
3821 
evaluate(const Document & root,Variables * variables) const3822 Value ExpressionSize::evaluate(const Document& root, Variables* variables) const {
3823     Value array = vpOperand[0]->evaluate(root, variables);
3824 
3825     uassert(17124,
3826             str::stream() << "The argument to $size must be an array, but was of type: "
3827                           << typeName(array.getType()),
3828             array.isArray());
3829     return Value::createIntOrLong(array.getArray().size());
3830 }
3831 
3832 REGISTER_EXPRESSION(size, ExpressionSize::parse);
getOpName() const3833 const char* ExpressionSize::getOpName() const {
3834     return "$size";
3835 }
3836 
3837 /* ----------------------- ExpressionSplit --------------------------- */
3838 
evaluate(const Document & root,Variables * variables) const3839 Value ExpressionSplit::evaluate(const Document& root, Variables* variables) const {
3840     Value inputArg = vpOperand[0]->evaluate(root, variables);
3841     Value separatorArg = vpOperand[1]->evaluate(root, variables);
3842 
3843     if (inputArg.nullish() || separatorArg.nullish()) {
3844         return Value(BSONNULL);
3845     }
3846 
3847     uassert(40085,
3848             str::stream() << "$split requires an expression that evaluates to a string as a first "
3849                              "argument, found: "
3850                           << typeName(inputArg.getType()),
3851             inputArg.getType() == BSONType::String);
3852     uassert(40086,
3853             str::stream() << "$split requires an expression that evaluates to a string as a second "
3854                              "argument, found: "
3855                           << typeName(separatorArg.getType()),
3856             separatorArg.getType() == BSONType::String);
3857 
3858     std::string input = inputArg.getString();
3859     std::string separator = separatorArg.getString();
3860 
3861     uassert(40087, "$split requires a non-empty separator", !separator.empty());
3862 
3863     std::vector<Value> output;
3864 
3865     // Keep track of the index at which the current output string began.
3866     size_t splitStartIndex = 0;
3867 
3868     // Iterate through 'input' and check to see if 'separator' matches at any point.
3869     for (size_t i = 0; i < input.size();) {
3870         if (stringHasTokenAtIndex(i, input, separator)) {
3871             // We matched; add the current string to our output and jump ahead.
3872             StringData splitString(input.c_str() + splitStartIndex, i - splitStartIndex);
3873             output.push_back(Value(splitString));
3874             i += separator.size();
3875             splitStartIndex = i;
3876         } else {
3877             // We did not match, continue to the next character.
3878             ++i;
3879         }
3880     }
3881 
3882     StringData splitString(input.c_str() + splitStartIndex, input.size() - splitStartIndex);
3883     output.push_back(Value(splitString));
3884 
3885     return Value(output);
3886 }
3887 
3888 REGISTER_EXPRESSION(split, ExpressionSplit::parse);
getOpName() const3889 const char* ExpressionSplit::getOpName() const {
3890     return "$split";
3891 }
3892 
3893 /* ----------------------- ExpressionSqrt ---------------------------- */
3894 
evaluateNumericArg(const Value & numericArg) const3895 Value ExpressionSqrt::evaluateNumericArg(const Value& numericArg) const {
3896     auto checkArg = [](bool nonNegative) {
3897         uassert(28714, "$sqrt's argument must be greater than or equal to 0", nonNegative);
3898     };
3899 
3900     if (numericArg.getType() == NumberDecimal) {
3901         Decimal128 argDec = numericArg.getDecimal();
3902         checkArg(!argDec.isLess(Decimal128::kNormalizedZero));  // NaN returns Nan without error
3903         return Value(argDec.squareRoot());
3904     }
3905     double argDouble = numericArg.coerceToDouble();
3906     checkArg(!(argDouble < 0));  // NaN returns Nan without error
3907     return Value(sqrt(argDouble));
3908 }
3909 
3910 REGISTER_EXPRESSION(sqrt, ExpressionSqrt::parse);
getOpName() const3911 const char* ExpressionSqrt::getOpName() const {
3912     return "$sqrt";
3913 }
3914 
3915 /* ----------------------- ExpressionStrcasecmp ---------------------------- */
3916 
evaluate(const Document & root,Variables * variables) const3917 Value ExpressionStrcasecmp::evaluate(const Document& root, Variables* variables) const {
3918     Value pString1(vpOperand[0]->evaluate(root, variables));
3919     Value pString2(vpOperand[1]->evaluate(root, variables));
3920 
3921     /* boost::iequals returns a bool not an int so strings must actually be allocated */
3922     string str1 = boost::to_upper_copy(pString1.coerceToString());
3923     string str2 = boost::to_upper_copy(pString2.coerceToString());
3924     int result = str1.compare(str2);
3925 
3926     if (result == 0)
3927         return Value(0);
3928     else if (result > 0)
3929         return Value(1);
3930     else
3931         return Value(-1);
3932 }
3933 
3934 REGISTER_EXPRESSION(strcasecmp, ExpressionStrcasecmp::parse);
getOpName() const3935 const char* ExpressionStrcasecmp::getOpName() const {
3936     return "$strcasecmp";
3937 }
3938 
3939 /* ----------------------- ExpressionSubstrBytes ---------------------------- */
3940 
evaluate(const Document & root,Variables * variables) const3941 Value ExpressionSubstrBytes::evaluate(const Document& root, Variables* variables) const {
3942     Value pString(vpOperand[0]->evaluate(root, variables));
3943     Value pLower(vpOperand[1]->evaluate(root, variables));
3944     Value pLength(vpOperand[2]->evaluate(root, variables));
3945 
3946     string str = pString.coerceToString();
3947     uassert(16034,
3948             str::stream() << getOpName()
3949                           << ":  starting index must be a numeric type (is BSON type "
3950                           << typeName(pLower.getType())
3951                           << ")",
3952             (pLower.getType() == NumberInt || pLower.getType() == NumberLong ||
3953              pLower.getType() == NumberDouble));
3954     uassert(16035,
3955             str::stream() << getOpName() << ":  length must be a numeric type (is BSON type "
3956                           << typeName(pLength.getType())
3957                           << ")",
3958             (pLength.getType() == NumberInt || pLength.getType() == NumberLong ||
3959              pLength.getType() == NumberDouble));
3960 
3961     string::size_type lower = static_cast<string::size_type>(pLower.coerceToLong());
3962     string::size_type length = static_cast<string::size_type>(pLength.coerceToLong());
3963 
3964     uassert(28656,
3965             str::stream() << getOpName()
3966                           << ":  Invalid range, starting index is a UTF-8 continuation byte.",
3967             (lower >= str.length() || !str::isUTF8ContinuationByte(str[lower])));
3968 
3969     // Check the byte after the last character we'd return. If it is a continuation byte, that
3970     // means we're in the middle of a UTF-8 character.
3971     uassert(
3972         28657,
3973         str::stream() << getOpName()
3974                       << ":  Invalid range, ending index is in the middle of a UTF-8 character.",
3975         (lower + length >= str.length() || !str::isUTF8ContinuationByte(str[lower + length])));
3976 
3977     if (lower >= str.length()) {
3978         // If lower > str.length() then string::substr() will throw out_of_range, so return an
3979         // empty string if lower is not a valid string index.
3980         return Value(StringData());
3981     }
3982     return Value(str.substr(lower, length));
3983 }
3984 
3985 // $substr is deprecated in favor of $substrBytes, but for now will just parse into a $substrBytes.
3986 REGISTER_EXPRESSION(substrBytes, ExpressionSubstrBytes::parse);
3987 REGISTER_EXPRESSION(substr, ExpressionSubstrBytes::parse);
getOpName() const3988 const char* ExpressionSubstrBytes::getOpName() const {
3989     return "$substrBytes";
3990 }
3991 
3992 /* ----------------------- ExpressionSubstrCP ---------------------------- */
3993 
evaluate(const Document & root,Variables * variables) const3994 Value ExpressionSubstrCP::evaluate(const Document& root, Variables* variables) const {
3995     Value inputVal(vpOperand[0]->evaluate(root, variables));
3996     Value lowerVal(vpOperand[1]->evaluate(root, variables));
3997     Value lengthVal(vpOperand[2]->evaluate(root, variables));
3998 
3999     std::string str = inputVal.coerceToString();
4000     uassert(34450,
4001             str::stream() << getOpName() << ": starting index must be a numeric type (is BSON type "
4002                           << typeName(lowerVal.getType())
4003                           << ")",
4004             lowerVal.numeric());
4005     uassert(34451,
4006             str::stream() << getOpName()
4007                           << ": starting index cannot be represented as a 32-bit integral value: "
4008                           << lowerVal.toString(),
4009             lowerVal.integral());
4010     uassert(34452,
4011             str::stream() << getOpName() << ": length must be a numeric type (is BSON type "
4012                           << typeName(lengthVal.getType())
4013                           << ")",
4014             lengthVal.numeric());
4015     uassert(34453,
4016             str::stream() << getOpName()
4017                           << ": length cannot be represented as a 32-bit integral value: "
4018                           << lengthVal.toString(),
4019             lengthVal.integral());
4020 
4021     int startIndexCodePoints = lowerVal.coerceToInt();
4022     int length = lengthVal.coerceToInt();
4023 
4024     uassert(34454,
4025             str::stream() << getOpName() << ": length must be a nonnegative integer.",
4026             length >= 0);
4027 
4028     uassert(34455,
4029             str::stream() << getOpName() << ": the starting index must be nonnegative integer.",
4030             startIndexCodePoints >= 0);
4031 
4032     size_t startIndexBytes = 0;
4033 
4034     for (int i = 0; i < startIndexCodePoints; i++) {
4035         if (startIndexBytes >= str.size()) {
4036             return Value(StringData());
4037         }
4038         uassert(34456,
4039                 str::stream() << getOpName() << ": invalid UTF-8 string",
4040                 !str::isUTF8ContinuationByte(str[startIndexBytes]));
4041         size_t codePointLength = getCodePointLength(str[startIndexBytes]);
4042         uassert(
4043             34457, str::stream() << getOpName() << ": invalid UTF-8 string", codePointLength <= 4);
4044         startIndexBytes += codePointLength;
4045     }
4046 
4047     size_t endIndexBytes = startIndexBytes;
4048 
4049     for (int i = 0; i < length && endIndexBytes < str.size(); i++) {
4050         uassert(34458,
4051                 str::stream() << getOpName() << ": invalid UTF-8 string",
4052                 !str::isUTF8ContinuationByte(str[endIndexBytes]));
4053         size_t codePointLength = getCodePointLength(str[endIndexBytes]);
4054         uassert(
4055             34459, str::stream() << getOpName() << ": invalid UTF-8 string", codePointLength <= 4);
4056         endIndexBytes += codePointLength;
4057     }
4058 
4059     return Value(std::string(str, startIndexBytes, endIndexBytes - startIndexBytes));
4060 }
4061 
4062 REGISTER_EXPRESSION(substrCP, ExpressionSubstrCP::parse);
getOpName() const4063 const char* ExpressionSubstrCP::getOpName() const {
4064     return "$substrCP";
4065 }
4066 
4067 /* ----------------------- ExpressionStrLenBytes ------------------------- */
4068 
evaluate(const Document & root,Variables * variables) const4069 Value ExpressionStrLenBytes::evaluate(const Document& root, Variables* variables) const {
4070     Value str(vpOperand[0]->evaluate(root, variables));
4071 
4072     uassert(34473,
4073             str::stream() << "$strLenBytes requires a string argument, found: "
4074                           << typeName(str.getType()),
4075             str.getType() == String);
4076 
4077     size_t strLen = str.getString().size();
4078 
4079     uassert(34470,
4080             "string length could not be represented as an int.",
4081             strLen <= std::numeric_limits<int>::max());
4082     return Value(static_cast<int>(strLen));
4083 }
4084 
4085 REGISTER_EXPRESSION(strLenBytes, ExpressionStrLenBytes::parse);
getOpName() const4086 const char* ExpressionStrLenBytes::getOpName() const {
4087     return "$strLenBytes";
4088 }
4089 
4090 /* ----------------------- ExpressionStrLenCP ------------------------- */
4091 
evaluate(const Document & root,Variables * variables) const4092 Value ExpressionStrLenCP::evaluate(const Document& root, Variables* variables) const {
4093     Value val(vpOperand[0]->evaluate(root, variables));
4094 
4095     uassert(34471,
4096             str::stream() << "$strLenCP requires a string argument, found: "
4097                           << typeName(val.getType()),
4098             val.getType() == String);
4099 
4100     std::string stringVal = val.getString();
4101     size_t strLen = str::lengthInUTF8CodePoints(stringVal);
4102 
4103     uassert(34472,
4104             "string length could not be represented as an int.",
4105             strLen <= std::numeric_limits<int>::max());
4106 
4107     return Value(static_cast<int>(strLen));
4108 }
4109 
4110 REGISTER_EXPRESSION(strLenCP, ExpressionStrLenCP::parse);
getOpName() const4111 const char* ExpressionStrLenCP::getOpName() const {
4112     return "$strLenCP";
4113 }
4114 
4115 /* ----------------------- ExpressionSubtract ---------------------------- */
4116 
evaluate(const Document & root,Variables * variables) const4117 Value ExpressionSubtract::evaluate(const Document& root, Variables* variables) const {
4118     Value lhs = vpOperand[0]->evaluate(root, variables);
4119     Value rhs = vpOperand[1]->evaluate(root, variables);
4120 
4121     BSONType diffType = Value::getWidestNumeric(rhs.getType(), lhs.getType());
4122 
4123     if (diffType == NumberDecimal) {
4124         Decimal128 right = rhs.coerceToDecimal();
4125         Decimal128 left = lhs.coerceToDecimal();
4126         return Value(left.subtract(right));
4127     } else if (diffType == NumberDouble) {
4128         double right = rhs.coerceToDouble();
4129         double left = lhs.coerceToDouble();
4130         return Value(left - right);
4131     } else if (diffType == NumberLong) {
4132         long long result;
4133 
4134         // If there is an overflow, convert the values to doubles.
4135         if (mongoSignedSubtractOverflow64(lhs.coerceToLong(), rhs.coerceToLong(), &result)) {
4136             return Value(lhs.coerceToDouble() - rhs.coerceToDouble());
4137         }
4138         return Value(result);
4139     } else if (diffType == NumberInt) {
4140         long long right = rhs.coerceToLong();
4141         long long left = lhs.coerceToLong();
4142         return Value::createIntOrLong(left - right);
4143     } else if (lhs.nullish() || rhs.nullish()) {
4144         return Value(BSONNULL);
4145     } else if (lhs.getType() == Date) {
4146         if (rhs.getType() == Date) {
4147             return Value(durationCount<Milliseconds>(lhs.getDate() - rhs.getDate()));
4148         } else if (rhs.numeric()) {
4149             return Value(lhs.getDate() - Milliseconds(rhs.coerceToLong()));
4150         } else {
4151             uasserted(16613,
4152                       str::stream() << "cant $subtract a " << typeName(rhs.getType())
4153                                     << " from a Date");
4154         }
4155     } else {
4156         uasserted(16556,
4157                   str::stream() << "cant $subtract a" << typeName(rhs.getType()) << " from a "
4158                                 << typeName(lhs.getType()));
4159     }
4160 }
4161 
4162 REGISTER_EXPRESSION(subtract, ExpressionSubtract::parse);
getOpName() const4163 const char* ExpressionSubtract::getOpName() const {
4164     return "$subtract";
4165 }
4166 
4167 /* ------------------------- ExpressionSwitch ------------------------------ */
4168 
4169 REGISTER_EXPRESSION(switch, ExpressionSwitch::parse);
4170 
evaluate(const Document & root,Variables * variables) const4171 Value ExpressionSwitch::evaluate(const Document& root, Variables* variables) const {
4172     for (auto&& branch : _branches) {
4173         Value caseExpression(branch.first->evaluate(root, variables));
4174 
4175         if (caseExpression.coerceToBool()) {
4176             return branch.second->evaluate(root, variables);
4177         }
4178     }
4179 
4180     uassert(40066,
4181             "$switch could not find a matching branch for an input, and no default was specified.",
4182             _default);
4183 
4184     return _default->evaluate(root, variables);
4185 }
4186 
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)4187 boost::intrusive_ptr<Expression> ExpressionSwitch::parse(
4188     const boost::intrusive_ptr<ExpressionContext>& expCtx,
4189     BSONElement expr,
4190     const VariablesParseState& vps) {
4191     uassert(40060,
4192             str::stream() << "$switch requires an object as an argument, found: "
4193                           << typeName(expr.type()),
4194             expr.type() == Object);
4195 
4196     intrusive_ptr<ExpressionSwitch> expression(new ExpressionSwitch(expCtx));
4197 
4198     for (auto&& elem : expr.Obj()) {
4199         auto field = elem.fieldNameStringData();
4200 
4201         if (field == "branches") {
4202             // Parse each branch separately.
4203             uassert(40061,
4204                     str::stream() << "$switch expected an array for 'branches', found: "
4205                                   << typeName(elem.type()),
4206                     elem.type() == Array);
4207 
4208             for (auto&& branch : elem.Array()) {
4209                 uassert(40062,
4210                         str::stream() << "$switch expected each branch to be an object, found: "
4211                                       << typeName(branch.type()),
4212                         branch.type() == Object);
4213 
4214                 ExpressionPair branchExpression;
4215 
4216                 for (auto&& branchElement : branch.Obj()) {
4217                     auto branchField = branchElement.fieldNameStringData();
4218 
4219                     if (branchField == "case") {
4220                         branchExpression.first = parseOperand(expCtx, branchElement, vps);
4221                     } else if (branchField == "then") {
4222                         branchExpression.second = parseOperand(expCtx, branchElement, vps);
4223                     } else {
4224                         uasserted(40063,
4225                                   str::stream() << "$switch found an unknown argument to a branch: "
4226                                                 << branchField);
4227                     }
4228                 }
4229 
4230                 uassert(40064,
4231                         "$switch requires each branch have a 'case' expression",
4232                         branchExpression.first);
4233                 uassert(40065,
4234                         "$switch requires each branch have a 'then' expression.",
4235                         branchExpression.second);
4236 
4237                 expression->_branches.push_back(branchExpression);
4238             }
4239         } else if (field == "default") {
4240             // Optional, arbitrary expression.
4241             expression->_default = parseOperand(expCtx, elem, vps);
4242         } else {
4243             uasserted(40067, str::stream() << "$switch found an unknown argument: " << field);
4244         }
4245     }
4246 
4247     uassert(40068, "$switch requires at least one branch.", !expression->_branches.empty());
4248 
4249     return expression;
4250 }
4251 
_doAddDependencies(DepsTracker * deps) const4252 void ExpressionSwitch::_doAddDependencies(DepsTracker* deps) const {
4253     for (auto&& branch : _branches) {
4254         branch.first->addDependencies(deps);
4255         branch.second->addDependencies(deps);
4256     }
4257 
4258     if (_default) {
4259         _default->addDependencies(deps);
4260     }
4261 }
4262 
optimize()4263 boost::intrusive_ptr<Expression> ExpressionSwitch::optimize() {
4264     if (_default) {
4265         _default = _default->optimize();
4266     }
4267 
4268     std::transform(_branches.begin(),
4269                    _branches.end(),
4270                    _branches.begin(),
4271                    [](ExpressionPair branch) -> ExpressionPair {
4272                        return {branch.first->optimize(), branch.second->optimize()};
4273                    });
4274 
4275     return this;
4276 }
4277 
serialize(bool explain) const4278 Value ExpressionSwitch::serialize(bool explain) const {
4279     std::vector<Value> serializedBranches;
4280     serializedBranches.reserve(_branches.size());
4281 
4282     for (auto&& branch : _branches) {
4283         serializedBranches.push_back(Value(Document{{"case", branch.first->serialize(explain)},
4284                                                     {"then", branch.second->serialize(explain)}}));
4285     }
4286 
4287     if (_default) {
4288         return Value(Document{{"$switch",
4289                                Document{{"branches", Value(serializedBranches)},
4290                                         {"default", _default->serialize(explain)}}}});
4291     }
4292 
4293     return Value(Document{{"$switch", Document{{"branches", Value(serializedBranches)}}}});
4294 }
4295 
4296 /* ------------------------- ExpressionToLower ----------------------------- */
4297 
evaluate(const Document & root,Variables * variables) const4298 Value ExpressionToLower::evaluate(const Document& root, Variables* variables) const {
4299     Value pString(vpOperand[0]->evaluate(root, variables));
4300     string str = pString.coerceToString();
4301     boost::to_lower(str);
4302     return Value(str);
4303 }
4304 
4305 REGISTER_EXPRESSION(toLower, ExpressionToLower::parse);
getOpName() const4306 const char* ExpressionToLower::getOpName() const {
4307     return "$toLower";
4308 }
4309 
4310 /* ------------------------- ExpressionToUpper -------------------------- */
4311 
evaluate(const Document & root,Variables * variables) const4312 Value ExpressionToUpper::evaluate(const Document& root, Variables* variables) const {
4313     Value pString(vpOperand[0]->evaluate(root, variables));
4314     string str(pString.coerceToString());
4315     boost::to_upper(str);
4316     return Value(str);
4317 }
4318 
4319 REGISTER_EXPRESSION(toUpper, ExpressionToUpper::parse);
getOpName() const4320 const char* ExpressionToUpper::getOpName() const {
4321     return "$toUpper";
4322 }
4323 
4324 /* ------------------------- ExpressionTrunc -------------------------- */
4325 
evaluateNumericArg(const Value & numericArg) const4326 Value ExpressionTrunc::evaluateNumericArg(const Value& numericArg) const {
4327     // There's no point in truncating integers or longs, it will have no effect.
4328 
4329     switch (numericArg.getType()) {
4330         case NumberDecimal:
4331             return Value(numericArg.getDecimal().quantize(Decimal128::kNormalizedZero,
4332                                                           Decimal128::kRoundTowardZero));
4333         case NumberDouble:
4334             return Value(std::trunc(numericArg.getDouble()));
4335         default:
4336             return numericArg;
4337     }
4338 }
4339 
4340 REGISTER_EXPRESSION(trunc, ExpressionTrunc::parse);
getOpName() const4341 const char* ExpressionTrunc::getOpName() const {
4342     return "$trunc";
4343 }
4344 
4345 /* ------------------------- ExpressionType ----------------------------- */
4346 
evaluate(const Document & root,Variables * variables) const4347 Value ExpressionType::evaluate(const Document& root, Variables* variables) const {
4348     Value val(vpOperand[0]->evaluate(root, variables));
4349     return Value(StringData(typeName(val.getType())));
4350 }
4351 
4352 REGISTER_EXPRESSION(type, ExpressionType::parse);
getOpName() const4353 const char* ExpressionType::getOpName() const {
4354     return "$type";
4355 }
4356 
4357 /* -------------------------- ExpressionZip ------------------------------ */
4358 
4359 REGISTER_EXPRESSION(zip, ExpressionZip::parse);
parse(const boost::intrusive_ptr<ExpressionContext> & expCtx,BSONElement expr,const VariablesParseState & vps)4360 intrusive_ptr<Expression> ExpressionZip::parse(
4361     const boost::intrusive_ptr<ExpressionContext>& expCtx,
4362     BSONElement expr,
4363     const VariablesParseState& vps) {
4364     uassert(34460,
4365             str::stream() << "$zip only supports an object as an argument, found "
4366                           << typeName(expr.type()),
4367             expr.type() == Object);
4368 
4369     intrusive_ptr<ExpressionZip> newZip(new ExpressionZip(expCtx));
4370 
4371     for (auto&& elem : expr.Obj()) {
4372         const auto field = elem.fieldNameStringData();
4373         if (field == "inputs") {
4374             uassert(34461,
4375                     str::stream() << "inputs must be an array of expressions, found "
4376                                   << typeName(elem.type()),
4377                     elem.type() == Array);
4378             for (auto&& subExpr : elem.Array()) {
4379                 newZip->_inputs.push_back(parseOperand(expCtx, subExpr, vps));
4380             }
4381         } else if (field == "defaults") {
4382             uassert(34462,
4383                     str::stream() << "defaults must be an array of expressions, found "
4384                                   << typeName(elem.type()),
4385                     elem.type() == Array);
4386             for (auto&& subExpr : elem.Array()) {
4387                 newZip->_defaults.push_back(parseOperand(expCtx, subExpr, vps));
4388             }
4389         } else if (field == "useLongestLength") {
4390             uassert(34463,
4391                     str::stream() << "useLongestLength must be a bool, found "
4392                                   << typeName(expr.type()),
4393                     elem.type() == Bool);
4394             newZip->_useLongestLength = elem.Bool();
4395         } else {
4396             uasserted(34464,
4397                       str::stream() << "$zip found an unknown argument: " << elem.fieldName());
4398         }
4399     }
4400 
4401     uassert(34465, "$zip requires at least one input array", !newZip->_inputs.empty());
4402     uassert(34466,
4403             "cannot specify defaults unless useLongestLength is true",
4404             (newZip->_useLongestLength || newZip->_defaults.empty()));
4405     uassert(34467,
4406             "defaults and inputs must have the same length",
4407             (newZip->_defaults.empty() || newZip->_defaults.size() == newZip->_inputs.size()));
4408 
4409     return std::move(newZip);
4410 }
4411 
evaluate(const Document & root,Variables * variables) const4412 Value ExpressionZip::evaluate(const Document& root, Variables* variables) const {
4413     // Evaluate input values.
4414     vector<vector<Value>> inputValues;
4415     inputValues.reserve(_inputs.size());
4416 
4417     size_t minArraySize = 0;
4418     size_t maxArraySize = 0;
4419     for (size_t i = 0; i < _inputs.size(); i++) {
4420         Value evalExpr = _inputs[i].get()->evaluate(root, variables);
4421         if (evalExpr.nullish()) {
4422             return Value(BSONNULL);
4423         }
4424 
4425         uassert(34468,
4426                 str::stream() << "$zip found a non-array expression in input: "
4427                               << evalExpr.toString(),
4428                 evalExpr.isArray());
4429 
4430         inputValues.push_back(evalExpr.getArray());
4431 
4432         size_t arraySize = evalExpr.getArrayLength();
4433 
4434         if (i == 0) {
4435             minArraySize = arraySize;
4436             maxArraySize = arraySize;
4437         } else {
4438             auto arraySizes = std::minmax({minArraySize, arraySize, maxArraySize});
4439             minArraySize = arraySizes.first;
4440             maxArraySize = arraySizes.second;
4441         }
4442     }
4443 
4444     vector<Value> evaluatedDefaults(_inputs.size(), Value(BSONNULL));
4445 
4446     // If we need default values, evaluate each expression.
4447     if (minArraySize != maxArraySize) {
4448         for (size_t i = 0; i < _defaults.size(); i++) {
4449             evaluatedDefaults[i] = _defaults[i].get()->evaluate(root, variables);
4450         }
4451     }
4452 
4453     size_t outputLength = _useLongestLength ? maxArraySize : minArraySize;
4454 
4455     // The final output array, e.g. [[1, 2, 3], [2, 3, 4]].
4456     vector<Value> output;
4457 
4458     // Used to construct each array in the output, e.g. [1, 2, 3].
4459     vector<Value> outputChild;
4460 
4461     output.reserve(outputLength);
4462     outputChild.reserve(_inputs.size());
4463 
4464     for (size_t row = 0; row < outputLength; row++) {
4465         outputChild.clear();
4466         for (size_t col = 0; col < _inputs.size(); col++) {
4467             if (inputValues[col].size() > row) {
4468                 // Add the value from the appropriate input array.
4469                 outputChild.push_back(inputValues[col][row]);
4470             } else {
4471                 // Add the corresponding default value.
4472                 outputChild.push_back(evaluatedDefaults[col]);
4473             }
4474         }
4475         output.push_back(Value(outputChild));
4476     }
4477 
4478     return Value(output);
4479 }
4480 
optimize()4481 boost::intrusive_ptr<Expression> ExpressionZip::optimize() {
4482     std::transform(_inputs.begin(),
4483                    _inputs.end(),
4484                    _inputs.begin(),
4485                    [](intrusive_ptr<Expression> inputExpression) -> intrusive_ptr<Expression> {
4486                        return inputExpression->optimize();
4487                    });
4488 
4489     std::transform(_defaults.begin(),
4490                    _defaults.end(),
4491                    _defaults.begin(),
4492                    [](intrusive_ptr<Expression> defaultExpression) -> intrusive_ptr<Expression> {
4493                        return defaultExpression->optimize();
4494                    });
4495 
4496     return this;
4497 }
4498 
serialize(bool explain) const4499 Value ExpressionZip::serialize(bool explain) const {
4500     vector<Value> serializedInput;
4501     vector<Value> serializedDefaults;
4502     Value serializedUseLongestLength = Value(_useLongestLength);
4503 
4504     for (auto&& expr : _inputs) {
4505         serializedInput.push_back(expr->serialize(explain));
4506     }
4507 
4508     for (auto&& expr : _defaults) {
4509         serializedDefaults.push_back(expr->serialize(explain));
4510     }
4511 
4512     return Value(DOC("$zip" << DOC("inputs" << Value(serializedInput) << "defaults"
4513                                             << Value(serializedDefaults)
4514                                             << "useLongestLength"
4515                                             << serializedUseLongestLength)));
4516 }
4517 
_doAddDependencies(DepsTracker * deps) const4518 void ExpressionZip::_doAddDependencies(DepsTracker* deps) const {
4519     std::for_each(
4520         _inputs.begin(), _inputs.end(), [&deps](intrusive_ptr<Expression> inputExpression) -> void {
4521             inputExpression->addDependencies(deps);
4522         });
4523     std::for_each(_defaults.begin(),
4524                   _defaults.end(),
4525                   [&deps](intrusive_ptr<Expression> defaultExpression) -> void {
4526                       defaultExpression->addDependencies(deps);
4527                   });
4528 }
4529 
4530 }  // namespace mongo
4531