1
2 /**
3 * Copyright (C) 2018-present MongoDB, Inc.
4 *
5 * This program is free software: you can redistribute it and/or modify
6 * it under the terms of the Server Side Public License, version 1,
7 * as published by MongoDB, Inc.
8 *
9 * This program is distributed in the hope that it will be useful,
10 * but WITHOUT ANY WARRANTY; without even the implied warranty of
11 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12 * Server Side Public License for more details.
13 *
14 * You should have received a copy of the Server Side Public License
15 * along with this program. If not, see
16 * <http://www.mongodb.com/licensing/server-side-public-license>.
17 *
18 * As a special exception, the copyright holders give permission to link the
19 * code of portions of this program with the OpenSSL library under certain
20 * conditions as described in each individual source file and distribute
21 * linked combinations including the program with the OpenSSL library. You
22 * must comply with the Server Side Public License in all respects for
23 * all of the code used other than as permitted herein. If you modify file(s)
24 * with this exception, you may extend this exception to your version of the
25 * file(s), but you are not obligated to do so. If you do not wish to do so,
26 * delete this exception statement from your version. If you delete this
27 * exception statement from all source files in the program, then also delete
28 * it in the license file.
29 */
30
31 #pragma once
32
33 #include <cstdint>
34
35 #include "mongo/base/data_type_endian.h"
36 #include "mongo/base/data_view.h"
37 #include "mongo/base/encoded_value_storage.h"
38 #include "mongo/base/static_assert.h"
39 #include "mongo/util/mongoutils/str.h"
40
41 namespace mongo {
42
43 /**
44 * Maximum accepted message size on the wire protocol.
45 */
46 const size_t MaxMessageSizeBytes = 48 * 1000 * 1000;
47
48 enum NetworkOp : int32_t {
49 opInvalid = 0,
50 opReply = 1, /* reply. responseTo is set. */
51 dbUpdate = 2001, /* update object */
52 dbInsert = 2002,
53 // dbGetByOID = 2003,
54 dbQuery = 2004,
55 dbGetMore = 2005,
56 dbDelete = 2006,
57 dbKillCursors = 2007,
58 // dbCommand_DEPRECATED = 2008, //
59 // dbCommandReply_DEPRECATED = 2009, //
60 dbCommand = 2010,
61 dbCommandReply = 2011,
62 dbCompressed = 2012,
63 dbMsg = 2013,
64 };
65
isSupportedRequestNetworkOp(NetworkOp op)66 inline bool isSupportedRequestNetworkOp(NetworkOp op) {
67 switch (op) {
68 case dbUpdate:
69 case dbInsert:
70 case dbQuery:
71 case dbGetMore:
72 case dbDelete:
73 case dbKillCursors:
74 case dbCommand:
75 case dbCompressed:
76 case dbMsg:
77 return true;
78 case dbCommandReply:
79 case opReply:
80 default:
81 return false;
82 }
83 }
84
85 enum class LogicalOp {
86 opInvalid,
87 opUpdate,
88 opInsert,
89 opQuery,
90 opGetMore,
91 opDelete,
92 opKillCursors,
93 opCommand,
94 opCompressed,
95 };
96
networkOpToLogicalOp(NetworkOp networkOp)97 inline LogicalOp networkOpToLogicalOp(NetworkOp networkOp) {
98 switch (networkOp) {
99 case dbUpdate:
100 return LogicalOp::opUpdate;
101 case dbInsert:
102 return LogicalOp::opInsert;
103 case dbQuery:
104 return LogicalOp::opQuery;
105 case dbGetMore:
106 return LogicalOp::opGetMore;
107 case dbDelete:
108 return LogicalOp::opDelete;
109 case dbKillCursors:
110 return LogicalOp::opKillCursors;
111 case dbMsg:
112 case dbCommand:
113 return LogicalOp::opCommand;
114 case dbCompressed:
115 return LogicalOp::opCompressed;
116 default:
117 int op = int(networkOp);
118 massert(34348, str::stream() << "cannot translate opcode " << op, !op);
119 return LogicalOp::opInvalid;
120 }
121 }
122
networkOpToString(NetworkOp networkOp)123 inline const char* networkOpToString(NetworkOp networkOp) {
124 switch (networkOp) {
125 case opInvalid:
126 return "none";
127 case opReply:
128 return "reply";
129 case dbUpdate:
130 return "update";
131 case dbInsert:
132 return "insert";
133 case dbQuery:
134 return "query";
135 case dbGetMore:
136 return "getmore";
137 case dbDelete:
138 return "remove";
139 case dbKillCursors:
140 return "killcursors";
141 case dbCommand:
142 return "command";
143 case dbCommandReply:
144 return "commandReply";
145 case dbCompressed:
146 return "compressed";
147 case dbMsg:
148 return "msg";
149 default:
150 int op = static_cast<int>(networkOp);
151 massert(16141, str::stream() << "cannot translate opcode " << op, !op);
152 return "";
153 }
154 }
155
logicalOpToString(LogicalOp logicalOp)156 inline const char* logicalOpToString(LogicalOp logicalOp) {
157 switch (logicalOp) {
158 case LogicalOp::opInvalid:
159 return "none";
160 case LogicalOp::opUpdate:
161 return "update";
162 case LogicalOp::opInsert:
163 return "insert";
164 case LogicalOp::opQuery:
165 return "query";
166 case LogicalOp::opGetMore:
167 return "getmore";
168 case LogicalOp::opDelete:
169 return "remove";
170 case LogicalOp::opKillCursors:
171 return "killcursors";
172 case LogicalOp::opCommand:
173 return "command";
174 case LogicalOp::opCompressed:
175 return "compressed";
176 default:
177 MONGO_UNREACHABLE;
178 }
179 }
180
181 namespace MSGHEADER {
182
183 #pragma pack(1)
184 /**
185 * See http://dochub.mongodb.org/core/mongowireprotocol
186 */
187 struct Layout {
188 int32_t messageLength; // total message size, including this
189 int32_t requestID; // identifier for this message
190 int32_t responseTo; // requestID from the original request
191 // (used in responses from db)
192 int32_t opCode;
193 };
194 #pragma pack()
195
196 class ConstView {
197 public:
198 typedef ConstDataView view_type;
199
ConstView(const char * data)200 ConstView(const char* data) : _data(data) {}
201
view2ptr()202 const char* view2ptr() const {
203 return data().view();
204 }
205
getMessageLength()206 int32_t getMessageLength() const {
207 return data().read<LittleEndian<int32_t>>(offsetof(Layout, messageLength));
208 }
209
getRequestMsgId()210 int32_t getRequestMsgId() const {
211 return data().read<LittleEndian<int32_t>>(offsetof(Layout, requestID));
212 }
213
getResponseToMsgId()214 int32_t getResponseToMsgId() const {
215 return data().read<LittleEndian<int32_t>>(offsetof(Layout, responseTo));
216 }
217
getOpCode()218 int32_t getOpCode() const {
219 return data().read<LittleEndian<int32_t>>(offsetof(Layout, opCode));
220 }
221
222 protected:
data()223 const view_type& data() const {
224 return _data;
225 }
226
227 private:
228 view_type _data;
229 };
230
231 class View : public ConstView {
232 public:
233 typedef DataView view_type;
234
View(char * data)235 View(char* data) : ConstView(data) {}
236
237 using ConstView::view2ptr;
view2ptr()238 char* view2ptr() {
239 return data().view();
240 }
241
setMessageLength(int32_t value)242 void setMessageLength(int32_t value) {
243 data().write(tagLittleEndian(value), offsetof(Layout, messageLength));
244 }
245
setRequestMsgId(int32_t value)246 void setRequestMsgId(int32_t value) {
247 data().write(tagLittleEndian(value), offsetof(Layout, requestID));
248 }
249
setResponseToMsgId(int32_t value)250 void setResponseToMsgId(int32_t value) {
251 data().write(tagLittleEndian(value), offsetof(Layout, responseTo));
252 }
253
setOpCode(int32_t value)254 void setOpCode(int32_t value) {
255 data().write(tagLittleEndian(value), offsetof(Layout, opCode));
256 }
257
258 private:
data()259 view_type data() const {
260 return const_cast<char*>(ConstView::view2ptr());
261 }
262 };
263
264 class Value : public EncodedValueStorage<Layout, ConstView, View> {
265 public:
Value()266 Value() {
267 MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout));
268 }
269
Value(ZeroInitTag_t zit)270 Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {}
271 };
272
273 } // namespace MSGHEADER
274
275 namespace MsgData {
276
277 #pragma pack(1)
278 struct Layout {
279 MSGHEADER::Layout header;
280 char data[4];
281 };
282 #pragma pack()
283
284 class ConstView {
285 public:
ConstView(const char * storage)286 ConstView(const char* storage) : _storage(storage) {}
287
view2ptr()288 const char* view2ptr() const {
289 return storage().view();
290 }
291
getLen()292 int32_t getLen() const {
293 return header().getMessageLength();
294 }
295
getId()296 int32_t getId() const {
297 return header().getRequestMsgId();
298 }
299
getResponseToMsgId()300 int32_t getResponseToMsgId() const {
301 return header().getResponseToMsgId();
302 }
303
getNetworkOp()304 NetworkOp getNetworkOp() const {
305 return NetworkOp(header().getOpCode());
306 }
307
data()308 const char* data() const {
309 return storage().view(offsetof(Layout, data));
310 }
311
valid()312 bool valid() const {
313 if (getLen() <= 0 || getLen() > (4 * BSONObjMaxInternalSize))
314 return false;
315 if (getNetworkOp() < 0 || getNetworkOp() > 30000)
316 return false;
317 return true;
318 }
319
getCursor()320 int64_t getCursor() const {
321 verify(getResponseToMsgId() > 0);
322 verify(getNetworkOp() == opReply);
323 return ConstDataView(data() + sizeof(int32_t)).read<LittleEndian<int64_t>>();
324 }
325
326 int dataLen() const; // len without header
327
328 protected:
storage()329 const ConstDataView& storage() const {
330 return _storage;
331 }
332
header()333 MSGHEADER::ConstView header() const {
334 return storage().view(offsetof(Layout, header));
335 }
336
337 private:
338 ConstDataView _storage;
339 };
340
341 class View : public ConstView {
342 public:
View(char * storage)343 View(char* storage) : ConstView(storage) {}
344
345 using ConstView::view2ptr;
view2ptr()346 char* view2ptr() {
347 return storage().view();
348 }
349
setLen(int value)350 void setLen(int value) {
351 return header().setMessageLength(value);
352 }
353
setId(int32_t value)354 void setId(int32_t value) {
355 return header().setRequestMsgId(value);
356 }
357
setResponseToMsgId(int32_t value)358 void setResponseToMsgId(int32_t value) {
359 return header().setResponseToMsgId(value);
360 }
361
setOperation(int value)362 void setOperation(int value) {
363 return header().setOpCode(value);
364 }
365
366 using ConstView::data;
data()367 char* data() {
368 return storage().view(offsetof(Layout, data));
369 }
370
371 private:
storage()372 DataView storage() const {
373 return const_cast<char*>(ConstView::view2ptr());
374 }
375
header()376 MSGHEADER::View header() const {
377 return storage().view(offsetof(Layout, header));
378 }
379 };
380
381 class Value : public EncodedValueStorage<Layout, ConstView, View> {
382 public:
Value()383 Value() {
384 MONGO_STATIC_ASSERT(sizeof(Value) == sizeof(Layout));
385 }
386
Value(ZeroInitTag_t zit)387 Value(ZeroInitTag_t zit) : EncodedValueStorage<Layout, ConstView, View>(zit) {}
388 };
389
390 const int MsgDataHeaderSize = sizeof(Value) - 4;
391
dataLen()392 inline int ConstView::dataLen() const {
393 return getLen() - MsgDataHeaderSize;
394 }
395
396 } // namespace MsgData
397
398 class Message {
399 public:
400 Message() = default;
Message(SharedBuffer data)401 explicit Message(SharedBuffer data) : _buf(std::move(data)) {}
402
header()403 MsgData::View header() const {
404 verify(!empty());
405 return _buf.get();
406 }
407
operation()408 NetworkOp operation() const {
409 return header().getNetworkOp();
410 }
411
singleData()412 MsgData::View singleData() const {
413 massert(13273, "single data buffer expected", _buf);
414 return header();
415 }
416
empty()417 bool empty() const {
418 return !_buf;
419 }
420
size()421 int size() const {
422 if (_buf) {
423 return MsgData::ConstView(_buf.get()).getLen();
424 }
425 return 0;
426 }
427
dataSize()428 int dataSize() const {
429 return size() - sizeof(MSGHEADER::Value);
430 }
431
reset()432 void reset() {
433 _buf = {};
434 }
435
436 // use to set first buffer if empty
setData(SharedBuffer buf)437 void setData(SharedBuffer buf) {
438 verify(empty());
439 _buf = std::move(buf);
440 }
setData(int operation,const char * msgtxt)441 void setData(int operation, const char* msgtxt) {
442 setData(operation, msgtxt, strlen(msgtxt) + 1);
443 }
setData(int operation,const char * msgdata,size_t len)444 void setData(int operation, const char* msgdata, size_t len) {
445 verify(empty());
446 size_t dataLen = len + sizeof(MsgData::Value) - 4;
447 _buf = SharedBuffer::allocate(dataLen);
448 MsgData::View d = _buf.get();
449 if (len)
450 memcpy(d.data(), msgdata, len);
451 d.setLen(dataLen);
452 d.setOperation(operation);
453 }
454
buf()455 char* buf() {
456 return _buf.get();
457 }
458
buf()459 const char* buf() const {
460 return _buf.get();
461 }
462
sharedBuffer()463 SharedBuffer sharedBuffer() {
464 return _buf;
465 }
466
sharedBuffer()467 ConstSharedBuffer sharedBuffer() const {
468 return _buf;
469 }
470
471 private:
472 SharedBuffer _buf;
473 };
474
475 /**
476 * Returns an always incrementing value to be used to assign to the next received network message.
477 */
478 int32_t nextMessageId();
479
480 } // namespace mongo
481