1# Copyright 2015-present MongoDB, Inc.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Internal network layer helper methods."""
16
17import datetime
18import errno
19import socket
20import struct
21
22
23from bson import _decode_all_selective
24from bson.py3compat import PY3
25
26from pymongo import helpers, message
27from pymongo.common import MAX_MESSAGE_SIZE
28from pymongo.compression_support import decompress, _NO_COMPRESSION
29from pymongo.errors import (AutoReconnect,
30                            NotPrimaryError,
31                            OperationFailure,
32                            ProtocolError,
33                            NetworkTimeout,
34                            _OperationCancelled)
35from pymongo.message import _UNPACK_REPLY, _OpMsg
36from pymongo.monotonic import time
37from pymongo.socket_checker import _errno_from_exception
38
39
40_UNPACK_HEADER = struct.Struct("<iiii").unpack
41
42
43def command(sock_info, dbname, spec, secondary_ok, is_mongos,
44            read_preference, codec_options, session, client, check=True,
45            allowable_errors=None, address=None,
46            check_keys=False, listeners=None, max_bson_size=None,
47            read_concern=None,
48            parse_write_concern_error=False,
49            collation=None,
50            compression_ctx=None,
51            use_op_msg=False,
52            unacknowledged=False,
53            user_fields=None,
54            exhaust_allowed=False):
55    """Execute a command over the socket, or raise socket.error.
56
57    :Parameters:
58      - `sock`: a raw socket instance
59      - `dbname`: name of the database on which to run the command
60      - `spec`: a command document as an ordered dict type, eg SON.
61      - `secondary_ok`: whether to set the secondaryOkay wire protocol bit
62      - `is_mongos`: are we connected to a mongos?
63      - `read_preference`: a read preference
64      - `codec_options`: a CodecOptions instance
65      - `session`: optional ClientSession instance.
66      - `client`: optional MongoClient instance for updating $clusterTime.
67      - `check`: raise OperationFailure if there are errors
68      - `allowable_errors`: errors to ignore if `check` is True
69      - `address`: the (host, port) of `sock`
70      - `check_keys`: if True, check `spec` for invalid keys
71      - `listeners`: An instance of :class:`~pymongo.monitoring.EventListeners`
72      - `max_bson_size`: The maximum encoded bson size for this server
73      - `read_concern`: The read concern for this command.
74      - `parse_write_concern_error`: Whether to parse the ``writeConcernError``
75        field in the command response.
76      - `collation`: The collation for this command.
77      - `compression_ctx`: optional compression Context.
78      - `use_op_msg`: True if we should use OP_MSG.
79      - `unacknowledged`: True if this is an unacknowledged command.
80      - `user_fields` (optional): Response fields that should be decoded
81        using the TypeDecoders from codec_options, passed to
82        bson._decode_all_selective.
83      - `exhaust_allowed`: True if we should enable OP_MSG exhaustAllowed.
84    """
85    name = next(iter(spec))
86    ns = dbname + '.$cmd'
87    flags = 4 if secondary_ok else 0
88
89    # Publish the original command document, perhaps with lsid and $clusterTime.
90    orig = spec
91    if is_mongos and not use_op_msg:
92        spec = message._maybe_add_read_preference(spec, read_preference)
93    if read_concern and not (session and session.in_transaction):
94        if read_concern.level:
95            spec['readConcern'] = read_concern.document
96        if session:
97            session._update_read_concern(spec, sock_info)
98    if collation is not None:
99        spec['collation'] = collation
100
101    publish = listeners is not None and listeners.enabled_for_commands
102    if publish:
103        start = datetime.datetime.now()
104
105    if compression_ctx and name.lower() in _NO_COMPRESSION:
106        compression_ctx = None
107
108    if (client and client._encrypter and
109            not client._encrypter._bypass_auto_encryption):
110        spec = orig = client._encrypter.encrypt(
111            dbname, spec, check_keys, codec_options)
112        # We already checked the keys, no need to do it again.
113        check_keys = False
114
115    if use_op_msg:
116        flags = _OpMsg.MORE_TO_COME if unacknowledged else 0
117        flags |= _OpMsg.EXHAUST_ALLOWED if exhaust_allowed else 0
118        request_id, msg, size, max_doc_size = message._op_msg(
119            flags, spec, dbname, read_preference, secondary_ok, check_keys,
120            codec_options, ctx=compression_ctx)
121        # If this is an unacknowledged write then make sure the encoded doc(s)
122        # are small enough, otherwise rely on the server to return an error.
123        if (unacknowledged and max_bson_size is not None and
124                max_doc_size > max_bson_size):
125            message._raise_document_too_large(name, size, max_bson_size)
126    else:
127        request_id, msg, size = message.query(
128            flags, ns, 0, -1, spec, None, codec_options, check_keys,
129            compression_ctx)
130
131    if (max_bson_size is not None
132            and size > max_bson_size + message._COMMAND_OVERHEAD):
133        message._raise_document_too_large(
134            name, size, max_bson_size + message._COMMAND_OVERHEAD)
135
136    if publish:
137        encoding_duration = datetime.datetime.now() - start
138        listeners.publish_command_start(orig, dbname, request_id, address,
139                                        service_id=sock_info.service_id)
140        start = datetime.datetime.now()
141
142    try:
143        sock_info.sock.sendall(msg)
144        if use_op_msg and unacknowledged:
145            # Unacknowledged, fake a successful command response.
146            reply = None
147            response_doc = {"ok": 1}
148        else:
149            reply = receive_message(sock_info, request_id)
150            sock_info.more_to_come = reply.more_to_come
151            unpacked_docs = reply.unpack_response(
152                codec_options=codec_options, user_fields=user_fields)
153
154            response_doc = unpacked_docs[0]
155            if client:
156                client._process_response(response_doc, session)
157            if check:
158                helpers._check_command_response(
159                    response_doc, sock_info.max_wire_version, allowable_errors,
160                    parse_write_concern_error=parse_write_concern_error)
161    except Exception as exc:
162        if publish:
163            duration = (datetime.datetime.now() - start) + encoding_duration
164            if isinstance(exc, (NotPrimaryError, OperationFailure)):
165                failure = exc.details
166            else:
167                failure = message._convert_exception(exc)
168            listeners.publish_command_failure(
169                duration, failure, name, request_id, address,
170                service_id=sock_info.service_id)
171        raise
172    if publish:
173        duration = (datetime.datetime.now() - start) + encoding_duration
174        listeners.publish_command_success(
175            duration, response_doc, name, request_id, address,
176            service_id=sock_info.service_id)
177
178    if client and client._encrypter and reply:
179        decrypted = client._encrypter.decrypt(reply.raw_command_response())
180        response_doc = _decode_all_selective(decrypted, codec_options,
181                                             user_fields)[0]
182
183    return response_doc
184
185_UNPACK_COMPRESSION_HEADER = struct.Struct("<iiB").unpack
186
187def receive_message(sock_info, request_id, max_message_size=MAX_MESSAGE_SIZE):
188    """Receive a raw BSON message or raise socket.error."""
189    timeout = sock_info.sock.gettimeout()
190    if timeout:
191        deadline = time() + timeout
192    else:
193        deadline = None
194    # Ignore the response's request id.
195    length, _, response_to, op_code = _UNPACK_HEADER(
196        _receive_data_on_socket(sock_info, 16, deadline))
197    # No request_id for exhaust cursor "getMore".
198    if request_id is not None:
199        if request_id != response_to:
200            raise ProtocolError("Got response id %r but expected "
201                                "%r" % (response_to, request_id))
202    if length <= 16:
203        raise ProtocolError("Message length (%r) not longer than standard "
204                            "message header size (16)" % (length,))
205    if length > max_message_size:
206        raise ProtocolError("Message length (%r) is larger than server max "
207                            "message size (%r)" % (length, max_message_size))
208    if op_code == 2012:
209        op_code, _, compressor_id = _UNPACK_COMPRESSION_HEADER(
210            _receive_data_on_socket(sock_info, 9, deadline))
211        data = decompress(
212            _receive_data_on_socket(sock_info, length - 25, deadline),
213            compressor_id)
214    else:
215        data = _receive_data_on_socket(sock_info, length - 16, deadline)
216
217    try:
218        unpack_reply = _UNPACK_REPLY[op_code]
219    except KeyError:
220        raise ProtocolError("Got opcode %r but expected "
221                            "%r" % (op_code, _UNPACK_REPLY.keys()))
222    return unpack_reply(data)
223
224
225_POLL_TIMEOUT = 0.5
226
227
228def wait_for_read(sock_info, deadline):
229    """Block until at least one byte is read, or a timeout, or a cancel."""
230    context = sock_info.cancel_context
231    # Only Monitor connections can be cancelled.
232    if context:
233        sock = sock_info.sock
234        while True:
235            # SSLSocket can have buffered data which won't be caught by select.
236            if hasattr(sock, 'pending') and sock.pending() > 0:
237                readable = True
238            else:
239                # Wait up to 500ms for the socket to become readable and then
240                # check for cancellation.
241                if deadline:
242                    timeout = max(min(deadline - time(), _POLL_TIMEOUT), 0.001)
243                else:
244                    timeout = _POLL_TIMEOUT
245                readable = sock_info.socket_checker.select(
246                    sock, read=True, timeout=timeout)
247            if context.cancelled:
248                raise _OperationCancelled('hello cancelled')
249            if readable:
250                return
251            if deadline and time() > deadline:
252                raise socket.timeout("timed out")
253
254# memoryview was introduced in Python 2.7 but we only use it on Python 3
255# because before 2.7.4 the struct module did not support memoryview:
256# https://bugs.python.org/issue10212.
257# In Jython, using slice assignment on a memoryview results in a
258# NullPointerException.
259if not PY3:
260    def _receive_data_on_socket(sock_info, length, deadline):
261        buf = bytearray(length)
262        i = 0
263        while length:
264            try:
265                wait_for_read(sock_info, deadline)
266                chunk = sock_info.sock.recv(length)
267            except (IOError, OSError) as exc:
268                if _errno_from_exception(exc) == errno.EINTR:
269                    continue
270                raise
271            if chunk == b"":
272                raise AutoReconnect("connection closed")
273
274            buf[i:i + len(chunk)] = chunk
275            i += len(chunk)
276            length -= len(chunk)
277
278        return bytes(buf)
279else:
280    def _receive_data_on_socket(sock_info, length, deadline):
281        buf = bytearray(length)
282        mv = memoryview(buf)
283        bytes_read = 0
284        while bytes_read < length:
285            try:
286                wait_for_read(sock_info, deadline)
287                chunk_length = sock_info.sock.recv_into(mv[bytes_read:])
288            except (IOError, OSError) as exc:
289                if _errno_from_exception(exc) == errno.EINTR:
290                    continue
291                raise
292            if chunk_length == 0:
293                raise AutoReconnect("connection closed")
294
295            bytes_read += chunk_length
296
297        return mv
298