1# Copyright 2009-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""**DEPRECATED** Tools for creating `messages
16<http://www.mongodb.org/display/DOCS/Mongo+Wire+Protocol>`_ to be sent to
17MongoDB.
18
19.. note:: This module is for internal use and is generally not needed by
20   application developers.
21
22.. versionchanged:: 3.12
23  This module is deprecated and will be removed in PyMongo 4.0.
24"""
25
26import datetime
27import random
28import struct
29
30import bson
31from bson import (CodecOptions,
32                  decode,
33                  encode,
34                  _decode_selective,
35                  _dict_to_bson,
36                  _make_c_string)
37from bson.codec_options import DEFAULT_CODEC_OPTIONS
38from bson.raw_bson import (_inflate_bson, DEFAULT_RAW_BSON_OPTIONS,
39                           RawBSONDocument)
40from bson.py3compat import b, StringIO
41from bson.son import SON
42
43try:
44    from pymongo import _cmessage
45    _use_c = True
46except ImportError:
47    _use_c = False
48from pymongo.errors import (ConfigurationError,
49                            CursorNotFound,
50                            DocumentTooLarge,
51                            ExecutionTimeout,
52                            InvalidOperation,
53                            NotPrimaryError,
54                            OperationFailure,
55                            ProtocolError)
56from pymongo.hello_compat import HelloCompat
57from pymongo.read_concern import DEFAULT_READ_CONCERN
58from pymongo.read_preferences import ReadPreference
59from pymongo.write_concern import WriteConcern
60
61
62MAX_INT32 = 2147483647
63MIN_INT32 = -2147483648
64
65# Overhead allowed for encoded command documents.
66_COMMAND_OVERHEAD = 16382
67
68_INSERT = 0
69_UPDATE = 1
70_DELETE = 2
71
72_EMPTY   = b''
73_BSONOBJ = b'\x03'
74_ZERO_8  = b'\x00'
75_ZERO_16 = b'\x00\x00'
76_ZERO_32 = b'\x00\x00\x00\x00'
77_ZERO_64 = b'\x00\x00\x00\x00\x00\x00\x00\x00'
78_SKIPLIM = b'\x00\x00\x00\x00\xff\xff\xff\xff'
79_OP_MAP = {
80    _INSERT: b'\x04documents\x00\x00\x00\x00\x00',
81    _UPDATE: b'\x04updates\x00\x00\x00\x00\x00',
82    _DELETE: b'\x04deletes\x00\x00\x00\x00\x00',
83}
84_FIELD_MAP = {
85    'insert': 'documents',
86    'update': 'updates',
87    'delete': 'deletes'
88}
89
90_UJOIN = u"%s.%s"
91
92_UNICODE_REPLACE_CODEC_OPTIONS = CodecOptions(
93    unicode_decode_error_handler='replace')
94
95
96def _randint():
97    """Generate a pseudo random 32 bit integer."""
98    return random.randint(MIN_INT32, MAX_INT32)
99
100
101def _maybe_add_read_preference(spec, read_preference):
102    """Add $readPreference to spec when appropriate."""
103    mode = read_preference.mode
104    document = read_preference.document
105    # Only add $readPreference if it's something other than primary to avoid
106    # problems with mongos versions that don't support read preferences. Also,
107    # for maximum backwards compatibility, don't add $readPreference for
108    # secondaryPreferred unless tags or maxStalenessSeconds are in use (setting
109    # the secondaryOkay bit has the same effect).
110    if mode and (
111            mode != ReadPreference.SECONDARY_PREFERRED.mode or
112            len(document) > 1):
113        if "$query" not in spec:
114            spec = SON([("$query", spec)])
115        spec["$readPreference"] = document
116    return spec
117
118
119def _convert_exception(exception):
120    """Convert an Exception into a failure document for publishing."""
121    return {'errmsg': str(exception),
122            'errtype': exception.__class__.__name__}
123
124
125def _convert_write_result(operation, command, result):
126    """Convert a legacy write result to write command format."""
127
128    # Based on _merge_legacy from bulk.py
129    affected = result.get("n", 0)
130    res = {"ok": 1, "n": affected}
131    errmsg = result.get("errmsg", result.get("err", ""))
132    if errmsg:
133        # The write was successful on at least the primary so don't return.
134        if result.get("wtimeout"):
135            res["writeConcernError"] = {"errmsg": errmsg,
136                                        "code": 64,
137                                        "errInfo": {"wtimeout": True}}
138        else:
139            # The write failed.
140            error = {"index": 0,
141                     "code": result.get("code", 8),
142                     "errmsg": errmsg}
143            if "errInfo" in result:
144                error["errInfo"] = result["errInfo"]
145            res["writeErrors"] = [error]
146            return res
147    if operation == "insert":
148        # GLE result for insert is always 0 in most MongoDB versions.
149        res["n"] = len(command['documents'])
150    elif operation == "update":
151        if "upserted" in result:
152            res["upserted"] = [{"index": 0, "_id": result["upserted"]}]
153        # Versions of MongoDB before 2.6 don't return the _id for an
154        # upsert if _id is not an ObjectId.
155        elif result.get("updatedExisting") is False and affected == 1:
156            # If _id is in both the update document *and* the query spec
157            # the update document _id takes precedence.
158            update = command['updates'][0]
159            _id = update["u"].get("_id", update["q"].get("_id"))
160            res["upserted"] = [{"index": 0, "_id": _id}]
161    return res
162
163
164_OPTIONS = SON([
165    ('tailable', 2),
166    ('oplogReplay', 8),
167    ('noCursorTimeout', 16),
168    ('awaitData', 32),
169    ('allowPartialResults', 128)])
170
171
172_MODIFIERS = SON([
173    ('$query', 'filter'),
174    ('$orderby', 'sort'),
175    ('$hint', 'hint'),
176    ('$comment', 'comment'),
177    ('$maxScan', 'maxScan'),
178    ('$maxTimeMS', 'maxTimeMS'),
179    ('$max', 'max'),
180    ('$min', 'min'),
181    ('$returnKey', 'returnKey'),
182    ('$showRecordId', 'showRecordId'),
183    ('$showDiskLoc', 'showRecordId'),  # <= MongoDb 3.0
184    ('$snapshot', 'snapshot')])
185
186
187def _gen_find_command(coll, spec, projection, skip, limit, batch_size, options,
188                      read_concern, collation=None, session=None,
189                      allow_disk_use=None):
190    """Generate a find command document."""
191    cmd = SON([('find', coll)])
192    if '$query' in spec:
193        cmd.update([(_MODIFIERS[key], val) if key in _MODIFIERS else (key, val)
194                    for key, val in spec.items()])
195        if '$explain' in cmd:
196            cmd.pop('$explain')
197        if '$readPreference' in cmd:
198            cmd.pop('$readPreference')
199    else:
200        cmd['filter'] = spec
201
202    if projection:
203        cmd['projection'] = projection
204    if skip:
205        cmd['skip'] = skip
206    if limit:
207        cmd['limit'] = abs(limit)
208        if limit < 0:
209            cmd['singleBatch'] = True
210    if batch_size:
211        cmd['batchSize'] = batch_size
212    if read_concern.level and not (session and session.in_transaction):
213        cmd['readConcern'] = read_concern.document
214    if collation:
215        cmd['collation'] = collation
216    if allow_disk_use is not None:
217        cmd['allowDiskUse'] = allow_disk_use
218    if options:
219        cmd.update([(opt, True)
220                    for opt, val in _OPTIONS.items()
221                    if options & val])
222
223    return cmd
224
225
226def _gen_get_more_command(cursor_id, coll, batch_size, max_await_time_ms):
227    """Generate a getMore command document."""
228    cmd = SON([('getMore', cursor_id),
229               ('collection', coll)])
230    if batch_size:
231        cmd['batchSize'] = batch_size
232    if max_await_time_ms is not None:
233        cmd['maxTimeMS'] = max_await_time_ms
234    return cmd
235
236
237class _Query(object):
238    """A query operation."""
239
240    __slots__ = ('flags', 'db', 'coll', 'ntoskip', 'spec',
241                 'fields', 'codec_options', 'read_preference', 'limit',
242                 'batch_size', 'name', 'read_concern', 'collation',
243                 'session', 'client', 'allow_disk_use', '_as_command',
244                 'exhaust')
245
246    # For compatibility with the _GetMore class.
247    sock_mgr = None
248    cursor_id = None
249
250    def __init__(self, flags, db, coll, ntoskip, spec, fields,
251                 codec_options, read_preference, limit,
252                 batch_size, read_concern, collation, session, client,
253                 allow_disk_use, exhaust):
254        self.flags = flags
255        self.db = db
256        self.coll = coll
257        self.ntoskip = ntoskip
258        self.spec = spec
259        self.fields = fields
260        self.codec_options = codec_options
261        self.read_preference = read_preference
262        self.read_concern = read_concern
263        self.limit = limit
264        self.batch_size = batch_size
265        self.collation = collation
266        self.session = session
267        self.client = client
268        self.allow_disk_use = allow_disk_use
269        self.name = 'find'
270        self._as_command = None
271        self.exhaust = exhaust
272
273    def namespace(self):
274        return _UJOIN % (self.db, self.coll)
275
276    def use_command(self, sock_info):
277        use_find_cmd = False
278        if sock_info.max_wire_version >= 4 and not self.exhaust:
279            use_find_cmd = True
280        elif sock_info.max_wire_version >= 8:
281            # OP_MSG supports exhaust on MongoDB 4.2+
282            use_find_cmd = True
283        elif not self.read_concern.ok_for_legacy:
284            raise ConfigurationError(
285                'read concern level of %s is not valid '
286                'with a max wire version of %d.'
287                % (self.read_concern.level,
288                   sock_info.max_wire_version))
289
290        if sock_info.max_wire_version < 5 and self.collation is not None:
291            raise ConfigurationError(
292                'Specifying a collation is unsupported with a max wire '
293                'version of %d.' % (sock_info.max_wire_version,))
294
295        if sock_info.max_wire_version < 4 and self.allow_disk_use is not None:
296            raise ConfigurationError(
297                'Specifying allowDiskUse is unsupported with a max wire '
298                'version of %d.' % (sock_info.max_wire_version,))
299
300        sock_info.validate_session(self.client, self.session)
301
302        return use_find_cmd
303
304    def as_command(self, sock_info):
305        """Return a find command document for this query."""
306        # We use the command twice: on the wire and for command monitoring.
307        # Generate it once, for speed and to avoid repeating side-effects.
308        if self._as_command is not None:
309            return self._as_command
310
311        explain = '$explain' in self.spec
312        cmd = _gen_find_command(
313            self.coll, self.spec, self.fields, self.ntoskip,
314            self.limit, self.batch_size, self.flags, self.read_concern,
315            self.collation, self.session, self.allow_disk_use)
316        if explain:
317            self.name = 'explain'
318            cmd = SON([('explain', cmd)])
319        session = self.session
320        sock_info.add_server_api(cmd)
321        if session:
322            session._apply_to(cmd, False, self.read_preference, sock_info)
323            # Explain does not support readConcern.
324            if not explain and not session.in_transaction:
325                session._update_read_concern(cmd, sock_info)
326        sock_info.send_cluster_time(cmd, session, self.client)
327        # Support auto encryption
328        client = self.client
329        if (client._encrypter and
330                not client._encrypter._bypass_auto_encryption):
331            cmd = client._encrypter.encrypt(
332                self.db, cmd, False, self.codec_options)
333        self._as_command = cmd, self.db
334        return self._as_command
335
336    def get_message(self, set_secondary_ok, sock_info, use_cmd=False):
337        """Get a query message, possibly setting the secondaryOk bit."""
338        if set_secondary_ok:
339            # Set the secondaryOk bit.
340            flags = self.flags | 4
341        else:
342            flags = self.flags
343
344        ns = self.namespace()
345        spec = self.spec
346
347        if use_cmd:
348            spec = self.as_command(sock_info)[0]
349            if sock_info.op_msg_enabled:
350                request_id, msg, size, _ = _op_msg(
351                    0, spec, self.db, self.read_preference,
352                    set_secondary_ok, False, self.codec_options,
353                    ctx=sock_info.compression_context)
354                return request_id, msg, size
355            ns = _UJOIN % (self.db, "$cmd")
356            ntoreturn = -1  # All DB commands return 1 document
357        else:
358            # OP_QUERY treats ntoreturn of -1 and 1 the same, return
359            # one document and close the cursor. We have to use 2 for
360            # batch size if 1 is specified.
361            ntoreturn = self.batch_size == 1 and 2 or self.batch_size
362            if self.limit:
363                if ntoreturn:
364                    ntoreturn = min(self.limit, ntoreturn)
365                else:
366                    ntoreturn = self.limit
367
368        if sock_info.is_mongos:
369            spec = _maybe_add_read_preference(spec,
370                                              self.read_preference)
371
372        return query(flags, ns, self.ntoskip, ntoreturn,
373                     spec, None if use_cmd else self.fields,
374                     self.codec_options, ctx=sock_info.compression_context)
375
376
377class _GetMore(object):
378    """A getmore operation."""
379
380    __slots__ = ('db', 'coll', 'ntoreturn', 'cursor_id', 'max_await_time_ms',
381                 'codec_options', 'read_preference', 'session', 'client',
382                 'sock_mgr', '_as_command', 'exhaust')
383
384    name = 'getMore'
385
386    def __init__(self, db, coll, ntoreturn, cursor_id, codec_options,
387                 read_preference, session, client, max_await_time_ms,
388                 sock_mgr, exhaust):
389        self.db = db
390        self.coll = coll
391        self.ntoreturn = ntoreturn
392        self.cursor_id = cursor_id
393        self.codec_options = codec_options
394        self.read_preference = read_preference
395        self.session = session
396        self.client = client
397        self.max_await_time_ms = max_await_time_ms
398        self.sock_mgr = sock_mgr
399        self._as_command = None
400        self.exhaust = exhaust
401
402    def namespace(self):
403        return _UJOIN % (self.db, self.coll)
404
405    def use_command(self, sock_info):
406        use_cmd = False
407        if sock_info.max_wire_version >= 4 and not self.exhaust:
408            use_cmd = True
409        elif sock_info.max_wire_version >= 8:
410            # OP_MSG supports exhaust on MongoDB 4.2+
411            use_cmd = True
412
413        sock_info.validate_session(self.client, self.session)
414        return use_cmd
415
416    def as_command(self, sock_info):
417        """Return a getMore command document for this query."""
418        # See _Query.as_command for an explanation of this caching.
419        if self._as_command is not None:
420            return self._as_command
421
422        cmd = _gen_get_more_command(self.cursor_id, self.coll,
423                                    self.ntoreturn,
424                                    self.max_await_time_ms)
425
426        if self.session:
427            self.session._apply_to(cmd, False, self.read_preference, sock_info)
428        sock_info.add_server_api(cmd)
429        sock_info.send_cluster_time(cmd, self.session, self.client)
430        # Support auto encryption
431        client = self.client
432        if (client._encrypter and
433                not client._encrypter._bypass_auto_encryption):
434            cmd = client._encrypter.encrypt(
435                self.db, cmd, False, self.codec_options)
436        self._as_command = cmd, self.db
437        return self._as_command
438
439    def get_message(self, dummy0, sock_info, use_cmd=False):
440        """Get a getmore message."""
441
442        ns = self.namespace()
443        ctx = sock_info.compression_context
444
445        if use_cmd:
446            spec = self.as_command(sock_info)[0]
447            if sock_info.op_msg_enabled:
448                if self.sock_mgr:
449                    flags = _OpMsg.EXHAUST_ALLOWED
450                else:
451                    flags = 0
452                request_id, msg, size, _ = _op_msg(
453                    flags, spec, self.db, None,
454                    False, False, self.codec_options,
455                    ctx=sock_info.compression_context)
456                return request_id, msg, size
457            ns = _UJOIN % (self.db, "$cmd")
458            return query(0, ns, 0, -1, spec, None, self.codec_options, ctx=ctx)
459
460        return get_more(ns, self.ntoreturn, self.cursor_id, ctx)
461
462
463class _RawBatchQuery(_Query):
464    def use_command(self, sock_info):
465        # Compatibility checks.
466        super(_RawBatchQuery, self).use_command(sock_info)
467        if sock_info.max_wire_version >= 8:
468            # MongoDB 4.2+ supports exhaust over OP_MSG
469            return True
470        elif sock_info.op_msg_enabled and not self.exhaust:
471            return True
472        return False
473
474
475class _RawBatchGetMore(_GetMore):
476    def use_command(self, sock_info):
477        # Compatibility checks.
478        super(_RawBatchGetMore, self).use_command(sock_info)
479        if sock_info.max_wire_version >= 8:
480            # MongoDB 4.2+ supports exhaust over OP_MSG
481            return True
482        elif sock_info.op_msg_enabled and not self.exhaust:
483            return True
484        return False
485
486
487class _CursorAddress(tuple):
488    """The server address (host, port) of a cursor, with namespace property."""
489
490    def __new__(cls, address, namespace):
491        self = tuple.__new__(cls, address)
492        self.__namespace = namespace
493        return self
494
495    @property
496    def namespace(self):
497        """The namespace this cursor."""
498        return self.__namespace
499
500    def __hash__(self):
501        # Two _CursorAddress instances with different namespaces
502        # must not hash the same.
503        return (self + (self.__namespace,)).__hash__()
504
505    def __eq__(self, other):
506        if isinstance(other, _CursorAddress):
507            return (tuple(self) == tuple(other)
508                    and self.namespace == other.namespace)
509        return NotImplemented
510
511    def __ne__(self, other):
512        return not self == other
513
514
515_pack_compression_header = struct.Struct("<iiiiiiB").pack
516_COMPRESSION_HEADER_SIZE = 25
517
518def _compress(operation, data, ctx):
519    """Takes message data, compresses it, and adds an OP_COMPRESSED header."""
520    compressed = ctx.compress(data)
521    request_id = _randint()
522
523    header = _pack_compression_header(
524        _COMPRESSION_HEADER_SIZE + len(compressed), # Total message length
525        request_id, # Request id
526        0, # responseTo
527        2012, # operation id
528        operation, # original operation id
529        len(data), # uncompressed message length
530        ctx.compressor_id) # compressor id
531    return request_id, header + compressed
532
533
534def __last_error(namespace, args):
535    """Data to send to do a lastError.
536    """
537    cmd = SON([("getlasterror", 1)])
538    cmd.update(args)
539    splitns = namespace.split('.', 1)
540    return query(0, splitns[0] + '.$cmd', 0, -1, cmd,
541                 None, DEFAULT_CODEC_OPTIONS)
542
543
544_pack_header = struct.Struct("<iiii").pack
545
546
547def __pack_message(operation, data):
548    """Takes message data and adds a message header based on the operation.
549
550    Returns the resultant message string.
551    """
552    rid = _randint()
553    message = _pack_header(16 + len(data), rid, 0, operation)
554    return rid, message + data
555
556
557_pack_int = struct.Struct("<i").pack
558
559
560def _insert(collection_name, docs, check_keys, flags, opts):
561    """Get an OP_INSERT message"""
562    encode = _dict_to_bson  # Make local. Uses extensions.
563    if len(docs) == 1:
564        encoded = encode(docs[0], check_keys, opts)
565        return b"".join([
566            b"\x00\x00\x00\x00",  # Flags don't matter for one doc.
567            _make_c_string(collection_name),
568            encoded]), len(encoded)
569
570    encoded = [encode(doc, check_keys, opts) for doc in docs]
571    if not encoded:
572        raise InvalidOperation("cannot do an empty bulk insert")
573    return b"".join([
574        _pack_int(flags),
575        _make_c_string(collection_name),
576        b"".join(encoded)]), max(map(len, encoded))
577
578
579def _insert_compressed(
580        collection_name, docs, check_keys, continue_on_error, opts, ctx):
581    """Internal compressed unacknowledged insert message helper."""
582    op_insert, max_bson_size = _insert(
583        collection_name, docs, check_keys, continue_on_error, opts)
584    rid, msg = _compress(2002, op_insert, ctx)
585    return rid, msg, max_bson_size
586
587
588def _insert_uncompressed(collection_name, docs, check_keys,
589            safe, last_error_args, continue_on_error, opts):
590    """Internal insert message helper."""
591    op_insert, max_bson_size = _insert(
592        collection_name, docs, check_keys, continue_on_error, opts)
593    rid, msg = __pack_message(2002, op_insert)
594    if safe:
595        rid, gle, _ = __last_error(collection_name, last_error_args)
596        return rid, msg + gle, max_bson_size
597    return rid, msg, max_bson_size
598if _use_c:
599    _insert_uncompressed = _cmessage._insert_message
600
601
602def insert(collection_name, docs, check_keys,
603           safe, last_error_args, continue_on_error, opts, ctx=None):
604    """**DEPRECATED** Get an **insert** message.
605
606    .. versionchanged:: 3.12
607      This function is deprecated and will be removed in PyMongo 4.0.
608    """
609    if ctx:
610        return _insert_compressed(
611            collection_name, docs, check_keys, continue_on_error, opts, ctx)
612    return _insert_uncompressed(collection_name, docs, check_keys, safe,
613                                last_error_args, continue_on_error, opts)
614
615
616def _update(collection_name, upsert, multi, spec, doc, check_keys, opts):
617    """Get an OP_UPDATE message."""
618    flags = 0
619    if upsert:
620        flags += 1
621    if multi:
622        flags += 2
623    encode = _dict_to_bson  # Make local. Uses extensions.
624    encoded_update = encode(doc, check_keys, opts)
625    return b"".join([
626        _ZERO_32,
627        _make_c_string(collection_name),
628        _pack_int(flags),
629        encode(spec, False, opts),
630        encoded_update]), len(encoded_update)
631
632
633def _update_compressed(
634        collection_name, upsert, multi, spec, doc, check_keys, opts, ctx):
635    """Internal compressed unacknowledged update message helper."""
636    op_update, max_bson_size = _update(
637        collection_name, upsert, multi, spec, doc, check_keys, opts)
638    rid, msg = _compress(2001, op_update, ctx)
639    return rid, msg, max_bson_size
640
641
642def _update_uncompressed(collection_name, upsert, multi, spec,
643                         doc, safe, last_error_args, check_keys, opts):
644    """Internal update message helper."""
645    op_update, max_bson_size = _update(
646        collection_name, upsert, multi, spec, doc, check_keys, opts)
647    rid, msg = __pack_message(2001, op_update)
648    if safe:
649        rid, gle, _ = __last_error(collection_name, last_error_args)
650        return rid, msg + gle, max_bson_size
651    return rid, msg, max_bson_size
652if _use_c:
653    _update_uncompressed = _cmessage._update_message
654
655
656def update(collection_name, upsert, multi, spec,
657           doc, safe, last_error_args, check_keys, opts, ctx=None):
658    """**DEPRECATED** Get an **update** message.
659
660    .. versionchanged:: 3.12
661      This function is deprecated and will be removed in PyMongo 4.0.
662    """
663    if ctx:
664        return _update_compressed(
665            collection_name, upsert, multi, spec, doc, check_keys, opts, ctx)
666    return _update_uncompressed(collection_name, upsert, multi, spec,
667                                doc, safe, last_error_args, check_keys, opts)
668
669
670_pack_op_msg_flags_type = struct.Struct("<IB").pack
671_pack_byte = struct.Struct("<B").pack
672
673
674def _op_msg_no_header(flags, command, identifier, docs, check_keys, opts):
675    """Get a OP_MSG message.
676
677    Note: this method handles multiple documents in a type one payload but
678    it does not perform batch splitting and the total message size is
679    only checked *after* generating the entire message.
680    """
681    # Encode the command document in payload 0 without checking keys.
682    encoded = _dict_to_bson(command, False, opts)
683    flags_type = _pack_op_msg_flags_type(flags, 0)
684    total_size = len(encoded)
685    max_doc_size = 0
686    if identifier:
687        type_one = _pack_byte(1)
688        cstring = _make_c_string(identifier)
689        encoded_docs = [_dict_to_bson(doc, check_keys, opts) for doc in docs]
690        size = len(cstring) + sum(len(doc) for doc in encoded_docs) + 4
691        encoded_size = _pack_int(size)
692        total_size += size
693        max_doc_size = max(len(doc) for doc in encoded_docs)
694        data = ([flags_type, encoded, type_one, encoded_size, cstring] +
695                encoded_docs)
696    else:
697        data = [flags_type, encoded]
698    return b''.join(data), total_size, max_doc_size
699
700
701def _op_msg_compressed(flags, command, identifier, docs, check_keys, opts,
702                       ctx):
703    """Internal OP_MSG message helper."""
704    msg, total_size, max_bson_size = _op_msg_no_header(
705        flags, command, identifier, docs, check_keys, opts)
706    rid, msg = _compress(2013, msg, ctx)
707    return rid, msg, total_size, max_bson_size
708
709
710def _op_msg_uncompressed(flags, command, identifier, docs, check_keys, opts):
711    """Internal compressed OP_MSG message helper."""
712    data, total_size, max_bson_size = _op_msg_no_header(
713        flags, command, identifier, docs, check_keys, opts)
714    request_id, op_message = __pack_message(2013, data)
715    return request_id, op_message, total_size, max_bson_size
716if _use_c:
717    _op_msg_uncompressed = _cmessage._op_msg
718
719
720def _op_msg(flags, command, dbname, read_preference, secondary_ok, check_keys,
721            opts, ctx=None):
722    """Get a OP_MSG message."""
723    command['$db'] = dbname
724    # getMore commands do not send $readPreference.
725    if read_preference is not None and "$readPreference" not in command:
726        if secondary_ok and not read_preference.mode:
727            command["$readPreference"] = (
728                ReadPreference.PRIMARY_PREFERRED.document)
729        else:
730            command["$readPreference"] = read_preference.document
731    name = next(iter(command))
732    try:
733        identifier = _FIELD_MAP.get(name)
734        docs = command.pop(identifier)
735    except KeyError:
736        identifier = ""
737        docs = None
738    try:
739        if ctx:
740            return _op_msg_compressed(
741                flags, command, identifier, docs, check_keys, opts, ctx)
742        return _op_msg_uncompressed(
743            flags, command, identifier, docs, check_keys, opts)
744    finally:
745        # Add the field back to the command.
746        if identifier:
747            command[identifier] = docs
748
749
750def _query(options, collection_name, num_to_skip,
751           num_to_return, query, field_selector, opts, check_keys):
752    """Get an OP_QUERY message."""
753    encoded = _dict_to_bson(query, check_keys, opts)
754    if field_selector:
755        efs = _dict_to_bson(field_selector, False, opts)
756    else:
757        efs = b""
758    max_bson_size = max(len(encoded), len(efs))
759    return b"".join([
760        _pack_int(options),
761        _make_c_string(collection_name),
762        _pack_int(num_to_skip),
763        _pack_int(num_to_return),
764        encoded,
765        efs]), max_bson_size
766
767
768def _query_compressed(options, collection_name, num_to_skip,
769                      num_to_return, query, field_selector,
770                      opts, check_keys=False, ctx=None):
771    """Internal compressed query message helper."""
772    op_query, max_bson_size = _query(
773        options,
774        collection_name,
775        num_to_skip,
776        num_to_return,
777        query,
778        field_selector,
779        opts,
780        check_keys)
781    rid, msg = _compress(2004, op_query, ctx)
782    return rid, msg, max_bson_size
783
784
785def _query_uncompressed(options, collection_name, num_to_skip,
786          num_to_return, query, field_selector, opts, check_keys=False):
787    """Internal query message helper."""
788    op_query, max_bson_size = _query(
789        options,
790        collection_name,
791        num_to_skip,
792        num_to_return,
793        query,
794        field_selector,
795        opts,
796        check_keys)
797    rid, msg = __pack_message(2004, op_query)
798    return rid, msg, max_bson_size
799if _use_c:
800    _query_uncompressed = _cmessage._query_message
801
802
803def query(options, collection_name, num_to_skip, num_to_return,
804          query, field_selector, opts, check_keys=False, ctx=None):
805    """**DEPRECATED** Get a **query** message.
806
807    .. versionchanged:: 3.12
808      This function is deprecated and will be removed in PyMongo 4.0.
809    """
810    if ctx:
811        return _query_compressed(options, collection_name, num_to_skip,
812                                 num_to_return, query, field_selector,
813                                 opts, check_keys, ctx)
814    return _query_uncompressed(options, collection_name, num_to_skip,
815                               num_to_return, query, field_selector, opts,
816                               check_keys)
817
818
819_pack_long_long = struct.Struct("<q").pack
820
821
822def _get_more(collection_name, num_to_return, cursor_id):
823    """Get an OP_GET_MORE message."""
824    return b"".join([
825        _ZERO_32,
826        _make_c_string(collection_name),
827        _pack_int(num_to_return),
828        _pack_long_long(cursor_id)])
829
830
831def _get_more_compressed(collection_name, num_to_return, cursor_id, ctx):
832    """Internal compressed getMore message helper."""
833    return _compress(
834        2005, _get_more(collection_name, num_to_return, cursor_id), ctx)
835
836
837def _get_more_uncompressed(collection_name, num_to_return, cursor_id):
838    """Internal getMore message helper."""
839    return __pack_message(
840        2005, _get_more(collection_name, num_to_return, cursor_id))
841if _use_c:
842    _get_more_uncompressed = _cmessage._get_more_message
843
844
845def get_more(collection_name, num_to_return, cursor_id, ctx=None):
846    """**DEPRECATED** Get a **getMore** message.
847
848    .. versionchanged:: 3.12
849      This function is deprecated and will be removed in PyMongo 4.0.
850    """
851    if ctx:
852        return _get_more_compressed(
853            collection_name, num_to_return, cursor_id, ctx)
854    return _get_more_uncompressed(collection_name, num_to_return, cursor_id)
855
856
857def _delete(collection_name, spec, opts, flags):
858    """Get an OP_DELETE message."""
859    encoded = _dict_to_bson(spec, False, opts)  # Uses extensions.
860    return b"".join([
861        _ZERO_32,
862        _make_c_string(collection_name),
863        _pack_int(flags),
864        encoded]), len(encoded)
865
866
867def _delete_compressed(collection_name, spec, opts, flags, ctx):
868    """Internal compressed unacknowledged delete message helper."""
869    op_delete, max_bson_size = _delete(collection_name, spec, opts, flags)
870    rid, msg = _compress(2006, op_delete, ctx)
871    return rid, msg, max_bson_size
872
873
874def _delete_uncompressed(
875        collection_name, spec, safe, last_error_args, opts, flags=0):
876    """Internal delete message helper."""
877    op_delete, max_bson_size = _delete(collection_name, spec, opts, flags)
878    rid, msg = __pack_message(2006, op_delete)
879    if safe:
880        rid, gle, _ = __last_error(collection_name, last_error_args)
881        return rid, msg + gle, max_bson_size
882    return rid, msg, max_bson_size
883
884
885def delete(
886        collection_name, spec, safe, last_error_args, opts, flags=0, ctx=None):
887    """**DEPRECATED** Get a **delete** message.
888
889    `opts` is a CodecOptions. `flags` is a bit vector that may contain
890    the SingleRemove flag or not:
891
892    http://docs.mongodb.org/meta-driver/latest/legacy/mongodb-wire-protocol/#op-delete
893
894    .. versionchanged:: 3.12
895      This function is deprecated and will be removed in PyMongo 4.0.
896    """
897    if ctx:
898        return _delete_compressed(collection_name, spec, opts, flags, ctx)
899    return _delete_uncompressed(
900        collection_name, spec, safe, last_error_args, opts, flags)
901
902
903def kill_cursors(cursor_ids):
904    """**DEPRECATED** Get a **killCursors** message.
905
906    .. versionchanged:: 3.12
907      This function is deprecated and will be removed in PyMongo 4.0.
908    """
909    num_cursors = len(cursor_ids)
910    pack = struct.Struct("<ii" + ("q" * num_cursors)).pack
911    op_kill_cursors = pack(0, num_cursors, *cursor_ids)
912    return __pack_message(2007, op_kill_cursors)
913
914
915class _BulkWriteContext(object):
916    """A wrapper around SocketInfo for use with write splitting functions."""
917
918    __slots__ = ('db_name', 'command', 'sock_info', 'op_id',
919                 'name', 'field', 'publish', 'start_time', 'listeners',
920                 'session', 'compress', 'op_type', 'codec')
921
922    def __init__(self, database_name, command, sock_info, operation_id,
923                 listeners, session, op_type, codec):
924        self.db_name = database_name
925        self.command = command
926        self.sock_info = sock_info
927        self.op_id = operation_id
928        self.listeners = listeners
929        self.publish = listeners.enabled_for_commands
930        self.name = next(iter(command))
931        self.field = _FIELD_MAP[self.name]
932        self.start_time = datetime.datetime.now() if self.publish else None
933        self.session = session
934        self.compress = True if sock_info.compression_context else False
935        self.op_type = op_type
936        self.codec = codec
937        sock_info.add_server_api(command)
938
939    def _batch_command(self, docs):
940        namespace = self.db_name + '.$cmd'
941        request_id, msg, to_send = _do_bulk_write_command(
942            namespace, self.op_type, self.command, docs, self.check_keys,
943            self.codec, self)
944        if not to_send:
945            raise InvalidOperation("cannot do an empty bulk write")
946        return request_id, msg, to_send
947
948    def execute(self, docs, client):
949        request_id, msg, to_send = self._batch_command(docs)
950        result = self.write_command(request_id, msg, to_send)
951        client._process_response(result, self.session)
952        return result, to_send
953
954    def execute_unack(self, docs, client):
955        request_id, msg, to_send = self._batch_command(docs)
956        # Though this isn't strictly a "legacy" write, the helper
957        # handles publishing commands and sending our message
958        # without receiving a result. Send 0 for max_doc_size
959        # to disable size checking. Size checking is handled while
960        # the documents are encoded to BSON.
961        self.legacy_write(request_id, msg, 0, False, to_send)
962        return to_send
963
964    @property
965    def check_keys(self):
966        """Should we check keys for this operation type?"""
967        return False
968
969    @property
970    def max_bson_size(self):
971        """A proxy for SockInfo.max_bson_size."""
972        return self.sock_info.max_bson_size
973
974    @property
975    def max_message_size(self):
976        """A proxy for SockInfo.max_message_size."""
977        if self.compress:
978            # Subtract 16 bytes for the message header.
979            return self.sock_info.max_message_size - 16
980        return self.sock_info.max_message_size
981
982    @property
983    def max_write_batch_size(self):
984        """A proxy for SockInfo.max_write_batch_size."""
985        return self.sock_info.max_write_batch_size
986
987    @property
988    def max_split_size(self):
989        """The maximum size of a BSON command before batch splitting."""
990        return self.max_bson_size
991
992    def legacy_bulk_insert(
993            self, request_id, msg, max_doc_size, acknowledged, docs, compress):
994        if compress:
995            request_id, msg = _compress(
996                2002, msg, self.sock_info.compression_context)
997        return self.legacy_write(
998            request_id, msg, max_doc_size, acknowledged, docs)
999
1000    def legacy_write(self, request_id, msg, max_doc_size, acknowledged, docs):
1001        """A proxy for SocketInfo.legacy_write that handles event publishing.
1002        """
1003        if self.publish:
1004            duration = datetime.datetime.now() - self.start_time
1005            cmd = self._start(request_id, docs)
1006            start = datetime.datetime.now()
1007        try:
1008            result = self.sock_info.legacy_write(
1009                request_id, msg, max_doc_size, acknowledged)
1010            if self.publish:
1011                duration = (datetime.datetime.now() - start) + duration
1012                if result is not None:
1013                    reply = _convert_write_result(self.name, cmd, result)
1014                else:
1015                    # Comply with APM spec.
1016                    reply = {'ok': 1}
1017                self._succeed(request_id, reply, duration)
1018        except Exception as exc:
1019            if self.publish:
1020                duration = (datetime.datetime.now() - start) + duration
1021                if isinstance(exc, OperationFailure):
1022                    failure = _convert_write_result(
1023                        self.name, cmd, exc.details)
1024                elif isinstance(exc, NotPrimaryError):
1025                    failure = exc.details
1026                else:
1027                    failure = _convert_exception(exc)
1028                self._fail(request_id, failure, duration)
1029            raise
1030        finally:
1031            self.start_time = datetime.datetime.now()
1032        return result
1033
1034    def write_command(self, request_id, msg, docs):
1035        """A proxy for SocketInfo.write_command that handles event publishing.
1036        """
1037        if self.publish:
1038            duration = datetime.datetime.now() - self.start_time
1039            self._start(request_id, docs)
1040            start = datetime.datetime.now()
1041        try:
1042            reply = self.sock_info.write_command(request_id, msg)
1043            if self.publish:
1044                duration = (datetime.datetime.now() - start) + duration
1045                self._succeed(request_id, reply, duration)
1046        except Exception as exc:
1047            if self.publish:
1048                duration = (datetime.datetime.now() - start) + duration
1049                if isinstance(exc, (NotPrimaryError, OperationFailure)):
1050                    failure = exc.details
1051                else:
1052                    failure = _convert_exception(exc)
1053                self._fail(request_id, failure, duration)
1054            raise
1055        finally:
1056            self.start_time = datetime.datetime.now()
1057        return reply
1058
1059    def _start(self, request_id, docs):
1060        """Publish a CommandStartedEvent."""
1061        cmd = self.command.copy()
1062        cmd[self.field] = docs
1063        self.listeners.publish_command_start(
1064            cmd, self.db_name,
1065            request_id, self.sock_info.address, self.op_id,
1066            self.sock_info.service_id)
1067        return cmd
1068
1069    def _succeed(self, request_id, reply, duration):
1070        """Publish a CommandSucceededEvent."""
1071        self.listeners.publish_command_success(
1072            duration, reply, self.name,
1073            request_id, self.sock_info.address, self.op_id,
1074            self.sock_info.service_id)
1075
1076    def _fail(self, request_id, failure, duration):
1077        """Publish a CommandFailedEvent."""
1078        self.listeners.publish_command_failure(
1079            duration, failure, self.name,
1080            request_id, self.sock_info.address, self.op_id,
1081            self.sock_info.service_id)
1082
1083
1084# From the Client Side Encryption spec:
1085# Because automatic encryption increases the size of commands, the driver
1086# MUST split bulk writes at a reduced size limit before undergoing automatic
1087# encryption. The write payload MUST be split at 2MiB (2097152).
1088_MAX_SPLIT_SIZE_ENC = 2097152
1089
1090
1091class _EncryptedBulkWriteContext(_BulkWriteContext):
1092    __slots__ = ()
1093
1094    def _batch_command(self, docs):
1095        namespace = self.db_name + '.$cmd'
1096        msg, to_send = _encode_batched_write_command(
1097            namespace, self.op_type, self.command, docs, self.check_keys,
1098            self.codec, self)
1099        if not to_send:
1100            raise InvalidOperation("cannot do an empty bulk write")
1101
1102        # Chop off the OP_QUERY header to get a properly batched write command.
1103        cmd_start = msg.index(b"\x00", 4) + 9
1104        cmd = _inflate_bson(memoryview(msg)[cmd_start:],
1105                            DEFAULT_RAW_BSON_OPTIONS)
1106        return cmd, to_send
1107
1108    def execute(self, docs, client):
1109        cmd, to_send = self._batch_command(docs)
1110        result = self.sock_info.command(
1111            self.db_name, cmd, codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
1112            session=self.session, client=client)
1113        return result, to_send
1114
1115    def execute_unack(self, docs, client):
1116        cmd, to_send = self._batch_command(docs)
1117        self.sock_info.command(
1118            self.db_name, cmd, write_concern=WriteConcern(w=0),
1119            session=self.session, client=client)
1120        return to_send
1121
1122    @property
1123    def max_split_size(self):
1124        """Reduce the batch splitting size."""
1125        return _MAX_SPLIT_SIZE_ENC
1126
1127
1128def _raise_document_too_large(operation, doc_size, max_size):
1129    """Internal helper for raising DocumentTooLarge."""
1130    if operation == "insert":
1131        raise DocumentTooLarge("BSON document too large (%d bytes)"
1132                               " - the connected server supports"
1133                               " BSON document sizes up to %d"
1134                               " bytes." % (doc_size, max_size))
1135    else:
1136        # There's nothing intelligent we can say
1137        # about size for update and delete
1138        raise DocumentTooLarge("%r command document too large" % (operation,))
1139
1140
1141def _do_batched_insert(collection_name, docs, check_keys,
1142                       safe, last_error_args, continue_on_error, opts,
1143                       ctx):
1144    """Insert `docs` using multiple batches.
1145    """
1146    def _insert_message(insert_message, send_safe):
1147        """Build the insert message with header and GLE.
1148        """
1149        request_id, final_message = __pack_message(2002, insert_message)
1150        if send_safe:
1151            request_id, error_message, _ = __last_error(collection_name,
1152                                                        last_error_args)
1153            final_message += error_message
1154        return request_id, final_message
1155
1156    send_safe = safe or not continue_on_error
1157    last_error = None
1158    data = StringIO()
1159    data.write(struct.pack("<i", int(continue_on_error)))
1160    data.write(_make_c_string(collection_name))
1161    message_length = begin_loc = data.tell()
1162    has_docs = False
1163    to_send = []
1164    encode = _dict_to_bson  # Make local
1165    compress = ctx.compress and not (safe or send_safe)
1166    for doc in docs:
1167        encoded = encode(doc, check_keys, opts)
1168        encoded_length = len(encoded)
1169        too_large = (encoded_length > ctx.max_bson_size)
1170
1171        message_length += encoded_length
1172        if message_length < ctx.max_message_size and not too_large:
1173            data.write(encoded)
1174            to_send.append(doc)
1175            has_docs = True
1176            continue
1177
1178        if has_docs:
1179            # We have enough data, send this message.
1180            try:
1181                if compress:
1182                    rid, msg = None, data.getvalue()
1183                else:
1184                    rid, msg = _insert_message(data.getvalue(), send_safe)
1185                ctx.legacy_bulk_insert(
1186                    rid, msg, 0, send_safe, to_send, compress)
1187            # Exception type could be OperationFailure or a subtype
1188            # (e.g. DuplicateKeyError)
1189            except OperationFailure as exc:
1190                # Like it says, continue on error...
1191                if continue_on_error:
1192                    # Store exception details to re-raise after the final batch.
1193                    last_error = exc
1194                # With unacknowledged writes just return at the first error.
1195                elif not safe:
1196                    return
1197                # With acknowledged writes raise immediately.
1198                else:
1199                    raise
1200
1201        if too_large:
1202            _raise_document_too_large(
1203                "insert", encoded_length, ctx.max_bson_size)
1204
1205        message_length = begin_loc + encoded_length
1206        data.seek(begin_loc)
1207        data.truncate()
1208        data.write(encoded)
1209        to_send = [doc]
1210
1211    if not has_docs:
1212        raise InvalidOperation("cannot do an empty bulk insert")
1213
1214    if compress:
1215        request_id, msg = None, data.getvalue()
1216    else:
1217        request_id, msg = _insert_message(data.getvalue(), safe)
1218    ctx.legacy_bulk_insert(request_id, msg, 0, safe, to_send, compress)
1219
1220    # Re-raise any exception stored due to continue_on_error
1221    if last_error is not None:
1222        raise last_error
1223if _use_c:
1224    _do_batched_insert = _cmessage._do_batched_insert
1225
1226# OP_MSG -------------------------------------------------------------
1227
1228
1229_OP_MSG_MAP = {
1230    _INSERT: b'documents\x00',
1231    _UPDATE: b'updates\x00',
1232    _DELETE: b'deletes\x00',
1233}
1234
1235
1236def _batched_op_msg_impl(
1237        operation, command, docs, check_keys, ack, opts, ctx, buf):
1238    """Create a batched OP_MSG write."""
1239    max_bson_size = ctx.max_bson_size
1240    max_write_batch_size = ctx.max_write_batch_size
1241    max_message_size = ctx.max_message_size
1242
1243    flags = b"\x00\x00\x00\x00" if ack else b"\x02\x00\x00\x00"
1244    # Flags
1245    buf.write(flags)
1246
1247    # Type 0 Section
1248    buf.write(b"\x00")
1249    buf.write(_dict_to_bson(command, False, opts))
1250
1251    # Type 1 Section
1252    buf.write(b"\x01")
1253    size_location = buf.tell()
1254    # Save space for size
1255    buf.write(b"\x00\x00\x00\x00")
1256    try:
1257        buf.write(_OP_MSG_MAP[operation])
1258    except KeyError:
1259        raise InvalidOperation('Unknown command')
1260
1261    if operation in (_UPDATE, _DELETE):
1262        check_keys = False
1263
1264    to_send = []
1265    idx = 0
1266    for doc in docs:
1267        # Encode the current operation
1268        value = _dict_to_bson(doc, check_keys, opts)
1269        doc_length = len(value)
1270        new_message_size = buf.tell() + doc_length
1271        # Does first document exceed max_message_size?
1272        doc_too_large = (idx == 0 and (new_message_size > max_message_size))
1273        # When OP_MSG is used unacknowleged we have to check
1274        # document size client side or applications won't be notified.
1275        # Otherwise we let the server deal with documents that are too large
1276        # since ordered=False causes those documents to be skipped instead of
1277        # halting the bulk write operation.
1278        unacked_doc_too_large = (not ack and (doc_length > max_bson_size))
1279        if doc_too_large or unacked_doc_too_large:
1280            write_op = list(_FIELD_MAP.keys())[operation]
1281            _raise_document_too_large(
1282                write_op, len(value), max_bson_size)
1283        # We have enough data, return this batch.
1284        if new_message_size > max_message_size:
1285            break
1286        buf.write(value)
1287        to_send.append(doc)
1288        idx += 1
1289        # We have enough documents, return this batch.
1290        if idx == max_write_batch_size:
1291            break
1292
1293    # Write type 1 section size
1294    length = buf.tell()
1295    buf.seek(size_location)
1296    buf.write(_pack_int(length - size_location))
1297
1298    return to_send, length
1299
1300
1301def _encode_batched_op_msg(
1302        operation, command, docs, check_keys, ack, opts, ctx):
1303    """Encode the next batched insert, update, or delete operation
1304    as OP_MSG.
1305    """
1306    buf = StringIO()
1307
1308    to_send, _ = _batched_op_msg_impl(
1309        operation, command, docs, check_keys, ack, opts, ctx, buf)
1310    return buf.getvalue(), to_send
1311if _use_c:
1312    _encode_batched_op_msg = _cmessage._encode_batched_op_msg
1313
1314
1315def _batched_op_msg_compressed(
1316        operation, command, docs, check_keys, ack, opts, ctx):
1317    """Create the next batched insert, update, or delete operation
1318    with OP_MSG, compressed.
1319    """
1320    data, to_send = _encode_batched_op_msg(
1321        operation, command, docs, check_keys, ack, opts, ctx)
1322
1323    request_id, msg = _compress(
1324        2013,
1325        data,
1326        ctx.sock_info.compression_context)
1327    return request_id, msg, to_send
1328
1329
1330def _batched_op_msg(
1331        operation, command, docs, check_keys, ack, opts, ctx):
1332    """OP_MSG implementation entry point."""
1333    buf = StringIO()
1334
1335    # Save space for message length and request id
1336    buf.write(_ZERO_64)
1337    # responseTo, opCode
1338    buf.write(b"\x00\x00\x00\x00\xdd\x07\x00\x00")
1339
1340    to_send, length = _batched_op_msg_impl(
1341        operation, command, docs, check_keys, ack, opts, ctx, buf)
1342
1343    # Header - request id and message length
1344    buf.seek(4)
1345    request_id = _randint()
1346    buf.write(_pack_int(request_id))
1347    buf.seek(0)
1348    buf.write(_pack_int(length))
1349
1350    return request_id, buf.getvalue(), to_send
1351if _use_c:
1352    _batched_op_msg = _cmessage._batched_op_msg
1353
1354
1355def _do_batched_op_msg(
1356        namespace, operation, command, docs, check_keys, opts, ctx):
1357    """Create the next batched insert, update, or delete operation
1358    using OP_MSG.
1359    """
1360    command['$db'] = namespace.split('.', 1)[0]
1361    if 'writeConcern' in command:
1362        ack = bool(command['writeConcern'].get('w', 1))
1363    else:
1364        ack = True
1365    if ctx.sock_info.compression_context:
1366        return _batched_op_msg_compressed(
1367            operation, command, docs, check_keys, ack, opts, ctx)
1368    return _batched_op_msg(
1369        operation, command, docs, check_keys, ack, opts, ctx)
1370
1371
1372# End OP_MSG -----------------------------------------------------
1373
1374
1375def _batched_write_command_compressed(
1376        namespace, operation, command, docs, check_keys, opts, ctx):
1377    """Create the next batched insert, update, or delete command, compressed.
1378    """
1379    data, to_send = _encode_batched_write_command(
1380        namespace, operation, command, docs, check_keys, opts, ctx)
1381
1382    request_id, msg = _compress(
1383        2004,
1384        data,
1385        ctx.sock_info.compression_context)
1386    return request_id, msg, to_send
1387
1388
1389def _encode_batched_write_command(
1390        namespace, operation, command, docs, check_keys, opts, ctx):
1391    """Encode the next batched insert, update, or delete command.
1392    """
1393    buf = StringIO()
1394
1395    to_send, _ = _batched_write_command_impl(
1396        namespace, operation, command, docs, check_keys, opts, ctx, buf)
1397    return buf.getvalue(), to_send
1398if _use_c:
1399    _encode_batched_write_command = _cmessage._encode_batched_write_command
1400
1401
1402def _batched_write_command(
1403        namespace, operation, command, docs, check_keys, opts, ctx):
1404    """Create the next batched insert, update, or delete command.
1405    """
1406    buf = StringIO()
1407
1408    # Save space for message length and request id
1409    buf.write(_ZERO_64)
1410    # responseTo, opCode
1411    buf.write(b"\x00\x00\x00\x00\xd4\x07\x00\x00")
1412
1413    # Write OP_QUERY write command
1414    to_send, length = _batched_write_command_impl(
1415        namespace, operation, command, docs, check_keys, opts, ctx, buf)
1416
1417    # Header - request id and message length
1418    buf.seek(4)
1419    request_id = _randint()
1420    buf.write(_pack_int(request_id))
1421    buf.seek(0)
1422    buf.write(_pack_int(length))
1423
1424    return request_id, buf.getvalue(), to_send
1425if _use_c:
1426    _batched_write_command = _cmessage._batched_write_command
1427
1428
1429def _do_batched_write_command(
1430        namespace, operation, command, docs, check_keys, opts, ctx):
1431    """Batched write commands entry point."""
1432    if ctx.sock_info.compression_context:
1433        return _batched_write_command_compressed(
1434            namespace, operation, command, docs, check_keys, opts, ctx)
1435    return _batched_write_command(
1436        namespace, operation, command, docs, check_keys, opts, ctx)
1437
1438
1439def _do_bulk_write_command(
1440        namespace, operation, command, docs, check_keys, opts, ctx):
1441    """Bulk write commands entry point."""
1442    if ctx.sock_info.max_wire_version > 5:
1443        return _do_batched_op_msg(
1444            namespace, operation, command, docs, check_keys, opts, ctx)
1445    return _do_batched_write_command(
1446        namespace, operation, command, docs, check_keys, opts, ctx)
1447
1448
1449def _batched_write_command_impl(
1450        namespace, operation, command, docs, check_keys, opts, ctx, buf):
1451    """Create a batched OP_QUERY write command."""
1452    max_bson_size = ctx.max_bson_size
1453    max_write_batch_size = ctx.max_write_batch_size
1454    # Max BSON object size + 16k - 2 bytes for ending NUL bytes.
1455    # Server guarantees there is enough room: SERVER-10643.
1456    max_cmd_size = max_bson_size + _COMMAND_OVERHEAD
1457    max_split_size = ctx.max_split_size
1458
1459    # No options
1460    buf.write(_ZERO_32)
1461    # Namespace as C string
1462    buf.write(b(namespace))
1463    buf.write(_ZERO_8)
1464    # Skip: 0, Limit: -1
1465    buf.write(_SKIPLIM)
1466
1467    # Where to write command document length
1468    command_start = buf.tell()
1469    buf.write(encode(command))
1470
1471    # Start of payload
1472    buf.seek(-1, 2)
1473    # Work around some Jython weirdness.
1474    buf.truncate()
1475    try:
1476        buf.write(_OP_MAP[operation])
1477    except KeyError:
1478        raise InvalidOperation('Unknown command')
1479
1480    if operation in (_UPDATE, _DELETE):
1481        check_keys = False
1482
1483    # Where to write list document length
1484    list_start = buf.tell() - 4
1485    to_send = []
1486    idx = 0
1487    for doc in docs:
1488        # Encode the current operation
1489        key = b(str(idx))
1490        value = encode(doc, check_keys, opts)
1491        # Is there enough room to add this document? max_cmd_size accounts for
1492        # the two trailing null bytes.
1493        doc_too_large = len(value) > max_cmd_size
1494        if doc_too_large:
1495            write_op = list(_FIELD_MAP.keys())[operation]
1496            _raise_document_too_large(
1497                write_op, len(value), max_bson_size)
1498        enough_data = (idx >= 1 and
1499                       (buf.tell() + len(key) + len(value)) >= max_split_size)
1500        enough_documents = (idx >= max_write_batch_size)
1501        if enough_data or enough_documents:
1502            break
1503        buf.write(_BSONOBJ)
1504        buf.write(key)
1505        buf.write(_ZERO_8)
1506        buf.write(value)
1507        to_send.append(doc)
1508        idx += 1
1509
1510    # Finalize the current OP_QUERY message.
1511    # Close list and command documents
1512    buf.write(_ZERO_16)
1513
1514    # Write document lengths and request id
1515    length = buf.tell()
1516    buf.seek(list_start)
1517    buf.write(_pack_int(length - list_start - 1))
1518    buf.seek(command_start)
1519    buf.write(_pack_int(length - command_start))
1520
1521    return to_send, length
1522
1523
1524class _OpReply(object):
1525    """A MongoDB OP_REPLY response message."""
1526
1527    __slots__ = ("flags", "cursor_id", "number_returned", "documents")
1528
1529    UNPACK_FROM = struct.Struct("<iqii").unpack_from
1530    OP_CODE = 1
1531
1532    def __init__(self, flags, cursor_id, number_returned, documents):
1533        self.flags = flags
1534        self.cursor_id = cursor_id
1535        self.number_returned = number_returned
1536        self.documents = documents
1537
1538    def raw_response(self, cursor_id=None, user_fields=None):
1539        """Check the response header from the database, without decoding BSON.
1540
1541        Check the response for errors and unpack.
1542
1543        Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
1544        OperationFailure.
1545
1546        :Parameters:
1547          - `cursor_id` (optional): cursor_id we sent to get this response -
1548            used for raising an informative exception when we get cursor id not
1549            valid at server response.
1550        """
1551        if self.flags & 1:
1552            # Shouldn't get this response if we aren't doing a getMore
1553            if cursor_id is None:
1554                raise ProtocolError("No cursor id for getMore operation")
1555
1556            # Fake a getMore command response. OP_GET_MORE provides no
1557            # document.
1558            msg = "Cursor not found, cursor id: %d" % (cursor_id,)
1559            errobj = {"ok": 0, "errmsg": msg, "code": 43}
1560            raise CursorNotFound(msg, 43, errobj)
1561        elif self.flags & 2:
1562            error_object = bson.BSON(self.documents).decode()
1563            # Fake the ok field if it doesn't exist.
1564            error_object.setdefault("ok", 0)
1565            if error_object["$err"].startswith(HelloCompat.LEGACY_ERROR):
1566                raise NotPrimaryError(error_object["$err"], error_object)
1567            elif error_object.get("code") == 50:
1568                raise ExecutionTimeout(error_object.get("$err"),
1569                                       error_object.get("code"),
1570                                       error_object)
1571            raise OperationFailure("database error: %s" %
1572                                   error_object.get("$err"),
1573                                   error_object.get("code"),
1574                                   error_object)
1575        if self.documents:
1576            return [self.documents]
1577        return []
1578
1579    def unpack_response(self, cursor_id=None,
1580                        codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
1581                        user_fields=None, legacy_response=False):
1582        """Unpack a response from the database and decode the BSON document(s).
1583
1584        Check the response for errors and unpack, returning a dictionary
1585        containing the response data.
1586
1587        Can raise CursorNotFound, NotPrimaryError, ExecutionTimeout, or
1588        OperationFailure.
1589
1590        :Parameters:
1591          - `cursor_id` (optional): cursor_id we sent to get this response -
1592            used for raising an informative exception when we get cursor id not
1593            valid at server response
1594          - `codec_options` (optional): an instance of
1595            :class:`~bson.codec_options.CodecOptions`
1596        """
1597        self.raw_response(cursor_id)
1598        if legacy_response:
1599            return bson.decode_all(self.documents, codec_options)
1600        return bson._decode_all_selective(
1601            self.documents, codec_options, user_fields)
1602
1603    def command_response(self):
1604        """Unpack a command response."""
1605        docs = self.unpack_response()
1606        assert self.number_returned == 1
1607        return docs[0]
1608
1609    def raw_command_response(self):
1610        """Return the bytes of the command response."""
1611        # This should never be called on _OpReply.
1612        raise NotImplementedError
1613
1614    @property
1615    def more_to_come(self):
1616        """Is the moreToCome bit set on this response?"""
1617        return False
1618
1619    @classmethod
1620    def unpack(cls, msg):
1621        """Construct an _OpReply from raw bytes."""
1622        # PYTHON-945: ignore starting_from field.
1623        flags, cursor_id, _, number_returned = cls.UNPACK_FROM(msg)
1624
1625        # Convert Python 3 memoryview to bytes. Note we should call
1626        # memoryview.tobytes() if we start using memoryview in Python 2.7.
1627        documents = bytes(msg[20:])
1628        return cls(flags, cursor_id, number_returned, documents)
1629
1630
1631class _OpMsg(object):
1632    """A MongoDB OP_MSG response message."""
1633
1634    __slots__ = ("flags", "cursor_id", "number_returned", "payload_document")
1635
1636    UNPACK_FROM = struct.Struct("<IBi").unpack_from
1637    OP_CODE = 2013
1638
1639    # Flag bits.
1640    CHECKSUM_PRESENT = 1
1641    MORE_TO_COME = 1 << 1
1642    EXHAUST_ALLOWED = 1 << 16  # Only present on requests.
1643
1644    def __init__(self, flags, payload_document):
1645        self.flags = flags
1646        self.payload_document = payload_document
1647
1648    def raw_response(self, cursor_id=None, user_fields={}):
1649        """
1650        cursor_id is ignored
1651        user_fields is used to determine which fields must not be decoded
1652        """
1653        inflated_response = _decode_selective(
1654            RawBSONDocument(self.payload_document), user_fields,
1655            DEFAULT_RAW_BSON_OPTIONS)
1656        return [inflated_response]
1657
1658    def unpack_response(self, cursor_id=None,
1659                        codec_options=_UNICODE_REPLACE_CODEC_OPTIONS,
1660                        user_fields=None, legacy_response=False):
1661        """Unpack a OP_MSG command response.
1662
1663        :Parameters:
1664          - `cursor_id` (optional): Ignored, for compatibility with _OpReply.
1665          - `codec_options` (optional): an instance of
1666            :class:`~bson.codec_options.CodecOptions`
1667        """
1668        # If _OpMsg is in-use, this cannot be a legacy response.
1669        assert not legacy_response
1670        return bson._decode_all_selective(
1671            self.payload_document, codec_options, user_fields)
1672
1673    def command_response(self):
1674        """Unpack a command response."""
1675        return self.unpack_response()[0]
1676
1677    def raw_command_response(self):
1678        """Return the bytes of the command response."""
1679        return self.payload_document
1680
1681    @property
1682    def more_to_come(self):
1683        """Is the moreToCome bit set on this response?"""
1684        return self.flags & self.MORE_TO_COME
1685
1686    @classmethod
1687    def unpack(cls, msg):
1688        """Construct an _OpMsg from raw bytes."""
1689        flags, first_payload_type, first_payload_size = cls.UNPACK_FROM(msg)
1690        if flags != 0:
1691            if flags & cls.CHECKSUM_PRESENT:
1692                raise ProtocolError(
1693                    "Unsupported OP_MSG flag checksumPresent: "
1694                    "0x%x" % (flags,))
1695
1696            if flags ^ cls.MORE_TO_COME:
1697                raise ProtocolError(
1698                    "Unsupported OP_MSG flags: 0x%x" % (flags,))
1699        if first_payload_type != 0:
1700            raise ProtocolError(
1701                "Unsupported OP_MSG payload type: "
1702                "0x%x" % (first_payload_type,))
1703
1704        if len(msg) != first_payload_size + 5:
1705            raise ProtocolError("Unsupported OP_MSG reply: >1 section")
1706
1707        # Convert Python 3 memoryview to bytes. Note we should call
1708        # memoryview.tobytes() if we start using memoryview in Python 2.7.
1709        payload_document = bytes(msg[5:])
1710        return cls(flags, payload_document)
1711
1712
1713_UNPACK_REPLY = {
1714    _OpReply.OP_CODE: _OpReply.unpack,
1715    _OpMsg.OP_CODE: _OpMsg.unpack,
1716}
1717
1718
1719def _first_batch(sock_info, db, coll, query, ntoreturn,
1720                 secondary_ok, codec_options, read_preference, cmd, listeners):
1721    """Simple query helper for retrieving a first (and possibly only) batch."""
1722    query = _Query(
1723        0, db, coll, 0, query, None, codec_options,
1724        read_preference, ntoreturn, 0, DEFAULT_READ_CONCERN, None, None,
1725        None, None, False)
1726
1727    name = next(iter(cmd))
1728    publish = listeners.enabled_for_commands
1729    if publish:
1730        start = datetime.datetime.now()
1731
1732    request_id, msg, max_doc_size = query.get_message(secondary_ok, sock_info)
1733
1734    if publish:
1735        encoding_duration = datetime.datetime.now() - start
1736        listeners.publish_command_start(
1737            cmd, db, request_id, sock_info.address,
1738            service_id=sock_info.service_id)
1739        start = datetime.datetime.now()
1740
1741    sock_info.send_message(msg, max_doc_size)
1742    reply = sock_info.receive_message(request_id)
1743    try:
1744        docs = reply.unpack_response(None, codec_options)
1745    except Exception as exc:
1746        if publish:
1747            duration = (datetime.datetime.now() - start) + encoding_duration
1748            if isinstance(exc, (NotPrimaryError, OperationFailure)):
1749                failure = exc.details
1750            else:
1751                failure = _convert_exception(exc)
1752            listeners.publish_command_failure(
1753                duration, failure, name, request_id, sock_info.address,
1754                service_id=sock_info.service_id)
1755        raise
1756    # listIndexes
1757    if 'cursor' in cmd:
1758        result = {
1759            u'cursor': {
1760                u'firstBatch': docs,
1761                u'id': reply.cursor_id,
1762                u'ns': u'%s.%s' % (db, coll)
1763            },
1764            u'ok': 1.0
1765        }
1766    # fsyncUnlock, currentOp
1767    else:
1768        result = docs[0] if docs else {}
1769        result[u'ok'] = 1.0
1770    if publish:
1771        duration = (datetime.datetime.now() - start) + encoding_duration
1772        listeners.publish_command_success(
1773            duration, result, name, request_id, sock_info.address,
1774            service_id=sock_info.service_id)
1775
1776    return result
1777