1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 #include "mongo/platform/basic.h"
32 
33 #include "mongo/s/write_ops/batch_write_op.h"
34 
35 #include <numeric>
36 
37 #include "mongo/base/error_codes.h"
38 #include "mongo/db/operation_context.h"
39 #include "mongo/stdx/memory.h"
40 #include "mongo/util/transitional_tools_do_not_use/vector_spooling.h"
41 
42 namespace mongo {
43 
44 using std::unique_ptr;
45 using std::set;
46 using std::stringstream;
47 using std::vector;
48 
49 namespace {
50 
51 // Conservative overhead per element contained in the write batch. This value was calculated as 1
52 // byte (element type) + 5 bytes (max string encoding of the array index encoded as string and the
53 // maximum key is 99999) + 1 byte (zero terminator) = 7 bytes
54 const int kBSONArrayPerElementOverheadBytes = 7;
55 
56 struct WriteErrorDetailComp {
operator ()mongo::__anonbe04a7000111::WriteErrorDetailComp57     bool operator()(const WriteErrorDetail* errorA, const WriteErrorDetail* errorB) const {
58         return errorA->getIndex() < errorB->getIndex();
59     }
60 };
61 
62 // MAGIC NUMBERS
63 //
64 // Before serializing updates/deletes, we don't know how big their fields would be, but we break
65 // batches before serializing.
66 //
67 // TODO: Revisit when we revisit command limits in general
68 const int kEstUpdateOverheadBytes = (BSONObjMaxInternalSize - BSONObjMaxUserSize) / 100;
69 const int kEstDeleteOverheadBytes = (BSONObjMaxInternalSize - BSONObjMaxUserSize) / 100;
70 
71 /**
72  * Returns a new write concern that has the copy of every field from the original
73  * document but with a w set to 1. This is intended for upgrading { w: 0 } write
74  * concern to { w: 1 }.
75  */
upgradeWriteConcern(const BSONObj & origWriteConcern)76 BSONObj upgradeWriteConcern(const BSONObj& origWriteConcern) {
77     BSONObjIterator iter(origWriteConcern);
78     BSONObjBuilder newWriteConcern;
79 
80     while (iter.more()) {
81         BSONElement elem(iter.next());
82 
83         if (strncmp(elem.fieldName(), "w", 2) == 0) {
84             newWriteConcern.append("w", 1);
85         } else {
86             newWriteConcern.append(elem);
87         }
88     }
89 
90     return newWriteConcern.obj();
91 }
92 
buildTargetError(const Status & errStatus,WriteErrorDetail * details)93 void buildTargetError(const Status& errStatus, WriteErrorDetail* details) {
94     details->setErrCode(errStatus.code());
95     details->setErrMessage(errStatus.reason());
96 }
97 
98 /**
99  * Helper to determine whether a number of targeted writes require a new targeted batch.
100  */
isNewBatchRequiredOrdered(const std::vector<TargetedWrite * > & writes,const TargetedBatchMap & batchMap)101 bool isNewBatchRequiredOrdered(const std::vector<TargetedWrite*>& writes,
102                                const TargetedBatchMap& batchMap) {
103     for (const auto write : writes) {
104         if (batchMap.find(&write->endpoint) == batchMap.end()) {
105             return true;
106         }
107     }
108 
109     return false;
110 }
111 
112 /**
113  * Helper to determine whether a shard is already targeted with a different shardVersion, which
114  * necessitates a new batch. This happens when a batch write incldues a multi target write and
115  * a single target write.
116  */
isNewBatchRequiredUnordered(const std::vector<TargetedWrite * > & writes,const TargetedBatchMap & batchMap,const std::set<ShardId> & targetedShards)117 bool isNewBatchRequiredUnordered(const std::vector<TargetedWrite*>& writes,
118                                  const TargetedBatchMap& batchMap,
119                                  const std::set<ShardId>& targetedShards) {
120     for (const auto write : writes) {
121         if (batchMap.find(&write->endpoint) == batchMap.end()) {
122             if (targetedShards.find((&write->endpoint)->shardName) != targetedShards.end()) {
123                 return true;
124             }
125         }
126     }
127 
128     return false;
129 }
130 
131 /**
132  * Helper to determine whether a number of targeted writes require a new targeted batch.
133  */
wouldMakeBatchesTooBig(const std::vector<TargetedWrite * > & writes,int writeSizeBytes,const TargetedBatchMap & batchMap)134 bool wouldMakeBatchesTooBig(const std::vector<TargetedWrite*>& writes,
135                             int writeSizeBytes,
136                             const TargetedBatchMap& batchMap) {
137     for (const auto write : writes) {
138         TargetedBatchMap::const_iterator it = batchMap.find(&write->endpoint);
139         if (it == batchMap.end()) {
140             // If this is the first item in the batch, it can't be too big
141             continue;
142         }
143 
144         const auto& batch = it->second;
145 
146         if (batch->getNumOps() >= write_ops::kMaxWriteBatchSize) {
147             // Too many items in batch
148             return true;
149         }
150 
151         if (batch->getEstimatedSizeBytes() + writeSizeBytes > BSONObjMaxUserSize) {
152             // Batch would be too big
153             return true;
154         }
155     }
156 
157     return false;
158 }
159 
160 /**
161  * Gets an estimated size of how much the particular write operation would add to the size of the
162  * batch.
163  */
getWriteSizeBytes(const WriteOp & writeOp)164 int getWriteSizeBytes(const WriteOp& writeOp) {
165     const BatchItemRef& item = writeOp.getWriteItem();
166     const BatchedCommandRequest::BatchType batchType = item.getOpType();
167 
168     if (batchType == BatchedCommandRequest::BatchType_Insert) {
169         return item.getDocument().objsize();
170     } else if (batchType == BatchedCommandRequest::BatchType_Update) {
171         // Note: Be conservative here - it's okay if we send slightly too many batches
172         auto collationSize =
173             item.getUpdate().getCollation() ? item.getUpdate().getCollation()->objsize() : 0;
174         auto estSize = item.getUpdate().getQ().objsize() + item.getUpdate().getU().objsize() +
175             collationSize + kEstUpdateOverheadBytes;
176         dassert(estSize >= item.getUpdate().toBSON().objsize());
177         return estSize;
178     } else if (batchType == BatchedCommandRequest::BatchType_Delete) {
179         // Note: Be conservative here - it's okay if we send slightly too many batches
180         auto collationSize =
181             item.getDelete().getCollation() ? item.getDelete().getCollation()->objsize() : 0;
182         auto estSize = item.getDelete().getQ().objsize() + collationSize + kEstDeleteOverheadBytes;
183         dassert(estSize >= item.getDelete().toBSON().objsize());
184         return estSize;
185     }
186 
187     MONGO_UNREACHABLE;
188 }
189 
190 /**
191  * Given *either* a batch error or an array of per-item errors, copies errors we're interested in
192  * into a TrackedErrorMap
193  */
trackErrors(const ShardEndpoint & endpoint,const vector<WriteErrorDetail * > itemErrors,TrackedErrors * trackedErrors)194 void trackErrors(const ShardEndpoint& endpoint,
195                  const vector<WriteErrorDetail*> itemErrors,
196                  TrackedErrors* trackedErrors) {
197     for (const auto error : itemErrors) {
198         if (trackedErrors->isTracking(error->getErrCode())) {
199             trackedErrors->addError(ShardError(endpoint, *error));
200         }
201     }
202 }
203 
204 }  // namespace
205 
BatchWriteOp(OperationContext * opCtx,const BatchedCommandRequest & clientRequest)206 BatchWriteOp::BatchWriteOp(OperationContext* opCtx, const BatchedCommandRequest& clientRequest)
207     : _opCtx(opCtx), _clientRequest(clientRequest), _batchTxnNum(_opCtx->getTxnNumber()) {
208     _writeOps.reserve(_clientRequest.sizeWriteOps());
209 
210     for (size_t i = 0; i < _clientRequest.sizeWriteOps(); ++i) {
211         _writeOps.emplace_back(BatchItemRef(&_clientRequest, i));
212     }
213 }
214 
targetBatch(const NSTargeter & targeter,bool recordTargetErrors,std::map<ShardId,TargetedWriteBatch * > * targetedBatches)215 Status BatchWriteOp::targetBatch(const NSTargeter& targeter,
216                                  bool recordTargetErrors,
217                                  std::map<ShardId, TargetedWriteBatch*>* targetedBatches) {
218     //
219     // Targeting of unordered batches is fairly simple - each remaining write op is targeted,
220     // and each of those targeted writes are grouped into a batch for a particular shard
221     // endpoint.
222     //
223     // Targeting of ordered batches is a bit more complex - to respect the ordering of the
224     // batch, we can only send:
225     // A) a single targeted batch to one shard endpoint
226     // B) multiple targeted batches, but only containing targeted writes for a single write op
227     //
228     // This means that any multi-shard write operation must be targeted and sent one-by-one.
229     // Subsequent single-shard write operations can be batched together if they go to the same
230     // place.
231     //
232     // Ex: ShardA : { skey : a->k }, ShardB : { skey : k->z }
233     //
234     // Ordered insert batch of: [{ skey : a }, { skey : b }, { skey : x }]
235     // broken into:
236     //  [{ skey : a }, { skey : b }],
237     //  [{ skey : x }]
238     //
239     // Ordered update Batch of :
240     //  [{ skey : a }{ $push },
241     //   { skey : b }{ $push },
242     //   { skey : [c, x] }{ $push },
243     //   { skey : y }{ $push },
244     //   { skey : z }{ $push }]
245     // broken into:
246     //  [{ skey : a }, { skey : b }],
247     //  [{ skey : [c,x] }],
248     //  [{ skey : y }, { skey : z }]
249     //
250 
251     const bool ordered = _clientRequest.getWriteCommandBase().getOrdered();
252 
253     TargetedBatchMap batchMap;
254     std::set<ShardId> targetedShards;
255 
256     int numTargetErrors = 0;
257 
258     const size_t numWriteOps = _clientRequest.sizeWriteOps();
259 
260     for (size_t i = 0; i < numWriteOps; ++i) {
261         WriteOp& writeOp = _writeOps[i];
262 
263         // Only target _Ready ops
264         if (writeOp.getWriteState() != WriteOpState_Ready)
265             continue;
266 
267         //
268         // Get TargetedWrites from the targeter for the write operation
269         //
270         // TargetedWrites need to be owned once returned
271 
272         OwnedPointerVector<TargetedWrite> writesOwned;
273         vector<TargetedWrite*>& writes = writesOwned.mutableVector();
274 
275         Status targetStatus = writeOp.targetWrites(_opCtx, targeter, &writes);
276 
277         if (!targetStatus.isOK()) {
278             WriteErrorDetail targetError;
279             buildTargetError(targetStatus, &targetError);
280 
281             if (!recordTargetErrors) {
282                 // Cancel current batch state with an error
283                 _cancelBatches(targetError, std::move(batchMap));
284                 return targetStatus;
285             } else if (!ordered || batchMap.empty()) {
286                 // Record an error for this batch
287 
288                 writeOp.setOpError(targetError);
289                 ++numTargetErrors;
290 
291                 if (ordered)
292                     return Status::OK();
293 
294                 continue;
295             } else {
296                 dassert(ordered && !batchMap.empty());
297 
298                 // Send out what we have, but don't record an error yet, since there may be an error
299                 // in the writes before this point.
300                 writeOp.cancelWrites(&targetError);
301                 break;
302             }
303         }
304 
305         //
306         // If ordered and we have a previous endpoint, make sure we don't need to send these
307         // targeted writes to any other endpoints.
308         //
309 
310         if (ordered && !batchMap.empty()) {
311             dassert(batchMap.size() == 1u);
312             if (isNewBatchRequiredOrdered(writes, batchMap)) {
313                 writeOp.cancelWrites(NULL);
314                 break;
315             }
316         }
317 
318         // Account the array overhead once for the actual updates array and once for the statement
319         // ids array, if retryable writes are used
320         const int writeSizeBytes = getWriteSizeBytes(writeOp) + kBSONArrayPerElementOverheadBytes +
321             (_batchTxnNum ? kBSONArrayPerElementOverheadBytes + 4 : 0);
322 
323         if (wouldMakeBatchesTooBig(writes, writeSizeBytes, batchMap)) {
324             invariant(!batchMap.empty());
325             writeOp.cancelWrites(nullptr);
326             break;
327         }
328 
329         if (!ordered && !batchMap.empty() &&
330             isNewBatchRequiredUnordered(writes, batchMap, targetedShards)) {
331             writeOp.cancelWrites(nullptr);
332             break;
333         }
334 
335         //
336         // Targeting went ok, add to appropriate TargetedBatch
337         //
338 
339         for (const auto write : writes) {
340             TargetedBatchMap::iterator batchIt = batchMap.find(&write->endpoint);
341             if (batchIt == batchMap.end()) {
342                 TargetedWriteBatch* newBatch = new TargetedWriteBatch(write->endpoint);
343                 batchIt = batchMap.emplace(&newBatch->getEndpoint(), newBatch).first;
344                 targetedShards.insert((&newBatch->getEndpoint())->shardName);
345             }
346 
347             TargetedWriteBatch* batch = batchIt->second;
348             batch->addWrite(write, writeSizeBytes);
349         }
350 
351         // Relinquish ownership of TargetedWrites, now the TargetedBatches own them
352         writesOwned.mutableVector().clear();
353 
354         //
355         // Break if we're ordered and we have more than one endpoint - later writes cannot be
356         // enforced as ordered across multiple shard endpoints.
357         //
358 
359         if (ordered && batchMap.size() > 1u)
360             break;
361     }
362 
363     //
364     // Send back our targeted batches
365     //
366 
367     for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end(); ++it) {
368         TargetedWriteBatch* batch = it->second;
369 
370         if (batch->getWrites().empty())
371             continue;
372 
373         // Remember targeted batch for reporting
374         _targeted.insert(batch);
375 
376         // Send the handle back to caller
377         invariant(targetedBatches->find(batch->getEndpoint().shardName) == targetedBatches->end());
378         targetedBatches->emplace(batch->getEndpoint().shardName, batch);
379     }
380 
381     return Status::OK();
382 }
383 
buildBatchRequest(const TargetedWriteBatch & targetedBatch) const384 BatchedCommandRequest BatchWriteOp::buildBatchRequest(
385     const TargetedWriteBatch& targetedBatch) const {
386     const auto batchType = _clientRequest.getBatchType();
387 
388     boost::optional<std::vector<int32_t>> stmtIdsForOp;
389     if (_batchTxnNum) {
390         stmtIdsForOp.emplace();
391     }
392 
393     boost::optional<std::vector<BSONObj>> insertDocs;
394     boost::optional<std::vector<write_ops::UpdateOpEntry>> updates;
395     boost::optional<std::vector<write_ops::DeleteOpEntry>> deletes;
396 
397     for (const auto& targetedWrite : targetedBatch.getWrites()) {
398         const WriteOpRef& writeOpRef = targetedWrite->writeOpRef;
399 
400         switch (batchType) {
401             case BatchedCommandRequest::BatchType_Insert:
402                 if (!insertDocs)
403                     insertDocs.emplace();
404                 insertDocs->emplace_back(
405                     _clientRequest.getInsertRequest().getDocuments().at(writeOpRef.first));
406                 break;
407             case BatchedCommandRequest::BatchType_Update:
408                 if (!updates)
409                     updates.emplace();
410                 updates->emplace_back(
411                     _clientRequest.getUpdateRequest().getUpdates().at(writeOpRef.first));
412                 break;
413             case BatchedCommandRequest::BatchType_Delete:
414                 if (!deletes)
415                     deletes.emplace();
416                 deletes->emplace_back(
417                     _clientRequest.getDeleteRequest().getDeletes().at(writeOpRef.first));
418                 break;
419             default:
420                 MONGO_UNREACHABLE;
421         }
422 
423         if (stmtIdsForOp) {
424             stmtIdsForOp->push_back(write_ops::getStmtIdForWriteAt(
425                 _clientRequest.getWriteCommandBase(), writeOpRef.first));
426         }
427     }
428 
429     BatchedCommandRequest request([&] {
430         switch (batchType) {
431             case BatchedCommandRequest::BatchType_Insert:
432                 return BatchedCommandRequest([&] {
433                     write_ops::Insert insertOp(_clientRequest.getNS());
434                     insertOp.setDocuments(std::move(*insertDocs));
435                     return insertOp;
436                 }());
437             case BatchedCommandRequest::BatchType_Update:
438                 return BatchedCommandRequest([&] {
439                     write_ops::Update updateOp(_clientRequest.getNS());
440                     updateOp.setUpdates(std::move(*updates));
441                     return updateOp;
442                 }());
443             case BatchedCommandRequest::BatchType_Delete:
444                 return BatchedCommandRequest([&] {
445                     write_ops::Delete deleteOp(_clientRequest.getNS());
446                     deleteOp.setDeletes(std::move(*deletes));
447                     return deleteOp;
448                 }());
449         }
450         MONGO_UNREACHABLE;
451     }());
452 
453     request.setWriteCommandBase([&] {
454         write_ops::WriteCommandBase wcb;
455 
456         wcb.setBypassDocumentValidation(
457             _clientRequest.getWriteCommandBase().getBypassDocumentValidation());
458         wcb.setOrdered(_clientRequest.getWriteCommandBase().getOrdered());
459 
460         if (_batchTxnNum) {
461             wcb.setStmtIds(std::move(stmtIdsForOp));
462         }
463 
464         return wcb;
465     }());
466 
467     request.setShardVersion(targetedBatch.getEndpoint().shardVersion);
468 
469     if (_clientRequest.hasWriteConcern()) {
470         if (_clientRequest.isVerboseWC()) {
471             request.setWriteConcern(_clientRequest.getWriteConcern());
472         } else {
473             // Mongos needs to send to the shard with w > 0 so it will be able to see the
474             // writeErrors
475             request.setWriteConcern(upgradeWriteConcern(_clientRequest.getWriteConcern()));
476         }
477     }
478 
479     return request;
480 }
481 
noteBatchResponse(const TargetedWriteBatch & targetedBatch,const BatchedCommandResponse & response,TrackedErrors * trackedErrors)482 void BatchWriteOp::noteBatchResponse(const TargetedWriteBatch& targetedBatch,
483                                      const BatchedCommandResponse& response,
484                                      TrackedErrors* trackedErrors) {
485     if (!response.getOk()) {
486         WriteErrorDetail error;
487         error.setErrCode(response.getErrCode());
488         error.setErrMessage(response.getErrMessage());
489 
490         // Treat command errors exactly like other failures of the batch.
491         //
492         // Note that no errors will be tracked from these failures - as-designed.
493         noteBatchError(targetedBatch, error);
494         return;
495     }
496 
497     // Stop tracking targeted batch
498     _targeted.erase(&targetedBatch);
499 
500     // Increment stats for this batch
501     _incBatchStats(response);
502 
503     //
504     // Assign errors to particular items.
505     // Write Concern errors are stored and handled later.
506     //
507 
508     // Special handling for write concern errors, save for later
509     if (response.isWriteConcernErrorSet()) {
510         _wcErrors.emplace_back(targetedBatch.getEndpoint(), *response.getWriteConcernError());
511     }
512 
513     vector<WriteErrorDetail*> itemErrors;
514 
515     // Handle batch and per-item errors
516     if (response.isErrDetailsSet()) {
517         // Per-item errors were set
518         itemErrors.insert(
519             itemErrors.begin(), response.getErrDetails().begin(), response.getErrDetails().end());
520 
521         // Sort per-item errors by index
522         std::sort(itemErrors.begin(), itemErrors.end(), WriteErrorDetailComp());
523     }
524 
525     //
526     // Go through all pending responses of the op and sorted remote reponses, populate errors
527     // This will either set all errors to the batch error or apply per-item errors as-needed
528     //
529     // If the batch is ordered, cancel all writes after the first error for retargeting.
530     //
531 
532     const bool ordered = _clientRequest.getWriteCommandBase().getOrdered();
533 
534     vector<WriteErrorDetail*>::iterator itemErrorIt = itemErrors.begin();
535     int index = 0;
536     WriteErrorDetail* lastError = NULL;
537     for (vector<TargetedWrite *>::const_iterator it = targetedBatch.getWrites().begin();
538          it != targetedBatch.getWrites().end();
539          ++it, ++index) {
540         const TargetedWrite* write = *it;
541         WriteOp& writeOp = _writeOps[write->writeOpRef.first];
542 
543         dassert(writeOp.getWriteState() == WriteOpState_Pending);
544 
545         // See if we have an error for the write
546         WriteErrorDetail* writeError = NULL;
547 
548         if (itemErrorIt != itemErrors.end() && (*itemErrorIt)->getIndex() == index) {
549             // We have an per-item error for this write op's index
550             writeError = *itemErrorIt;
551             ++itemErrorIt;
552         }
553 
554         // Finish the response (with error, if needed)
555         if (NULL == writeError) {
556             if (!ordered || !lastError) {
557                 writeOp.noteWriteComplete(*write);
558             } else {
559                 // We didn't actually apply this write - cancel so we can retarget
560                 dassert(writeOp.getNumTargeted() == 1u);
561                 writeOp.cancelWrites(lastError);
562             }
563         } else {
564             writeOp.noteWriteError(*write, *writeError);
565             lastError = writeError;
566         }
567     }
568 
569     // Track errors we care about, whether batch or individual errors
570     if (NULL != trackedErrors) {
571         trackErrors(targetedBatch.getEndpoint(), itemErrors, trackedErrors);
572     }
573 
574     // Track upserted ids if we need to
575     if (response.isUpsertDetailsSet()) {
576         const vector<BatchedUpsertDetail*>& upsertedIds = response.getUpsertDetails();
577         for (vector<BatchedUpsertDetail*>::const_iterator it = upsertedIds.begin();
578              it != upsertedIds.end();
579              ++it) {
580             // The child upserted details don't have the correct index for the full batch
581             const BatchedUpsertDetail* childUpsertedId = *it;
582 
583             // Work backward from the child batch item index to the batch item index
584             int childBatchIndex = childUpsertedId->getIndex();
585             int batchIndex = targetedBatch.getWrites()[childBatchIndex]->writeOpRef.first;
586 
587             // Push the upserted id with the correct index into the batch upserted ids
588             auto upsertedId = stdx::make_unique<BatchedUpsertDetail>();
589             upsertedId->setIndex(batchIndex);
590             upsertedId->setUpsertedID(childUpsertedId->getUpsertedID());
591             _upsertedIds.push_back(std::move(upsertedId));
592         }
593     }
594 }
595 
noteBatchError(const TargetedWriteBatch & targetedBatch,const WriteErrorDetail & error)596 void BatchWriteOp::noteBatchError(const TargetedWriteBatch& targetedBatch,
597                                   const WriteErrorDetail& error) {
598     // Treat errors to get a batch response as failures of the contained writes
599     BatchedCommandResponse emulatedResponse;
600     emulatedResponse.setOk(true);
601     emulatedResponse.setN(0);
602 
603     const int numErrors =
604         _clientRequest.getWriteCommandBase().getOrdered() ? 1 : targetedBatch.getWrites().size();
605 
606     for (int i = 0; i < numErrors; i++) {
607         auto errorClone(stdx::make_unique<WriteErrorDetail>());
608         error.cloneTo(errorClone.get());
609         errorClone->setIndex(i);
610         emulatedResponse.addToErrDetails(errorClone.release());
611     }
612 
613     dassert(emulatedResponse.isValid(nullptr));
614     noteBatchResponse(targetedBatch, emulatedResponse, nullptr);
615 }
616 
abortBatch(const WriteErrorDetail & error)617 void BatchWriteOp::abortBatch(const WriteErrorDetail& error) {
618     dassert(!isFinished());
619     dassert(numWriteOpsIn(WriteOpState_Pending) == 0);
620 
621     const size_t numWriteOps = _clientRequest.sizeWriteOps();
622     const bool orderedOps = _clientRequest.getWriteCommandBase().getOrdered();
623     for (size_t i = 0; i < numWriteOps; ++i) {
624         WriteOp& writeOp = _writeOps[i];
625         // Can only be called with no outstanding batches
626         dassert(writeOp.getWriteState() != WriteOpState_Pending);
627 
628         if (writeOp.getWriteState() < WriteOpState_Completed) {
629             writeOp.setOpError(error);
630 
631             // Only one error if we're ordered
632             if (orderedOps)
633                 break;
634         }
635     }
636 
637     dassert(isFinished());
638 }
639 
isFinished()640 bool BatchWriteOp::isFinished() {
641     const size_t numWriteOps = _clientRequest.sizeWriteOps();
642     const bool orderedOps = _clientRequest.getWriteCommandBase().getOrdered();
643     for (size_t i = 0; i < numWriteOps; ++i) {
644         WriteOp& writeOp = _writeOps[i];
645         if (writeOp.getWriteState() < WriteOpState_Completed)
646             return false;
647         else if (orderedOps && writeOp.getWriteState() == WriteOpState_Error)
648             return true;
649     }
650 
651     return true;
652 }
653 
buildClientResponse(BatchedCommandResponse * batchResp)654 void BatchWriteOp::buildClientResponse(BatchedCommandResponse* batchResp) {
655     dassert(isFinished());
656 
657     // Result is OK
658     batchResp->setOk(true);
659 
660     // For non-verbose, it's all we need.
661     if (!_clientRequest.isVerboseWC()) {
662         dassert(batchResp->isValid(NULL));
663         return;
664     }
665 
666     //
667     // Find all the errors in the batch
668     //
669 
670     vector<WriteOp*> errOps;
671 
672     const size_t numWriteOps = _clientRequest.sizeWriteOps();
673     for (size_t i = 0; i < numWriteOps; ++i) {
674         WriteOp& writeOp = _writeOps[i];
675 
676         if (writeOp.getWriteState() == WriteOpState_Error) {
677             errOps.push_back(&writeOp);
678         }
679     }
680 
681     //
682     // Build the per-item errors.
683     //
684 
685     if (!errOps.empty()) {
686         for (vector<WriteOp*>::iterator it = errOps.begin(); it != errOps.end(); ++it) {
687             WriteOp& writeOp = **it;
688             WriteErrorDetail* error = new WriteErrorDetail();
689             writeOp.getOpError().cloneTo(error);
690             batchResp->addToErrDetails(error);
691         }
692     }
693 
694     // Only return a write concern error if everything succeeded (unordered or ordered)
695     // OR if something succeeded and we're unordered
696     const bool orderedOps = _clientRequest.getWriteCommandBase().getOrdered();
697     const bool reportWCError =
698         errOps.empty() || (!orderedOps && errOps.size() < _clientRequest.sizeWriteOps());
699     if (!_wcErrors.empty() && reportWCError) {
700         WriteConcernErrorDetail* error = new WriteConcernErrorDetail;
701 
702         // Generate the multi-error message below
703         StringBuilder msg;
704         if (_wcErrors.size() > 1) {
705             msg << "multiple errors reported : ";
706             error->setErrCode(ErrorCodes::WriteConcernFailed);
707         } else {
708             error->setErrCode(_wcErrors.begin()->error.getErrCode());
709         }
710 
711         for (auto it = _wcErrors.begin(); it != _wcErrors.end(); ++it) {
712             const auto& wcError = *it;
713             if (it != _wcErrors.begin()) {
714                 msg << " :: and :: ";
715             }
716             msg << wcError.error.getErrMessage() << " at " << wcError.endpoint.shardName;
717         }
718 
719         error->setErrMessage(msg.str());
720         batchResp->setWriteConcernError(error);
721     }
722 
723     //
724     // Append the upserted ids, if required
725     //
726 
727     if (_upsertedIds.size() != 0) {
728         batchResp->setUpsertDetails(transitional_tools_do_not_use::unspool_vector(_upsertedIds));
729     }
730 
731     // Stats
732     const int nValue = _numInserted + _numUpserted + _numMatched + _numDeleted;
733     batchResp->setN(nValue);
734     if (_clientRequest.getBatchType() == BatchedCommandRequest::BatchType_Update &&
735         _numModified >= 0) {
736         batchResp->setNModified(_numModified);
737     }
738 
739     dassert(batchResp->isValid(NULL));
740 }
741 
numWriteOpsIn(WriteOpState opState) const742 int BatchWriteOp::numWriteOpsIn(WriteOpState opState) const {
743     // TODO: This could be faster, if we tracked this info explicitly
744     return std::accumulate(
745         _writeOps.begin(), _writeOps.end(), 0, [opState](int sum, const WriteOp& writeOp) {
746             return sum + (writeOp.getWriteState() == opState ? 1 : 0);
747         });
748 }
749 
_incBatchStats(const BatchedCommandResponse & response)750 void BatchWriteOp::_incBatchStats(const BatchedCommandResponse& response) {
751     const auto batchType = _clientRequest.getBatchType();
752 
753     if (batchType == BatchedCommandRequest::BatchType_Insert) {
754         _numInserted += response.getN();
755     } else if (batchType == BatchedCommandRequest::BatchType_Update) {
756         int numUpserted = 0;
757         if (response.isUpsertDetailsSet()) {
758             numUpserted = response.sizeUpsertDetails();
759         }
760         _numMatched += (response.getN() - numUpserted);
761         long long numModified = response.getNModified();
762 
763         if (numModified >= 0)
764             _numModified += numModified;
765         else
766             _numModified = -1;  // sentinel used to indicate we omit the field downstream
767 
768         _numUpserted += numUpserted;
769     } else {
770         dassert(batchType == BatchedCommandRequest::BatchType_Delete);
771         _numDeleted += response.getN();
772     }
773 }
774 
_cancelBatches(const WriteErrorDetail & why,TargetedBatchMap && batchMapToCancel)775 void BatchWriteOp::_cancelBatches(const WriteErrorDetail& why,
776                                   TargetedBatchMap&& batchMapToCancel) {
777     TargetedBatchMap batchMap(batchMapToCancel);
778 
779     // Collect all the writeOps that are currently targeted
780     for (TargetedBatchMap::iterator it = batchMap.begin(); it != batchMap.end();) {
781         TargetedWriteBatch* batch = it->second;
782         const vector<TargetedWrite*>& writes = batch->getWrites();
783 
784         for (vector<TargetedWrite*>::const_iterator writeIt = writes.begin();
785              writeIt != writes.end();
786              ++writeIt) {
787             TargetedWrite* write = *writeIt;
788 
789             // NOTE: We may repeatedly cancel a write op here, but that's fast and we want to cancel
790             // before erasing the TargetedWrite* (which owns the cancelled targeting info) for
791             // reporting reasons.
792             _writeOps[write->writeOpRef.first].cancelWrites(&why);
793         }
794 
795         // Note that we need to *erase* first, *then* delete, since the map keys are ptrs from
796         // the values
797         batchMap.erase(it++);
798         delete batch;
799     }
800 }
801 
operator ()(const ShardEndpoint * endpointA,const ShardEndpoint * endpointB) const802 bool EndpointComp::operator()(const ShardEndpoint* endpointA,
803                               const ShardEndpoint* endpointB) const {
804     const int shardNameDiff = endpointA->shardName.compare(endpointB->shardName);
805     if (shardNameDiff) {
806         return shardNameDiff < 0;
807     }
808 
809     const long shardVersionDiff =
810         endpointA->shardVersion.toLong() - endpointB->shardVersion.toLong();
811     if (shardVersionDiff) {
812         return shardVersionDiff < 0;
813     }
814 
815     return endpointA->shardVersion.epoch().compare(endpointB->shardVersion.epoch()) < 0;
816 }
817 
startTracking(int errCode)818 void TrackedErrors::startTracking(int errCode) {
819     dassert(!isTracking(errCode));
820     _errorMap.emplace(errCode, std::vector<ShardError>());
821 }
822 
isTracking(int errCode) const823 bool TrackedErrors::isTracking(int errCode) const {
824     return _errorMap.count(errCode) != 0;
825 }
826 
addError(ShardError error)827 void TrackedErrors::addError(ShardError error) {
828     TrackedErrorMap::iterator seenIt = _errorMap.find(error.error.getErrCode());
829     if (seenIt == _errorMap.end())
830         return;
831     seenIt->second.emplace_back(std::move(error));
832 }
833 
getErrors(int errCode) const834 const std::vector<ShardError>& TrackedErrors::getErrors(int errCode) const {
835     dassert(isTracking(errCode));
836     return _errorMap.find(errCode)->second;
837 }
838 
839 }  // namespace mongo
840