1 
2 /**
3  *    Copyright (C) 2018-present MongoDB, Inc.
4  *
5  *    This program is free software: you can redistribute it and/or modify
6  *    it under the terms of the Server Side Public License, version 1,
7  *    as published by MongoDB, Inc.
8  *
9  *    This program is distributed in the hope that it will be useful,
10  *    but WITHOUT ANY WARRANTY; without even the implied warranty of
11  *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
12  *    Server Side Public License for more details.
13  *
14  *    You should have received a copy of the Server Side Public License
15  *    along with this program. If not, see
16  *    <http://www.mongodb.com/licensing/server-side-public-license>.
17  *
18  *    As a special exception, the copyright holders give permission to link the
19  *    code of portions of this program with the OpenSSL library under certain
20  *    conditions as described in each individual source file and distribute
21  *    linked combinations including the program with the OpenSSL library. You
22  *    must comply with the Server Side Public License in all respects for
23  *    all of the code used other than as permitted herein. If you modify file(s)
24  *    with this exception, you may extend this exception to your version of the
25  *    file(s), but you are not obligated to do so. If you do not wish to do so,
26  *    delete this exception statement from your version. If you delete this
27  *    exception statement from all source files in the program, then also delete
28  *    it in the license file.
29  */
30 
31 #pragma once
32 
33 #include <cstdint>
34 
35 #include "mongo/base/data_type_endian.h"
36 #include "mongo/base/data_view.h"
37 #include "mongo/base/encoded_value_storage.h"
38 #include "mongo/base/static_assert.h"
39 #include "mongo/util/mongoutils/str.h"
40 
41 namespace mongo {
42 
43 /**
44  * Maximum accepted message size on the wire protocol.
45  */
46 const size_t MaxMessageSizeBytes = 48 * 1000 * 1000;
47 
48 enum NetworkOp : int32_t {
49     opInvalid = 0,
50     opReply = 1,     /* reply. responseTo is set. */
51     dbUpdate = 2001, /* update object */
52     dbInsert = 2002,
53     // dbGetByOID = 2003,
54     dbQuery = 2004,
55     dbGetMore = 2005,
56     dbDelete = 2006,
57     dbKillCursors = 2007,
58     // dbCommand_DEPRECATED = 2008, //
59     // dbCommandReply_DEPRECATED = 2009, //
60     dbCommand = 2010,
61     dbCommandReply = 2011,
62     dbCompressed = 2012,
63     dbMsg = 2013,
64 };
65 
isSupportedRequestNetworkOp(NetworkOp op)66 inline bool isSupportedRequestNetworkOp(NetworkOp op) {
67     switch (op) {
68         case dbUpdate:
69         case dbInsert:
70         case dbQuery:
71         case dbGetMore:
72         case dbDelete:
73         case dbKillCursors:
74         case dbCommand:
75         case dbCompressed:
76         case dbMsg:
77             return true;
78         case dbCommandReply:
79         case opReply:
80         default:
81             return false;
82     }
83 }
84 
85 enum class LogicalOp {
86     opInvalid,
87     opUpdate,
88     opInsert,
89     opQuery,
90     opGetMore,
91     opDelete,
92     opKillCursors,
93     opCommand,
94     opCompressed,
95 };
96 
networkOpToLogicalOp(NetworkOp networkOp)97 inline LogicalOp networkOpToLogicalOp(NetworkOp networkOp) {
98     switch (networkOp) {
99         case dbUpdate:
100             return LogicalOp::opUpdate;
101         case dbInsert:
102             return LogicalOp::opInsert;
103         case dbQuery:
104             return LogicalOp::opQuery;
105         case dbGetMore:
106             return LogicalOp::opGetMore;
107         case dbDelete:
108             return LogicalOp::opDelete;
109         case dbKillCursors:
110             return LogicalOp::opKillCursors;
111         case dbMsg:
112         case dbCommand:
113             return LogicalOp::opCommand;
114         case dbCompressed:
115             return LogicalOp::opCompressed;
116         default:
117             int op = int(networkOp);
118             massert(34348, str::stream() << "cannot translate opcode " << op, !op);
119             return LogicalOp::opInvalid;
120     }
121 }
122 
networkOpToString(NetworkOp networkOp)123 inline const char* networkOpToString(NetworkOp networkOp) {
124     switch (networkOp) {
125         case opInvalid:
126             return "none";
127         case opReply:
128             return "reply";
129         case dbUpdate:
130             return "update";
131         case dbInsert:
132             return "insert";
133         case dbQuery:
134             return "query";
135         case dbGetMore:
136             return "getmore";
137         case dbDelete:
138             return "remove";
139         case dbKillCursors:
140             return "killcursors";
141         case dbCommand:
142             return "command";
143         case dbCommandReply:
144             return "commandReply";
145         case dbCompressed:
146             return "compressed";
147         case dbMsg:
148             return "msg";
149         default:
150             int op = static_cast<int>(networkOp);
151             massert(16141, str::stream() << "cannot translate opcode " << op, !op);
152             return "";
153     }
154 }
155 
logicalOpToString(LogicalOp logicalOp)156 inline const char* logicalOpToString(LogicalOp logicalOp) {
157     switch (logicalOp) {
158         case LogicalOp::opInvalid:
159             return "none";
160         case LogicalOp::opUpdate:
161             return "update";
162         case LogicalOp::opInsert:
163             return "insert";
164         case LogicalOp::opQuery:
165             return "query";
166         case LogicalOp::opGetMore:
167             return "getmore";
168         case LogicalOp::opDelete:
169             return "remove";
170         case LogicalOp::opKillCursors:
171             return "killcursors";
172         case LogicalOp::opCommand:
173             return "command";
174         case LogicalOp::opCompressed:
175             return "compressed";
176         default:
177             MONGO_UNREACHABLE;
178     }
179 }
180 
181 namespace MSGHEADER {
182 
183 #pragma pack(1)
184 /**
185  * See http://dochub.mongodb.org/core/mongowireprotocol
186  */
187 struct Layout {
188     int32_t messageLength;  // total message size, including this
189     int32_t requestID;      // identifier for this message
190     int32_t responseTo;     // requestID from the original request
191     //   (used in responses from db)
192     int32_t opCode;
193 };
194 #pragma pack()
195 
196 class ConstView {
197 public:
198     typedef ConstDataView view_type;
199 
ConstView(const char * data)200     ConstView(const char* data) : _data(data) {}
201 
view2ptr()202     const char* view2ptr() const {
203         return data().view();
204     }
205 
getMessageLength()206     int32_t getMessageLength() const {
207         return data().read<LittleEndian<int32_t>>(offsetof(Layout, messageLength));
208     }
209 
getRequestMsgId()210     int32_t getRequestMsgId() const {
211         return data().read<LittleEndian<int32_t>>(offsetof(Layout, requestID));
212     }
213 
getResponseToMsgId()214     int32_t getResponseToMsgId() const {
215         return data().read<LittleEndian<int32_t>>(offsetof(Layout, responseTo));
216     }
217 
getOpCode()218     int32_t getOpCode() const {
219         return data().read<LittleEndian<int32_t>>(offsetof(Layout, opCode));
220     }
221 
222 protected:
data()223     const view_type& data() const {
224         return _data;
225     }
226 
227 private:
228     view_type _data;
229 };
230 
231 class View : public ConstView {
232 public:
233     typedef DataView view_type;
234 
View(char * data)235     View(char* data) : ConstView(data) {}
236 
237     using ConstView::view2ptr;
view2ptr()238     char* view2ptr() {
239         return data().view();
240     }
241 
setMessageLength(int32_t value)242     void setMessageLength(int32_t value) {
243         data().write(tagLittleEndian(value), offsetof(Layout, messageLength));
244     }
245 
setRequestMsgId(int32_t value)246     void setRequestMsgId(int32_t value) {
247         data().write(tagLittleEndian(value), offsetof(Layout, requestID));
248     }
249 
setResponseToMsgId(int32_t value)250     void setResponseToMsgId(int32_t value) {
251         data().write(tagLittleEndian(value), offsetof(Layout, responseTo));
252     }
253 
setOpCode(int32_t value)254     void setOpCode(int32_t value) {
255         data().write(tagLittleEndian(value), offsetof(Layout, opCode));
256     }
257 
258 private:
data()259     view_type data() const {
260         return const_cast<char*>(ConstView::view2ptr());
261     }
262 };
263 
264 class Value : public EncodedValueStorage<Layout, ConstView, View> {
265 public:
Value()266     Value() {
267         MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout));
268     }
269 
Value(ZeroInitTag_t zit)270     Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {}
271 };
272 
273 }  // namespace MSGHEADER
274 
275 namespace MsgData {
276 
277 #pragma pack(1)
278 struct Layout {
279     MSGHEADER::Layout header;
280     char data[4];
281 };
282 #pragma pack()
283 
284 class ConstView {
285 public:
ConstView(const char * storage)286     ConstView(const char* storage) : _storage(storage) {}
287 
view2ptr()288     const char* view2ptr() const {
289         return storage().view();
290     }
291 
getLen()292     int32_t getLen() const {
293         return header().getMessageLength();
294     }
295 
getId()296     int32_t getId() const {
297         return header().getRequestMsgId();
298     }
299 
getResponseToMsgId()300     int32_t getResponseToMsgId() const {
301         return header().getResponseToMsgId();
302     }
303 
getNetworkOp()304     NetworkOp getNetworkOp() const {
305         return NetworkOp(header().getOpCode());
306     }
307 
data()308     const char* data() const {
309         return storage().view(offsetof(Layout, data));
310     }
311 
valid()312     bool valid() const {
313         if (getLen() <= 0 || getLen() > (4 * BSONObjMaxInternalSize))
314             return false;
315         if (getNetworkOp() < 0 || getNetworkOp() > 30000)
316             return false;
317         return true;
318     }
319 
getCursor()320     int64_t getCursor() const {
321         verify(getResponseToMsgId() > 0);
322         verify(getNetworkOp() == opReply);
323         return ConstDataView(data() + sizeof(int32_t)).read<LittleEndian<int64_t>>();
324     }
325 
326     int dataLen() const;  // len without header
327 
328 protected:
storage()329     const ConstDataView& storage() const {
330         return _storage;
331     }
332 
header()333     MSGHEADER::ConstView header() const {
334         return storage().view(offsetof(Layout, header));
335     }
336 
337 private:
338     ConstDataView _storage;
339 };
340 
341 class View : public ConstView {
342 public:
View(char * storage)343     View(char* storage) : ConstView(storage) {}
344 
345     using ConstView::view2ptr;
view2ptr()346     char* view2ptr() {
347         return storage().view();
348     }
349 
setLen(int value)350     void setLen(int value) {
351         return header().setMessageLength(value);
352     }
353 
setId(int32_t value)354     void setId(int32_t value) {
355         return header().setRequestMsgId(value);
356     }
357 
setResponseToMsgId(int32_t value)358     void setResponseToMsgId(int32_t value) {
359         return header().setResponseToMsgId(value);
360     }
361 
setOperation(int value)362     void setOperation(int value) {
363         return header().setOpCode(value);
364     }
365 
366     using ConstView::data;
data()367     char* data() {
368         return storage().view(offsetof(Layout, data));
369     }
370 
371 private:
storage()372     DataView storage() const {
373         return const_cast<char*>(ConstView::view2ptr());
374     }
375 
header()376     MSGHEADER::View header() const {
377         return storage().view(offsetof(Layout, header));
378     }
379 };
380 
381 class Value : public EncodedValueStorage<Layout, ConstView, View> {
382 public:
Value()383     Value() {
384         MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout));
385     }
386 
Value(ZeroInitTag_t zit)387     Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {}
388 };
389 
390 const int MsgDataHeaderSize = sizeof(Value) - 4;
391 
dataLen()392 inline int ConstView::dataLen() const {
393     return getLen() - MsgDataHeaderSize;
394 }
395 
396 }  // namespace MsgData
397 
398 class Message {
399 public:
400     Message() = default;
Message(SharedBuffer data)401     explicit Message(SharedBuffer data) : _buf(std::move(data)) {}
402 
header()403     MsgData::View header() const {
404         verify(!empty());
405         return _buf.get();
406     }
407 
operation()408     NetworkOp operation() const {
409         return header().getNetworkOp();
410     }
411 
singleData()412     MsgData::View singleData() const {
413         massert(13273, "single data buffer expected", _buf);
414         return header();
415     }
416 
empty()417     bool empty() const {
418         return !_buf;
419     }
420 
size()421     int size() const {
422         if (_buf) {
423             return MsgData::ConstView(_buf.get()).getLen();
424         }
425         return 0;
426     }
427 
dataSize()428     int dataSize() const {
429         return size() - sizeof(MSGHEADER::Value);
430     }
431 
reset()432     void reset() {
433         _buf = {};
434     }
435 
436     // use to set first buffer if empty
setData(SharedBuffer buf)437     void setData(SharedBuffer buf) {
438         verify(empty());
439         _buf = std::move(buf);
440     }
setData(int operation,const char * msgtxt)441     void setData(int operation, const char* msgtxt) {
442         setData(operation, msgtxt, strlen(msgtxt) + 1);
443     }
setData(int operation,const char * msgdata,size_t len)444     void setData(int operation, const char* msgdata, size_t len) {
445         verify(empty());
446         size_t dataLen = len + sizeof(MsgData::Value) - 4;
447         _buf = SharedBuffer::allocate(dataLen);
448         MsgData::View d = _buf.get();
449         if (len)
450             memcpy(d.data(), msgdata, len);
451         d.setLen(dataLen);
452         d.setOperation(operation);
453     }
454 
buf()455     char* buf() {
456         return _buf.get();
457     }
458 
buf()459     const char* buf() const {
460         return _buf.get();
461     }
462 
sharedBuffer()463     SharedBuffer sharedBuffer() {
464         return _buf;
465     }
466 
sharedBuffer()467     ConstSharedBuffer sharedBuffer() const {
468         return _buf;
469     }
470 
471 private:
472     SharedBuffer _buf;
473 };
474 
475 /**
476  * Returns an always incrementing value to be used to assign to the next received network message.
477  */
478 int32_t nextMessageId();
479 
480 }  // namespace mongo
481