1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10	"errors"
11	"fmt"
12
13	"go.mongodb.org/mongo-driver/bson"
14	"go.mongodb.org/mongo-driver/x/bsonx"
15)
16
17// Msg represents the OP_MSG message of the MongoDB wire protocol.
18type Msg struct {
19	MsgHeader Header
20	FlagBits  MsgFlag
21	Sections  []Section
22	Checksum  uint32
23}
24
25// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
26func (m Msg) MarshalWireMessage() ([]byte, error) {
27	b := make([]byte, 0, m.Len())
28	return m.AppendWireMessage(b)
29}
30
31// ValidateWireMessage implements the Validator and WireMessage interfaces.
32func (m Msg) ValidateWireMessage() error {
33	if int(m.MsgHeader.MessageLength) != m.Len() {
34		return errors.New("incorrect header: message length is not correct")
35	}
36	if m.MsgHeader.OpCode != OpMsg {
37		return errors.New("incorrect header: opcode is not OpMsg")
38	}
39
40	return nil
41}
42
43// AppendWireMessage implements the Appender and WireMessage interfaces.
44//
45// AppendWireMesssage will set the MessageLength property of the MsgHeader if it is zero. It will also set the Opcode
46// to OP_MSG if it is zero. If either of these properties are non-zero and not correct, this method will return both the
47// []byte with the wire message appended to it and an invalid header error.
48func (m Msg) AppendWireMessage(b []byte) ([]byte, error) {
49	var err error
50	err = m.MsgHeader.SetDefaults(m.Len(), OpMsg)
51
52	b = m.MsgHeader.AppendHeader(b)
53	b = appendInt32(b, int32(m.FlagBits))
54
55	for _, section := range m.Sections {
56		newB := make([]byte, 0)
57		newB = section.AppendSection(newB)
58
59		b = section.AppendSection(b)
60	}
61
62	return b, err
63}
64
65// String implements the fmt.Stringer interface.
66func (m Msg) String() string {
67	return fmt.Sprintf(
68		`OP_MSG{MsgHeader: %v, FlagBits: %d, Sections: %v, Checksum: %d}`,
69		m.MsgHeader, m.FlagBits, m.Sections, m.Checksum,
70	)
71}
72
73// Len implements the WireMessage interface.
74func (m Msg) Len() int {
75	// Header + Flags + len of each section + optional checksum
76	totalLen := 16 + 4 // header and flag
77
78	for _, section := range m.Sections {
79		totalLen += section.Len()
80	}
81
82	if m.FlagBits&ChecksumPresent > 0 {
83		totalLen += 4
84	}
85
86	return totalLen
87}
88
89// UnmarshalWireMessage implements the Unmarshaler interface.
90func (m *Msg) UnmarshalWireMessage(b []byte) error {
91	var err error
92
93	m.MsgHeader, err = ReadHeader(b, 0)
94	if err != nil {
95		return err
96	}
97	if len(b) < int(m.MsgHeader.MessageLength) {
98		return Error{
99			Type:    ErrOpMsg,
100			Message: "[]byte too small",
101		}
102	}
103
104	m.FlagBits = MsgFlag(readInt32(b, 16))
105
106	// read each section
107	sectionBytes := m.MsgHeader.MessageLength - 16 - 4 // number of bytes taken up by sections
108	hasChecksum := m.FlagBits&ChecksumPresent > 0
109	if hasChecksum {
110		sectionBytes -= 4 // 4 bytes at end for checksum
111	}
112
113	m.Sections = make([]Section, 0)
114	position := 20 // position to read from
115	for sectionBytes > 0 {
116		sectionType := SectionType(b[position])
117		position++
118
119		switch sectionType {
120		case SingleDocument:
121			rdr, size, err := readDocument(b, int32(position))
122			if err.Message != "" {
123				err.Type = ErrOpMsg
124				return err
125			}
126
127			position += size
128			sb := SectionBody{
129				Document: rdr,
130			}
131			sb.PayloadType = sb.Kind()
132
133			sectionBytes -= int32(sb.Len())
134			m.Sections = append(m.Sections, sb)
135		case DocumentSequence:
136			sds := SectionDocumentSequence{}
137			sds.Size = readInt32(b, int32(position))
138			position += 4
139
140			identifier, err := readCString(b, int32(position))
141			if err != nil {
142				return err
143			}
144
145			sds.Identifier = identifier
146			position += len(identifier) + 1 // +1 for \0
147			sds.PayloadType = sds.Kind()
148
149			// length of documents to read
150			// sequenceLen - 4 bytes for size field - identifierLength (including \0)
151			docsLen := int(sds.Size) - 4 - len(identifier) - 1
152			for docsLen > 0 {
153				rdr, size, err := readDocument(b, int32(position))
154				if err.Message != "" {
155					err.Type = ErrOpMsg
156					return err
157				}
158
159				position += size
160				sds.Documents = append(sds.Documents, rdr)
161				docsLen -= size
162			}
163
164			sectionBytes -= int32(sds.Len())
165			m.Sections = append(m.Sections, sds)
166		}
167	}
168
169	if hasChecksum {
170		m.Checksum = uint32(readInt32(b, int32(position)))
171	}
172
173	return nil
174}
175
176// GetMainDocument returns the document containing the message to send.
177func (m *Msg) GetMainDocument() (bsonx.Doc, error) {
178	return bsonx.ReadDoc(m.Sections[0].(SectionBody).Document)
179}
180
181// GetSequenceArray returns this message's document sequence as a BSON array along with the array identifier.
182// If this message has no associated document sequence, a nil array is returned.
183func (m *Msg) GetSequenceArray() (bsonx.Arr, string, error) {
184	if len(m.Sections) == 1 {
185		return nil, "", nil
186	}
187
188	arr := bsonx.Arr{}
189	sds := m.Sections[1].(SectionDocumentSequence)
190
191	for _, rdr := range sds.Documents {
192		doc, err := bsonx.ReadDoc([]byte(rdr))
193		if err != nil {
194			return nil, "", err
195		}
196
197		arr = append(arr, bsonx.Document(doc))
198	}
199
200	return arr, sds.Identifier, nil
201}
202
203// AcknowledgedWrite returns true if this msg represents an acknowledged write command.
204func (m *Msg) AcknowledgedWrite() bool {
205	return m.FlagBits&MoreToCome == 0
206}
207
208// MsgFlag represents the flags on an OP_MSG message.
209type MsgFlag uint32
210
211// These constants represent the individual flags on an OP_MSG message.
212const (
213	ChecksumPresent MsgFlag = 1 << iota
214	MoreToCome
215
216	ExhaustAllowed MsgFlag = 1 << 16
217)
218
219// Section represents a section on an OP_MSG message.
220type Section interface {
221	Kind() SectionType
222	Len() int
223	AppendSection([]byte) []byte
224}
225
226// SectionBody represents the kind body of an OP_MSG message.
227type SectionBody struct {
228	PayloadType SectionType
229	Document    bson.Raw
230}
231
232// Kind implements the Section interface.
233func (sb SectionBody) Kind() SectionType {
234	return SingleDocument
235}
236
237// Len implements the Section interface
238func (sb SectionBody) Len() int {
239	return 1 + len(sb.Document) // 1 for PayloadType
240}
241
242// AppendSection implements the Section interface.
243func (sb SectionBody) AppendSection(dest []byte) []byte {
244	dest = append(dest, byte(SingleDocument))
245	dest = append(dest, sb.Document...)
246	return dest
247}
248
249// SectionDocumentSequence represents the kind document sequence of an OP_MSG message.
250type SectionDocumentSequence struct {
251	PayloadType SectionType
252	Size        int32
253	Identifier  string
254	Documents   []bson.Raw
255}
256
257// Kind implements the Section interface.
258func (sds SectionDocumentSequence) Kind() SectionType {
259	return DocumentSequence
260}
261
262// Len implements the Section interface
263func (sds SectionDocumentSequence) Len() int {
264	// PayloadType + Size + Identifier + 1 (null terminator) + totalDocLen
265	totalDocLen := 0
266	for _, doc := range sds.Documents {
267		totalDocLen += len(doc)
268	}
269
270	return 1 + 4 + len(sds.Identifier) + 1 + totalDocLen
271}
272
273// PayloadLen returns the length of the payload
274func (sds SectionDocumentSequence) PayloadLen() int {
275	// 4 bytes for size field, len identifier (including \0), and total docs len
276	return sds.Len() - 1
277}
278
279// AppendSection implements the Section interface
280func (sds SectionDocumentSequence) AppendSection(dest []byte) []byte {
281	dest = append(dest, byte(DocumentSequence))
282	dest = appendInt32(dest, sds.Size)
283	dest = appendCString(dest, sds.Identifier)
284
285	for _, doc := range sds.Documents {
286		dest = append(dest, doc...)
287	}
288
289	return dest
290}
291
292// SectionType represents the type for 1 section in an OP_MSG
293type SectionType uint8
294
295// These constants represent the individual section types for a section in an OP_MSG
296const (
297	SingleDocument SectionType = iota
298	DocumentSequence
299)
300
301// OpmsgWireVersion is the minimum wire version needed to use OP_MSG
302const OpmsgWireVersion = 6
303