1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 #define MONGO_LOG_DEFAULT_COMPONENT ::mongo::logger::LogComponent::kCommand
32 
33 #include "mongo/platform/basic.h"
34 
35 #include "mongo/s/commands/cluster_aggregate.h"
36 
37 #include <boost/intrusive_ptr.hpp>
38 
39 #include "mongo/bson/util/bson_extract.h"
40 #include "mongo/db/auth/authorization_session.h"
41 #include "mongo/db/client.h"
42 #include "mongo/db/commands.h"
43 #include "mongo/db/operation_context.h"
44 #include "mongo/db/pipeline/document_source_change_stream.h"
45 #include "mongo/db/pipeline/document_source_out.h"
46 #include "mongo/db/pipeline/expression_context.h"
47 #include "mongo/db/pipeline/lite_parsed_pipeline.h"
48 #include "mongo/db/pipeline/pipeline.h"
49 #include "mongo/db/query/collation/collator_factory_interface.h"
50 #include "mongo/db/query/cursor_response.h"
51 #include "mongo/db/query/explain_common.h"
52 #include "mongo/db/query/find_common.h"
53 #include "mongo/db/views/resolved_view.h"
54 #include "mongo/db/views/view.h"
55 #include "mongo/executor/task_executor_pool.h"
56 #include "mongo/rpc/get_status_from_command_result.h"
57 #include "mongo/s/catalog_cache.h"
58 #include "mongo/s/client/shard_connection.h"
59 #include "mongo/s/client/shard_registry.h"
60 #include "mongo/s/commands/cluster_commands_helpers.h"
61 #include "mongo/s/commands/pipeline_s.h"
62 #include "mongo/s/grid.h"
63 #include "mongo/s/query/cluster_client_cursor_impl.h"
64 #include "mongo/s/query/cluster_client_cursor_params.h"
65 #include "mongo/s/query/cluster_cursor_manager.h"
66 #include "mongo/s/query/cluster_query_knobs.h"
67 #include "mongo/s/query/establish_cursors.h"
68 #include "mongo/s/query/router_stage_update_on_add_shard.h"
69 #include "mongo/s/query/store_possible_cursor.h"
70 #include "mongo/s/stale_exception.h"
71 #include "mongo/util/fail_point.h"
72 #include "mongo/util/log.h"
73 
74 namespace mongo {
75 
76 MONGO_FP_DECLARE(clusterAggregateHangBeforeEstablishingShardCursors);
77 
78 namespace {
79 // Given a document representing an aggregation command such as
80 //
81 //   {aggregate: "myCollection", pipeline: [], ...},
82 //
83 // produces the corresponding explain command:
84 //
85 //   {explain: {aggregate: "myCollection", pipline: [], ...}, $queryOptions: {...}, verbosity: ...}
wrapAggAsExplain(Document aggregateCommand,ExplainOptions::Verbosity verbosity)86 Document wrapAggAsExplain(Document aggregateCommand, ExplainOptions::Verbosity verbosity) {
87     MutableDocument explainCommandBuilder;
88     explainCommandBuilder["explain"] = Value(aggregateCommand);
89     // Downstream host targeting code expects queryOptions at the top level of the command object.
90     explainCommandBuilder[QueryRequest::kUnwrappedReadPrefField] =
91         Value(aggregateCommand[QueryRequest::kUnwrappedReadPrefField]);
92 
93     // readConcern needs to be promoted to the top-level of the request.
94     explainCommandBuilder[repl::ReadConcernArgs::kReadConcernFieldName] =
95         Value(aggregateCommand[repl::ReadConcernArgs::kReadConcernFieldName]);
96 
97     // Add explain command options.
98     for (auto&& explainOption : ExplainOptions::toBSON(verbosity)) {
99         explainCommandBuilder[explainOption.fieldNameStringData()] = Value(explainOption);
100     }
101 
102     return explainCommandBuilder.freeze();
103 }
104 
appendExplainResults(const std::vector<AsyncRequestsSender::Response> & shardResults,const boost::intrusive_ptr<ExpressionContext> & mergeCtx,const std::unique_ptr<Pipeline,Pipeline::Deleter> & pipelineForTargetedShards,const std::unique_ptr<Pipeline,Pipeline::Deleter> & pipelineForMerging,BSONObjBuilder * result)105 Status appendExplainResults(
106     const std::vector<AsyncRequestsSender::Response>& shardResults,
107     const boost::intrusive_ptr<ExpressionContext>& mergeCtx,
108     const std::unique_ptr<Pipeline, Pipeline::Deleter>& pipelineForTargetedShards,
109     const std::unique_ptr<Pipeline, Pipeline::Deleter>& pipelineForMerging,
110     BSONObjBuilder* result) {
111     if (pipelineForTargetedShards->isSplitForShards()) {
112         *result << "mergeType"
113                 << (pipelineForMerging->canRunOnMongos()
114                         ? "mongos"
115                         : pipelineForMerging->needsPrimaryShardMerger() ? "primaryShard"
116                                                                         : "anyShard")
117                 << "splitPipeline"
118                 << Document{
119                        {"shardsPart",
120                         pipelineForTargetedShards->writeExplainOps(*mergeCtx->explain)},
121                        {"mergerPart", pipelineForMerging->writeExplainOps(*mergeCtx->explain)}};
122     } else {
123         *result << "splitPipeline" << BSONNULL;
124     }
125 
126     BSONObjBuilder shardExplains(result->subobjStart("shards"));
127     for (const auto& shardResult : shardResults) {
128         invariant(shardResult.shardHostAndPort);
129         shardExplains.append(shardResult.shardId.toString(),
130                              BSON("host" << shardResult.shardHostAndPort->toString() << "stages"
131                                          << shardResult.swResponse.getValue().data["stages"]));
132     }
133 
134     return Status::OK();
135 }
136 
appendCursorResponseToCommandResult(const ShardId & shardId,const BSONObj cursorResponse,BSONObjBuilder * result)137 Status appendCursorResponseToCommandResult(const ShardId& shardId,
138                                            const BSONObj cursorResponse,
139                                            BSONObjBuilder* result) {
140     // If a write error was encountered, append it to the output buffer first.
141     if (auto wcErrorElem = cursorResponse["writeConcernError"]) {
142         appendWriteConcernErrorToCmdResponse(shardId, wcErrorElem, *result);
143     }
144 
145     // Pass the results from the remote shard into our command response.
146     result->appendElementsUnique(Command::filterCommandReplyForPassthrough(cursorResponse));
147     return getStatusFromCommandResult(result->asTempObj());
148 }
149 
mustRunOnAllShards(const NamespaceString & nss,const CachedCollectionRoutingInfo & routingInfo,const LiteParsedPipeline & litePipe)150 bool mustRunOnAllShards(const NamespaceString& nss,
151                         const CachedCollectionRoutingInfo& routingInfo,
152                         const LiteParsedPipeline& litePipe) {
153     // Any collectionless aggregation like a $currentOp, and a change stream on a sharded collection
154     // must run on all shards.
155     const bool nsIsSharded = static_cast<bool>(routingInfo.cm());
156     return nss.isCollectionlessAggregateNS() || (nsIsSharded && litePipe.hasChangeStream());
157 }
158 
getExecutionNsRoutingInfo(OperationContext * opCtx,const NamespaceString & execNss,CatalogCache * catalogCache)159 StatusWith<CachedCollectionRoutingInfo> getExecutionNsRoutingInfo(OperationContext* opCtx,
160                                                                   const NamespaceString& execNss,
161                                                                   CatalogCache* catalogCache) {
162     // This call to getCollectionRoutingInfo will return !OK if the database does not exist.
163     auto swRoutingInfo = catalogCache->getCollectionRoutingInfo(opCtx, execNss);
164 
165     // Collectionless aggregations, however, may be run on 'admin' (which should always exist) but
166     // are subsequently targeted towards the shards. If getCollectionRoutingInfo is OK, we perform a
167     // further check that at least one shard exists if the aggregation is collectionless.
168     if (swRoutingInfo.isOK() && execNss.isCollectionlessAggregateNS()) {
169         std::vector<ShardId> shardIds;
170         Grid::get(opCtx)->shardRegistry()->getAllShardIds(&shardIds);
171 
172         if (shardIds.size() == 0) {
173             return {ErrorCodes::NamespaceNotFound, "No shards are present in the cluster"};
174         }
175     }
176 
177     return swRoutingInfo;
178 }
179 
getTargetedShards(OperationContext * opCtx,const NamespaceString & nss,const LiteParsedPipeline & litePipe,const CachedCollectionRoutingInfo & routingInfo,const BSONObj shardQuery,const BSONObj collation)180 std::set<ShardId> getTargetedShards(OperationContext* opCtx,
181                                     const NamespaceString& nss,
182                                     const LiteParsedPipeline& litePipe,
183                                     const CachedCollectionRoutingInfo& routingInfo,
184                                     const BSONObj shardQuery,
185                                     const BSONObj collation) {
186     if (mustRunOnAllShards(nss, routingInfo, litePipe)) {
187         // The pipeline begins with a stage which must be run on all shards.
188         std::vector<ShardId> shardIds;
189         Grid::get(opCtx)->shardRegistry()->getAllShardIds(&shardIds);
190         return {shardIds.begin(), shardIds.end()};
191     }
192 
193     if (routingInfo.cm()) {
194         // The collection is sharded. Use the routing table to decide which shards to target
195         // based on the query and collation.
196         std::set<ShardId> shardIds;
197         routingInfo.cm()->getShardIdsForQuery(opCtx, shardQuery, collation, &shardIds);
198         return shardIds;
199     }
200 
201     // The collection is unsharded. Target only the primary shard for the database.
202     return {routingInfo.primaryId()};
203 }
204 
createCommandForTargetedShards(const AggregationRequest & request,const BSONObj originalCmdObj,const std::unique_ptr<Pipeline,Pipeline::Deleter> & pipelineForTargetedShards)205 BSONObj createCommandForTargetedShards(
206     const AggregationRequest& request,
207     const BSONObj originalCmdObj,
208     const std::unique_ptr<Pipeline, Pipeline::Deleter>& pipelineForTargetedShards) {
209     // Create the command for the shards.
210     MutableDocument targetedCmd(request.serializeToCommandObj());
211     targetedCmd[AggregationRequest::kFromMongosName] = Value(true);
212 
213     // If 'pipelineForTargetedShards' is 'nullptr', this is an unsharded direct passthrough.
214     if (pipelineForTargetedShards) {
215         targetedCmd[AggregationRequest::kPipelineName] =
216             Value(pipelineForTargetedShards->serialize());
217 
218         if (pipelineForTargetedShards->isSplitForShards()) {
219             targetedCmd[AggregationRequest::kNeedsMergeName] = Value(true);
220             targetedCmd[AggregationRequest::kCursorName] =
221                 Value(DOC(AggregationRequest::kBatchSizeName << 0));
222         }
223     }
224 
225     // If this pipeline is not split, ensure that the write concern is propagated if present.
226     if (!pipelineForTargetedShards || !pipelineForTargetedShards->isSplitForShards()) {
227         targetedCmd["writeConcern"] = Value(originalCmdObj["writeConcern"]);
228     }
229 
230     // If this is a request for an aggregation explain, then we must wrap the aggregate inside an
231     // explain command.
232     if (auto explainVerbosity = request.getExplain()) {
233         targetedCmd.reset(wrapAggAsExplain(targetedCmd.freeze(), *explainVerbosity));
234     }
235 
236     return targetedCmd.freeze().toBson();
237 }
238 
createCommandForMergingShard(const AggregationRequest & request,const boost::intrusive_ptr<ExpressionContext> & mergeCtx,const BSONObj originalCmdObj,const std::unique_ptr<Pipeline,Pipeline::Deleter> & pipelineForMerging)239 BSONObj createCommandForMergingShard(
240     const AggregationRequest& request,
241     const boost::intrusive_ptr<ExpressionContext>& mergeCtx,
242     const BSONObj originalCmdObj,
243     const std::unique_ptr<Pipeline, Pipeline::Deleter>& pipelineForMerging) {
244     MutableDocument mergeCmd(request.serializeToCommandObj());
245 
246     mergeCmd["pipeline"] = Value(pipelineForMerging->serialize());
247     mergeCmd[AggregationRequest::kFromMongosName] = Value(true);
248     mergeCmd["writeConcern"] = Value(originalCmdObj["writeConcern"]);
249 
250     // If the user didn't specify a collation already, make sure there's a collation attached to
251     // the merge command, since the merging shard may not have the collection metadata.
252     if (mergeCmd.peek()["collation"].missing()) {
253         mergeCmd["collation"] = mergeCtx->getCollator()
254             ? Value(mergeCtx->getCollator()->getSpec().toBSON())
255             : Value(Document{CollationSpec::kSimpleSpec});
256     }
257 
258     return mergeCmd.freeze().toBson();
259 }
260 
establishShardCursors(OperationContext * opCtx,const NamespaceString & nss,const LiteParsedPipeline & litePipe,CachedCollectionRoutingInfo * routingInfo,const BSONObj & cmdObj,const ReadPreferenceSetting & readPref,const BSONObj & shardQuery,const BSONObj & collation)261 StatusWith<std::vector<ClusterClientCursorParams::RemoteCursor>> establishShardCursors(
262     OperationContext* opCtx,
263     const NamespaceString& nss,
264     const LiteParsedPipeline& litePipe,
265     CachedCollectionRoutingInfo* routingInfo,
266     const BSONObj& cmdObj,
267     const ReadPreferenceSetting& readPref,
268     const BSONObj& shardQuery,
269     const BSONObj& collation) {
270     LOG(1) << "Dispatching command " << redact(cmdObj) << " to establish cursors on shards";
271 
272     std::set<ShardId> shardIds =
273         getTargetedShards(opCtx, nss, litePipe, *routingInfo, shardQuery, collation);
274     std::vector<std::pair<ShardId, BSONObj>> requests;
275 
276     if (mustRunOnAllShards(nss, *routingInfo, litePipe)) {
277         // The pipeline contains a stage which must be run on all shards. Skip versioning and
278         // enqueue the raw command objects.
279         for (auto&& shardId : shardIds) {
280             requests.emplace_back(std::move(shardId), cmdObj);
281         }
282     } else if (routingInfo->cm()) {
283         // The collection is sharded. Use the routing table to decide which shards to target
284         // based on the query and collation, and build versioned requests for them.
285         for (auto& shardId : shardIds) {
286             auto versionedCmdObj =
287                 appendShardVersion(cmdObj, routingInfo->cm()->getVersion(shardId));
288             requests.emplace_back(std::move(shardId), std::move(versionedCmdObj));
289         }
290     } else {
291         // The collection is unsharded. Target only the primary shard for the database.
292         // Don't append shard version info when contacting the config servers.
293         requests.emplace_back(routingInfo->primaryId(),
294                               !routingInfo->primary()->isConfig()
295                                   ? appendShardVersion(cmdObj, ChunkVersion::UNSHARDED())
296                                   : cmdObj);
297     }
298 
299     if (MONGO_FAIL_POINT(clusterAggregateHangBeforeEstablishingShardCursors)) {
300         log() << "clusterAggregateHangBeforeEstablishingShardCursors fail point enabled.  Blocking "
301                  "until fail point is disabled.";
302         while (MONGO_FAIL_POINT(clusterAggregateHangBeforeEstablishingShardCursors)) {
303             sleepsecs(1);
304         }
305     }
306 
307     // If we reach this point, we're either trying to establish cursors on a sharded execution
308     // namespace, or handling the case where a sharded collection was dropped and recreated as
309     // unsharded. Since views cannot be sharded, and because we will return an error rather than
310     // attempting to continue in the event that a recreated namespace is a view, we set
311     // viewDefinitionOut to nullptr.
312     BSONObj* viewDefinitionOut = nullptr;
313     auto swCursors = establishCursors(opCtx,
314                                       Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor(),
315                                       nss,
316                                       readPref,
317                                       requests,
318                                       false /* do not allow partial results */,
319                                       viewDefinitionOut /* can't receive view definition */);
320 
321     // If any shard returned a stale shardVersion error, invalidate the routing table cache.
322     // This will cause the cache to be refreshed the next time it is accessed.
323     if (ErrorCodes::isStaleShardingError(swCursors.getStatus().code())) {
324         Grid::get(opCtx)->catalogCache()->onStaleConfigError(std::move(*routingInfo));
325     }
326 
327     return swCursors;
328 }
329 
330 struct DispatchShardPipelineResults {
331     // True if this pipeline was split, and the second half of the pipeline needs to be run on the
332     // primary shard for the database.
333     bool needsPrimaryShardMerge;
334 
335     // Populated if this *is not* an explain, this vector represents the cursors on the remote
336     // shards.
337     std::vector<ClusterClientCursorParams::RemoteCursor> remoteCursors;
338 
339     // Populated if this *is* an explain, this vector represents the results from each shard.
340     std::vector<AsyncRequestsSender::Response> remoteExplainOutput;
341 
342     // The half of the pipeline that was sent to each shard, or the entire pipeline if there was
343     // only one shard targeted.
344     std::unique_ptr<Pipeline, Pipeline::Deleter> pipelineForTargetedShards;
345 
346     // The merging half of the pipeline if more than one shard was targeted, otherwise nullptr.
347     std::unique_ptr<Pipeline, Pipeline::Deleter> pipelineForMerging;
348 
349     // The command object to send to the targeted shards.
350     BSONObj commandForTargetedShards;
351 };
352 
353 /**
354  * Targets shards for the pipeline and returns a struct with the remote cursors or results, and
355  * the pipeline that will need to be executed to merge the results from the remotes. If a stale
356  * shard version is encountered, refreshes the routing table and tries again.
357  */
dispatchShardPipeline(const boost::intrusive_ptr<ExpressionContext> & expCtx,const NamespaceString & executionNss,BSONObj originalCmdObj,const AggregationRequest & aggRequest,const LiteParsedPipeline & liteParsedPipeline,std::unique_ptr<Pipeline,Pipeline::Deleter> pipeline)358 StatusWith<DispatchShardPipelineResults> dispatchShardPipeline(
359     const boost::intrusive_ptr<ExpressionContext>& expCtx,
360     const NamespaceString& executionNss,
361     BSONObj originalCmdObj,
362     const AggregationRequest& aggRequest,
363     const LiteParsedPipeline& liteParsedPipeline,
364     std::unique_ptr<Pipeline, Pipeline::Deleter> pipeline) {
365     // The process is as follows:
366     // - First, determine whether we need to target more than one shard. If so, we split the
367     // pipeline; if not, we retain the existing pipeline.
368     // - Call establishShardCursors to dispatch the aggregation to the targeted shards.
369     // - If we get a staleConfig exception, re-evaluate whether we need to split the pipeline with
370     // the refreshed routing table data.
371     // - If the pipeline is not split and we now need to target multiple shards, split it. If the
372     // pipeline is already split and we now only need to target a single shard, reassemble the
373     // original pipeline.
374     // - After exhausting 10 attempts to establish the cursors, we give up and throw.
375     auto swCursors = makeStatusWith<std::vector<ClusterClientCursorParams::RemoteCursor>>();
376     auto swShardResults = makeStatusWith<std::vector<AsyncRequestsSender::Response>>();
377     auto opCtx = expCtx->opCtx;
378 
379     const bool needsPrimaryShardMerge =
380         (pipeline->needsPrimaryShardMerger() || internalQueryAlwaysMergeOnPrimaryShard.load());
381 
382     const bool needsMongosMerge = pipeline->needsMongosMerger();
383 
384     const auto shardQuery = pipeline->getInitialQuery();
385 
386     auto pipelineForTargetedShards = std::move(pipeline);
387     std::unique_ptr<Pipeline, Pipeline::Deleter> pipelineForMerging;
388     BSONObj targetedCommand;
389 
390     int numAttempts = 0;
391 
392     do {
393         // We need to grab a new routing table at the start of each iteration, since a stale config
394         // exception will invalidate the previous one.
395         auto executionNsRoutingInfo = uassertStatusOK(
396             Grid::get(opCtx)->catalogCache()->getCollectionRoutingInfo(opCtx, executionNss));
397 
398         // Determine whether we can run the entire aggregation on a single shard.
399         std::set<ShardId> shardIds = getTargetedShards(opCtx,
400                                                        executionNss,
401                                                        liteParsedPipeline,
402                                                        executionNsRoutingInfo,
403                                                        shardQuery,
404                                                        aggRequest.getCollation());
405 
406         uassert(ErrorCodes::ShardNotFound,
407                 "No targets were found for this aggregation. All shards were removed from the "
408                 "cluster mid-operation",
409                 shardIds.size() > 0);
410 
411         // Don't need to split the pipeline if we are only targeting a single shard, unless:
412         // - There is a stage that needs to be run on the primary shard and the single target shard
413         //   is not the primary.
414         // - The pipeline contains one or more stages which must always merge on mongoS.
415         const bool needsSplit =
416             (shardIds.size() > 1u || needsMongosMerge ||
417              (needsPrimaryShardMerge && *(shardIds.begin()) != executionNsRoutingInfo.primaryId()));
418 
419         const bool isSplit = pipelineForTargetedShards->isSplitForShards();
420 
421         // If we have to run on multiple shards and the pipeline is not yet split, split it. If we
422         // can run on a single shard and the pipeline is already split, reassemble it.
423         if (needsSplit && !isSplit) {
424             pipelineForMerging = std::move(pipelineForTargetedShards);
425             pipelineForTargetedShards = pipelineForMerging->splitForSharded();
426         } else if (!needsSplit && isSplit) {
427             pipelineForTargetedShards->unsplitFromSharded(std::move(pipelineForMerging));
428         }
429 
430         // Generate the command object for the targeted shards.
431         targetedCommand =
432             createCommandForTargetedShards(aggRequest, originalCmdObj, pipelineForTargetedShards);
433 
434         // Refresh the shard registry if we're targeting all shards.  We need the shard registry
435         // to be at least as current as the logical time used when creating the command for
436         // $changeStream to work reliably, so we do a "hard" reload.
437         if (mustRunOnAllShards(executionNss, executionNsRoutingInfo, liteParsedPipeline)) {
438             auto* shardRegistry = Grid::get(opCtx)->shardRegistry();
439             if (!shardRegistry->reload(opCtx)) {
440                 shardRegistry->reload(opCtx);
441             }
442         }
443 
444         // Explain does not produce a cursor, so instead we scatter-gather commands to the shards.
445         if (expCtx->explain) {
446             if (mustRunOnAllShards(executionNss, executionNsRoutingInfo, liteParsedPipeline)) {
447                 // Some stages (such as $currentOp) need to be broadcast to all shards, and should
448                 // not participate in the shard version protocol.
449                 swShardResults =
450                     scatterGatherUnversionedTargetAllShards(opCtx,
451                                                             executionNss.db().toString(),
452                                                             executionNss,
453                                                             targetedCommand,
454                                                             ReadPreferenceSetting::get(opCtx),
455                                                             Shard::RetryPolicy::kIdempotent);
456             } else {
457                 // Aggregations on a real namespace should use the routing table to target shards,
458                 // and should participate in the shard version protocol.
459                 swShardResults =
460                     scatterGatherVersionedTargetByRoutingTable(opCtx,
461                                                                executionNss.db().toString(),
462                                                                executionNss,
463                                                                targetedCommand,
464                                                                ReadPreferenceSetting::get(opCtx),
465                                                                Shard::RetryPolicy::kIdempotent,
466                                                                shardQuery,
467                                                                aggRequest.getCollation(),
468                                                                nullptr /* viewDefinition */);
469             }
470         } else {
471             swCursors = establishShardCursors(opCtx,
472                                               executionNss,
473                                               liteParsedPipeline,
474                                               &executionNsRoutingInfo,
475                                               targetedCommand,
476                                               ReadPreferenceSetting::get(opCtx),
477                                               shardQuery,
478                                               aggRequest.getCollation());
479 
480             if (ErrorCodes::isStaleShardingError(swCursors.getStatus().code())) {
481                 LOG(1) << "got stale shardVersion error " << swCursors.getStatus()
482                        << " while dispatching " << redact(targetedCommand) << " after "
483                        << (numAttempts + 1) << " dispatch attempts";
484             }
485         }
486     } while (++numAttempts < kMaxNumStaleVersionRetries &&
487              (expCtx->explain ? !swShardResults.isOK() : !swCursors.isOK()));
488 
489     if (!swShardResults.isOK()) {
490         return swShardResults.getStatus();
491     }
492     if (!swCursors.isOK()) {
493         return swCursors.getStatus();
494     }
495     return DispatchShardPipelineResults{needsPrimaryShardMerge,
496                                         std::move(swCursors.getValue()),
497                                         std::move(swShardResults.getValue()),
498                                         std::move(pipelineForTargetedShards),
499                                         std::move(pipelineForMerging),
500                                         targetedCommand};
501 }
502 
establishMergingShardCursor(OperationContext * opCtx,const NamespaceString & nss,const std::vector<ClusterClientCursorParams::RemoteCursor> & cursors,const BSONObj mergeCmdObj,const boost::optional<ShardId> primaryShard)503 StatusWith<std::pair<ShardId, Shard::CommandResponse>> establishMergingShardCursor(
504     OperationContext* opCtx,
505     const NamespaceString& nss,
506     const std::vector<ClusterClientCursorParams::RemoteCursor>& cursors,
507     const BSONObj mergeCmdObj,
508     const boost::optional<ShardId> primaryShard) {
509     // Run merging command on random shard, unless we need to run on the primary shard.
510     auto& prng = opCtx->getClient()->getPrng();
511     const auto mergingShardId =
512         primaryShard ? primaryShard.get() : cursors[prng.nextInt32(cursors.size())].shardId;
513     const auto mergingShard =
514         uassertStatusOK(Grid::get(opCtx)->shardRegistry()->getShard(opCtx, mergingShardId));
515 
516     auto shardCmdResponse = uassertStatusOK(
517         mergingShard->runCommandWithFixedRetryAttempts(opCtx,
518                                                        ReadPreferenceSetting::get(opCtx),
519                                                        nss.db().toString(),
520                                                        mergeCmdObj,
521                                                        Shard::RetryPolicy::kIdempotent));
522 
523     return {{std::move(mergingShardId), std::move(shardCmdResponse)}};
524 }
525 
establishMergingMongosCursor(OperationContext * opCtx,const AggregationRequest & request,const NamespaceString & requestedNss,BSONObj cmdToRunOnNewShards,const LiteParsedPipeline & liteParsedPipeline,std::unique_ptr<Pipeline,Pipeline::Deleter> pipelineForMerging,std::vector<ClusterClientCursorParams::RemoteCursor> cursors)526 BSONObj establishMergingMongosCursor(
527     OperationContext* opCtx,
528     const AggregationRequest& request,
529     const NamespaceString& requestedNss,
530     BSONObj cmdToRunOnNewShards,
531     const LiteParsedPipeline& liteParsedPipeline,
532     std::unique_ptr<Pipeline, Pipeline::Deleter> pipelineForMerging,
533     std::vector<ClusterClientCursorParams::RemoteCursor> cursors) {
534 
535     // Inject the MongosProcessInterface for sources which need it.
536     PipelineS::injectMongosInterface(pipelineForMerging.get());
537 
538     ClusterClientCursorParams params(
539         requestedNss,
540         AuthorizationSession::get(opCtx->getClient())->getAuthenticatedUserNames(),
541         ReadPreferenceSetting::get(opCtx));
542 
543     params.tailableMode = pipelineForMerging->getContext()->tailableMode;
544     params.mergePipeline = std::move(pipelineForMerging);
545     params.remotes = std::move(cursors);
546 
547     // A batch size of 0 is legal for the initial aggregate, but not valid for getMores, the batch
548     // size we pass here is used for getMores, so do not specify a batch size if the initial request
549     // had a batch size of 0.
550     params.batchSize = request.getBatchSize() == 0
551         ? boost::none
552         : boost::optional<long long>(request.getBatchSize());
553 
554     if (liteParsedPipeline.hasChangeStream()) {
555         // For change streams, we need to set up a custom stage to establish cursors on new shards
556         // when they are added.
557         params.createCustomCursorSource = [cmdToRunOnNewShards](OperationContext* opCtx,
558                                                                 executor::TaskExecutor* executor,
559                                                                 ClusterClientCursorParams* params) {
560             return stdx::make_unique<RouterStageUpdateOnAddShard>(
561                 opCtx, executor, params, cmdToRunOnNewShards);
562         };
563     }
564     auto ccc = ClusterClientCursorImpl::make(
565         opCtx, Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor(), std::move(params));
566 
567     auto cursorState = ClusterCursorManager::CursorState::NotExhausted;
568     BSONObjBuilder cursorResponse;
569 
570     CursorResponseBuilder responseBuilder(true, &cursorResponse);
571 
572     for (long long objCount = 0; objCount < request.getBatchSize(); ++objCount) {
573         ClusterQueryResult next;
574         try {
575             next = uassertStatusOK(ccc->next(RouterExecStage::ExecContext::kInitialFind));
576         } catch (const ExceptionFor<ErrorCodes::CloseChangeStream>&) {
577             // This exception is thrown when a $changeStream stage encounters an event
578             // that invalidates the cursor. We should close the cursor and return without
579             // error.
580             cursorState = ClusterCursorManager::CursorState::Exhausted;
581             break;
582         }
583 
584         // Check whether we have exhausted the pipeline's results.
585         if (next.isEOF()) {
586             // We reached end-of-stream. If the cursor is not tailable, then we mark it as
587             // exhausted. If it is tailable, usually we keep it open (i.e. "NotExhausted") even when
588             // we reach end-of-stream. However, if all the remote cursors are exhausted, there is no
589             // hope of returning data and thus we need to close the mongos cursor as well.
590             if (!ccc->isTailable() || ccc->remotesExhausted()) {
591                 cursorState = ClusterCursorManager::CursorState::Exhausted;
592             }
593             break;
594         }
595 
596         // If this result will fit into the current batch, add it. Otherwise, stash it in the cursor
597         // to be returned on the next getMore.
598         auto nextObj = *next.getResult();
599 
600         if (!FindCommon::haveSpaceForNext(nextObj, objCount, responseBuilder.bytesUsed())) {
601             ccc->queueResult(nextObj);
602             break;
603         }
604 
605         responseBuilder.append(nextObj);
606     }
607 
608     ccc->detachFromOperationContext();
609 
610     CursorId clusterCursorId = 0;
611 
612     if (cursorState == ClusterCursorManager::CursorState::NotExhausted) {
613         clusterCursorId = uassertStatusOK(Grid::get(opCtx)->getCursorManager()->registerCursor(
614             opCtx,
615             ccc.releaseCursor(),
616             requestedNss,
617             ClusterCursorManager::CursorType::MultiTarget,
618             ClusterCursorManager::CursorLifetime::Mortal));
619     }
620 
621     responseBuilder.done(clusterCursorId, requestedNss.ns());
622 
623     Command::appendCommandStatus(cursorResponse, Status::OK());
624 
625     return cursorResponse.obj();
626 }
627 
getDefaultCollationForUnshardedCollection(const Shard * primaryShard,const NamespaceString & nss)628 BSONObj getDefaultCollationForUnshardedCollection(const Shard* primaryShard,
629                                                   const NamespaceString& nss) {
630     ScopedDbConnection conn(primaryShard->getConnString());
631     BSONObj defaultCollation;
632     std::list<BSONObj> all =
633         conn->getCollectionInfos(nss.db().toString(), BSON("name" << nss.coll()));
634     if (all.empty()) {
635         return defaultCollation;
636     }
637     BSONObj collectionInfo = all.front();
638     if (collectionInfo["options"].type() == BSONType::Object) {
639         BSONObj collectionOptions = collectionInfo["options"].Obj();
640         BSONElement collationElement;
641         auto status = bsonExtractTypedField(
642             collectionOptions, "collation", BSONType::Object, &collationElement);
643         if (status.isOK()) {
644             defaultCollation = collationElement.Obj().getOwned();
645             uassert(ErrorCodes::BadValue,
646                     "Default collation in collection metadata cannot be empty.",
647                     !defaultCollation.isEmpty());
648         } else if (status != ErrorCodes::NoSuchKey) {
649             uassertStatusOK(status);
650         }
651     }
652     return defaultCollation;
653 }
654 
655 }  // namespace
656 
runAggregate(OperationContext * opCtx,const Namespaces & namespaces,const AggregationRequest & request,BSONObj cmdObj,BSONObjBuilder * result)657 Status ClusterAggregate::runAggregate(OperationContext* opCtx,
658                                       const Namespaces& namespaces,
659                                       const AggregationRequest& request,
660                                       BSONObj cmdObj,
661                                       BSONObjBuilder* result) {
662     if (request.getExplain()) {
663         explain_common::generateServerInfo(result);
664     }
665 
666     const auto catalogCache = Grid::get(opCtx)->catalogCache();
667 
668     auto executionNsRoutingInfoStatus =
669         getExecutionNsRoutingInfo(opCtx, namespaces.executionNss, catalogCache);
670 
671     LiteParsedPipeline liteParsedPipeline(request);
672 
673     if (!executionNsRoutingInfoStatus.isOK()) {
674         // Standard aggregations swallow 'NamespaceNotFound' and return an empty cursor with id 0 in
675         // the event that the database does not exist. For $changeStream aggregations, however, we
676         // throw the exception in all error cases, including that of a non-existent database.
677         uassert(executionNsRoutingInfoStatus.getStatus().code(),
678                 str::stream() << "failed to open $changeStream: "
679                               << executionNsRoutingInfoStatus.getStatus().reason(),
680                 !liteParsedPipeline.hasChangeStream());
681         appendEmptyResultSet(
682             *result, executionNsRoutingInfoStatus.getStatus(), namespaces.requestedNss.ns());
683         return Status::OK();
684     }
685 
686     auto executionNsRoutingInfo = executionNsRoutingInfoStatus.getValue();
687 
688     // Determine the appropriate collation and 'resolve' involved namespaces to make the
689     // ExpressionContext.
690 
691     // We won't try to execute anything on a mongos, but we still have to populate this map so that
692     // any $lookups, etc. will be able to have a resolved view definition. It's okay that this is
693     // incorrect, we will repopulate the real resolved namespace map on the mongod. Note that we
694     // need to check if any involved collections are sharded before forwarding an aggregation
695     // command on an unsharded collection.
696     StringMap<ExpressionContext::ResolvedNamespace> resolvedNamespaces;
697 
698     for (auto&& nss : liteParsedPipeline.getInvolvedNamespaces()) {
699         const auto resolvedNsRoutingInfo =
700             uassertStatusOK(catalogCache->getCollectionRoutingInfo(opCtx, nss));
701         uassert(
702             28769, str::stream() << nss.ns() << " cannot be sharded", !resolvedNsRoutingInfo.cm());
703         resolvedNamespaces.try_emplace(nss.coll(), nss, std::vector<BSONObj>{});
704     }
705 
706     // If this pipeline is on an unsharded collection, is allowed to be forwarded to shards, does
707     // not need to run on all shards, and doesn't need transformation via
708     // DocumentSource::serialize(), then go ahead and pass it through to the owning shard
709     // unmodified.
710     if (!executionNsRoutingInfo.cm() &&
711         !mustRunOnAllShards(namespaces.executionNss, executionNsRoutingInfo, liteParsedPipeline) &&
712         liteParsedPipeline.allowedToForwardFromMongos() &&
713         liteParsedPipeline.allowedToPassthroughFromMongos()) {
714         return aggPassthrough(opCtx,
715                               namespaces,
716                               executionNsRoutingInfo.primary()->getId(),
717                               cmdObj,
718                               request,
719                               liteParsedPipeline,
720                               result);
721     }
722 
723     std::unique_ptr<CollatorInterface> collation;
724     if (!request.getCollation().isEmpty()) {
725         collation = uassertStatusOK(CollatorFactoryInterface::get(opCtx->getServiceContext())
726                                         ->makeFromBSON(request.getCollation()));
727     } else if (const auto chunkMgr = executionNsRoutingInfo.cm()) {
728         if (chunkMgr->getDefaultCollator()) {
729             collation = chunkMgr->getDefaultCollator()->clone();
730         }
731     } else {
732         // Unsharded collection.  Get collection metadata from primary chunk.
733         auto collationObj = getDefaultCollationForUnshardedCollection(
734             executionNsRoutingInfo.primary().get(), namespaces.executionNss);
735         if (!collationObj.isEmpty()) {
736             collation = uassertStatusOK(CollatorFactoryInterface::get(opCtx->getServiceContext())
737                                             ->makeFromBSON(collationObj));
738         }
739     }
740 
741     boost::intrusive_ptr<ExpressionContext> mergeCtx =
742         new ExpressionContext(opCtx, request, std::move(collation), std::move(resolvedNamespaces));
743     mergeCtx->inMongos = true;
744     // explicitly *not* setting mergeCtx->tempDir
745 
746     auto pipeline = uassertStatusOK(Pipeline::parse(request.getPipeline(), mergeCtx));
747     pipeline->optimizePipeline();
748 
749     // Check whether the entire pipeline must be run on mongoS.
750     if (pipeline->requiredToRunOnMongos()) {
751         uassert(ErrorCodes::IllegalOperation,
752                 str::stream() << "Aggregation pipeline must be run on mongoS, but "
753                               << pipeline->getSources().front()->getSourceName()
754                               << " is not capable of producing input",
755                 !pipeline->getSources().front()->constraints().requiresInputDocSource);
756 
757         auto cursorResponse = establishMergingMongosCursor(opCtx,
758                                                            request,
759                                                            namespaces.requestedNss,
760                                                            cmdObj,
761                                                            liteParsedPipeline,
762                                                            std::move(pipeline),
763                                                            {});
764         Command::filterCommandReplyForPassthrough(cursorResponse, result);
765         return getStatusFromCommandResult(result->asTempObj());
766     }
767 
768     auto dispatchResults = uassertStatusOK(dispatchShardPipeline(mergeCtx,
769                                                                  namespaces.executionNss,
770                                                                  cmdObj,
771                                                                  request,
772                                                                  liteParsedPipeline,
773                                                                  std::move(pipeline)));
774 
775     if (mergeCtx->explain) {
776         // If we reach here, we've either succeeded in running the explain or exhausted all
777         // attempts. In either case, attempt to append the explain results to the output builder.
778         uassertAllShardsSupportExplain(dispatchResults.remoteExplainOutput);
779 
780         return appendExplainResults(std::move(dispatchResults.remoteExplainOutput),
781                                     mergeCtx,
782                                     dispatchResults.pipelineForTargetedShards,
783                                     dispatchResults.pipelineForMerging,
784                                     result);
785     }
786 
787 
788     invariant(dispatchResults.remoteCursors.size() > 0);
789 
790     // If we dispatched to a single shard, store the remote cursor and return immediately.
791     if (!dispatchResults.pipelineForTargetedShards->isSplitForShards()) {
792         invariant(dispatchResults.remoteCursors.size() == 1);
793         const auto& remoteCursor = dispatchResults.remoteCursors[0];
794         auto executorPool = Grid::get(opCtx)->getExecutorPool();
795         const BSONObj reply = uassertStatusOK(storePossibleCursor(
796             opCtx,
797             remoteCursor.shardId,
798             remoteCursor.hostAndPort,
799             remoteCursor.cursorResponse.toBSON(CursorResponse::ResponseType::InitialResponse),
800             namespaces.requestedNss,
801             executorPool->getArbitraryExecutor(),
802             Grid::get(opCtx)->getCursorManager(),
803             mergeCtx->tailableMode));
804 
805         return appendCursorResponseToCommandResult(remoteCursor.shardId, reply, result);
806     }
807 
808     // If we reach here, we have a merge pipeline to dispatch.
809     auto mergingPipeline = std::move(dispatchResults.pipelineForMerging);
810     invariant(mergingPipeline);
811 
812     // First, check whether we can merge on the mongoS. If the merge pipeline MUST run on mongoS,
813     // then ignore the internalQueryProhibitMergingOnMongoS parameter.
814     if (mergingPipeline->requiredToRunOnMongos() ||
815         (!internalQueryProhibitMergingOnMongoS.load() && mergingPipeline->canRunOnMongos())) {
816         // Register the new mongoS cursor, and retrieve the initial batch of results.
817         auto cursorResponse =
818             establishMergingMongosCursor(opCtx,
819                                          request,
820                                          namespaces.requestedNss,
821                                          dispatchResults.commandForTargetedShards,
822                                          liteParsedPipeline,
823                                          std::move(mergingPipeline),
824                                          std::move(dispatchResults.remoteCursors));
825 
826         // We don't need to storePossibleCursor or propagate writeConcern errors; an $out pipeline
827         // can never run on mongoS. Filter the command response and return immediately.
828         Command::filterCommandReplyForPassthrough(cursorResponse, result);
829         return getStatusFromCommandResult(result->asTempObj());
830     }
831 
832     // If we cannot merge on mongoS, establish the merge cursor on a shard.
833     mergingPipeline->addInitialSource(
834         DocumentSourceMergeCursors::create(parseCursors(dispatchResults.remoteCursors), mergeCtx));
835     auto mergeCmdObj = createCommandForMergingShard(request, mergeCtx, cmdObj, mergingPipeline);
836 
837     auto mergeResponse = uassertStatusOK(
838         establishMergingShardCursor(opCtx,
839                                     namespaces.executionNss,
840                                     dispatchResults.remoteCursors,
841                                     mergeCmdObj,
842                                     boost::optional<ShardId>{dispatchResults.needsPrimaryShardMerge,
843                                                              executionNsRoutingInfo.primaryId()}));
844 
845     auto mergingShardId = mergeResponse.first;
846     auto response = mergeResponse.second;
847 
848     // The merging shard is remote, so if a response was received, a HostAndPort must have been set.
849     invariant(response.hostAndPort);
850     auto mergeCursorResponse = uassertStatusOK(
851         storePossibleCursor(opCtx,
852                             mergingShardId,
853                             *response.hostAndPort,
854                             response.response,
855                             namespaces.requestedNss,
856                             Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor(),
857                             Grid::get(opCtx)->getCursorManager()));
858 
859     return appendCursorResponseToCommandResult(mergingShardId, mergeCursorResponse, result);
860 }
861 
parseCursors(const std::vector<ClusterClientCursorParams::RemoteCursor> & responses)862 std::vector<DocumentSourceMergeCursors::CursorDescriptor> ClusterAggregate::parseCursors(
863     const std::vector<ClusterClientCursorParams::RemoteCursor>& responses) {
864     std::vector<DocumentSourceMergeCursors::CursorDescriptor> cursors;
865     for (const auto& response : responses) {
866         invariant(0 != response.cursorResponse.getCursorId());
867         invariant(response.cursorResponse.getBatch().empty());
868         cursors.emplace_back(ConnectionString(response.hostAndPort),
869                              response.cursorResponse.getNSS().toString(),
870                              response.cursorResponse.getCursorId());
871     }
872     return cursors;
873 }
874 
uassertAllShardsSupportExplain(const std::vector<AsyncRequestsSender::Response> & shardResults)875 void ClusterAggregate::uassertAllShardsSupportExplain(
876     const std::vector<AsyncRequestsSender::Response>& shardResults) {
877     for (const auto& result : shardResults) {
878         auto status = result.swResponse.getStatus();
879         if (status.isOK()) {
880             status = getStatusFromCommandResult(result.swResponse.getValue().data);
881         }
882         uassert(17403,
883                 str::stream() << "Shard " << result.shardId.toString() << " failed: "
884                               << causedBy(status),
885                 status.isOK());
886 
887         uassert(17404,
888                 str::stream() << "Shard " << result.shardId.toString()
889                               << " does not support $explain",
890                 result.swResponse.getValue().data.hasField("stages"));
891     }
892 }
893 
aggPassthrough(OperationContext * opCtx,const Namespaces & namespaces,const ShardId & shardId,BSONObj cmdObj,const AggregationRequest & aggRequest,const LiteParsedPipeline & liteParsedPipeline,BSONObjBuilder * out)894 Status ClusterAggregate::aggPassthrough(OperationContext* opCtx,
895                                         const Namespaces& namespaces,
896                                         const ShardId& shardId,
897                                         BSONObj cmdObj,
898                                         const AggregationRequest& aggRequest,
899                                         const LiteParsedPipeline& liteParsedPipeline,
900                                         BSONObjBuilder* out) {
901     // Temporary hack. See comment on declaration for details.
902     auto swShard = Grid::get(opCtx)->shardRegistry()->getShard(opCtx, shardId);
903     if (!swShard.isOK()) {
904         return swShard.getStatus();
905     }
906     auto shard = std::move(swShard.getValue());
907 
908     // Format the command for the shard. This adds the 'fromMongos' field, wraps the command as an
909     // explain if necessary, and rewrites the result into a format safe to forward to shards.
910     cmdObj = Command::filterCommandRequestForPassthrough(
911         createCommandForTargetedShards(aggRequest, cmdObj, nullptr));
912 
913     auto cmdResponse = uassertStatusOK(shard->runCommandWithFixedRetryAttempts(
914         opCtx,
915         ReadPreferenceSetting::get(opCtx),
916         namespaces.executionNss.db().toString(),
917         !shard->isConfig() ? appendShardVersion(std::move(cmdObj), ChunkVersion::UNSHARDED())
918                            : std::move(cmdObj),
919         Shard::RetryPolicy::kIdempotent));
920 
921     if (ErrorCodes::isStaleShardingError(cmdResponse.commandStatus.code())) {
922         throw StaleConfigException("command failed because of stale config", cmdResponse.response);
923     }
924 
925     BSONObj result;
926     if (aggRequest.getExplain()) {
927         // If this was an explain, then we get back an explain result object rather than a cursor.
928         result = cmdResponse.response;
929     } else {
930         // The merging shard is remote, so if a response was received, a HostAndPort must have been
931         // set.
932         invariant(cmdResponse.hostAndPort);
933         result = uassertStatusOK(storePossibleCursor(
934             opCtx,
935             shard->getId(),
936             *cmdResponse.hostAndPort,
937             cmdResponse.response,
938             namespaces.requestedNss,
939             Grid::get(opCtx)->getExecutorPool()->getArbitraryExecutor(),
940             Grid::get(opCtx)->getCursorManager(),
941             liteParsedPipeline.hasChangeStream() ? TailableMode::kTailableAndAwaitData
942                                                  : TailableMode::kNormal));
943     }
944 
945     // First append the properly constructed writeConcernError. It will then be skipped
946     // in appendElementsUnique.
947     if (auto wcErrorElem = result["writeConcernError"]) {
948         appendWriteConcernErrorToCmdResponse(shard->getId(), wcErrorElem, *out);
949     }
950 
951     out->appendElementsUnique(Command::filterCommandReplyForPassthrough(result));
952 
953     BSONObj responseObj = out->asTempObj();
954     if (ResolvedView::isResolvedViewErrorResponse(responseObj)) {
955         auto resolvedView = ResolvedView::fromBSON(responseObj);
956 
957         auto resolvedAggRequest = resolvedView.asExpandedViewAggregation(aggRequest);
958         auto resolvedAggCmd = resolvedAggRequest.serializeToCommandObj().toBson();
959         out->resetToEmpty();
960 
961         // We pass both the underlying collection namespace and the view namespace here. The
962         // underlying collection namespace is used to execute the aggregation on mongoD. Any cursor
963         // returned will be registered under the view namespace so that subsequent getMore and
964         // killCursors calls against the view have access.
965         Namespaces nsStruct;
966         nsStruct.requestedNss = namespaces.requestedNss;
967         nsStruct.executionNss = resolvedView.getNamespace();
968 
969         return ClusterAggregate::runAggregate(
970             opCtx, nsStruct, resolvedAggRequest, resolvedAggCmd, out);
971     }
972 
973     return getStatusFromCommandResult(result);
974 }
975 
976 }  // namespace mongo
977