1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package wiremessage 8 9import ( 10 "errors" 11 "fmt" 12 13 "go.mongodb.org/mongo-driver/bson" 14 "go.mongodb.org/mongo-driver/x/bsonx" 15) 16 17// Msg represents the OP_MSG message of the MongoDB wire protocol. 18type Msg struct { 19 MsgHeader Header 20 FlagBits MsgFlag 21 Sections []Section 22 Checksum uint32 23} 24 25// MarshalWireMessage implements the Marshaler and WireMessage interfaces. 26func (m Msg) MarshalWireMessage() ([]byte, error) { 27 b := make([]byte, 0, m.Len()) 28 return m.AppendWireMessage(b) 29} 30 31// ValidateWireMessage implements the Validator and WireMessage interfaces. 32func (m Msg) ValidateWireMessage() error { 33 if int(m.MsgHeader.MessageLength) != m.Len() { 34 return errors.New("incorrect header: message length is not correct") 35 } 36 if m.MsgHeader.OpCode != OpMsg { 37 return errors.New("incorrect header: opcode is not OpMsg") 38 } 39 40 return nil 41} 42 43// AppendWireMessage implements the Appender and WireMessage interfaces. 44// 45// AppendWireMesssage will set the MessageLength property of the MsgHeader if it is zero. It will also set the Opcode 46// to OP_MSG if it is zero. If either of these properties are non-zero and not correct, this method will return both the 47// []byte with the wire message appended to it and an invalid header error. 48func (m Msg) AppendWireMessage(b []byte) ([]byte, error) { 49 var err error 50 err = m.MsgHeader.SetDefaults(m.Len(), OpMsg) 51 52 b = m.MsgHeader.AppendHeader(b) 53 b = appendInt32(b, int32(m.FlagBits)) 54 55 for _, section := range m.Sections { 56 newB := make([]byte, 0) 57 newB = section.AppendSection(newB) 58 59 b = section.AppendSection(b) 60 } 61 62 return b, err 63} 64 65// String implements the fmt.Stringer interface. 66func (m Msg) String() string { 67 return fmt.Sprintf( 68 `OP_MSG{MsgHeader: %v, FlagBits: %d, Sections: %v, Checksum: %d}`, 69 m.MsgHeader, m.FlagBits, m.Sections, m.Checksum, 70 ) 71} 72 73// Len implements the WireMessage interface. 74func (m Msg) Len() int { 75 // Header + Flags + len of each section + optional checksum 76 totalLen := 16 + 4 // header and flag 77 78 for _, section := range m.Sections { 79 totalLen += section.Len() 80 } 81 82 if m.FlagBits&ChecksumPresent > 0 { 83 totalLen += 4 84 } 85 86 return totalLen 87} 88 89// UnmarshalWireMessage implements the Unmarshaler interface. 90func (m *Msg) UnmarshalWireMessage(b []byte) error { 91 var err error 92 93 m.MsgHeader, err = ReadHeader(b, 0) 94 if err != nil { 95 return err 96 } 97 if len(b) < int(m.MsgHeader.MessageLength) { 98 return Error{ 99 Type: ErrOpMsg, 100 Message: "[]byte too small", 101 } 102 } 103 104 m.FlagBits = MsgFlag(readInt32(b, 16)) 105 106 // read each section 107 sectionBytes := m.MsgHeader.MessageLength - 16 - 4 // number of bytes taken up by sections 108 hasChecksum := m.FlagBits&ChecksumPresent > 0 109 if hasChecksum { 110 sectionBytes -= 4 // 4 bytes at end for checksum 111 } 112 113 m.Sections = make([]Section, 0) 114 position := 20 // position to read from 115 for sectionBytes > 0 { 116 sectionType := SectionType(b[position]) 117 position++ 118 119 switch sectionType { 120 case SingleDocument: 121 rdr, size, err := readDocument(b, int32(position)) 122 if err.Message != "" { 123 err.Type = ErrOpMsg 124 return err 125 } 126 127 position += size 128 sb := SectionBody{ 129 Document: rdr, 130 } 131 sb.PayloadType = sb.Kind() 132 133 sectionBytes -= int32(sb.Len()) 134 m.Sections = append(m.Sections, sb) 135 case DocumentSequence: 136 sds := SectionDocumentSequence{} 137 sds.Size = readInt32(b, int32(position)) 138 position += 4 139 140 identifier, err := readCString(b, int32(position)) 141 if err != nil { 142 return err 143 } 144 145 sds.Identifier = identifier 146 position += len(identifier) + 1 // +1 for \0 147 sds.PayloadType = sds.Kind() 148 149 // length of documents to read 150 // sequenceLen - 4 bytes for size field - identifierLength (including \0) 151 docsLen := int(sds.Size) - 4 - len(identifier) - 1 152 for docsLen > 0 { 153 rdr, size, err := readDocument(b, int32(position)) 154 if err.Message != "" { 155 err.Type = ErrOpMsg 156 return err 157 } 158 159 position += size 160 sds.Documents = append(sds.Documents, rdr) 161 docsLen -= size 162 } 163 164 sectionBytes -= int32(sds.Len()) 165 m.Sections = append(m.Sections, sds) 166 } 167 } 168 169 if hasChecksum { 170 m.Checksum = uint32(readInt32(b, int32(position))) 171 } 172 173 return nil 174} 175 176// GetMainDocument returns the document containing the message to send. 177func (m *Msg) GetMainDocument() (bsonx.Doc, error) { 178 return bsonx.ReadDoc(m.Sections[0].(SectionBody).Document) 179} 180 181// GetSequenceArray returns this message's document sequence as a BSON array along with the array identifier. 182// If this message has no associated document sequence, a nil array is returned. 183func (m *Msg) GetSequenceArray() (bsonx.Arr, string, error) { 184 if len(m.Sections) == 1 { 185 return nil, "", nil 186 } 187 188 arr := bsonx.Arr{} 189 sds := m.Sections[1].(SectionDocumentSequence) 190 191 for _, rdr := range sds.Documents { 192 doc, err := bsonx.ReadDoc([]byte(rdr)) 193 if err != nil { 194 return nil, "", err 195 } 196 197 arr = append(arr, bsonx.Document(doc)) 198 } 199 200 return arr, sds.Identifier, nil 201} 202 203// AcknowledgedWrite returns true if this msg represents an acknowledged write command. 204func (m *Msg) AcknowledgedWrite() bool { 205 return m.FlagBits&MoreToCome == 0 206} 207 208// MsgFlag represents the flags on an OP_MSG message. 209type MsgFlag uint32 210 211// These constants represent the individual flags on an OP_MSG message. 212const ( 213 ChecksumPresent MsgFlag = 1 << iota 214 MoreToCome 215 216 ExhaustAllowed MsgFlag = 1 << 16 217) 218 219// Section represents a section on an OP_MSG message. 220type Section interface { 221 Kind() SectionType 222 Len() int 223 AppendSection([]byte) []byte 224} 225 226// SectionBody represents the kind body of an OP_MSG message. 227type SectionBody struct { 228 PayloadType SectionType 229 Document bson.Raw 230} 231 232// Kind implements the Section interface. 233func (sb SectionBody) Kind() SectionType { 234 return SingleDocument 235} 236 237// Len implements the Section interface 238func (sb SectionBody) Len() int { 239 return 1 + len(sb.Document) // 1 for PayloadType 240} 241 242// AppendSection implements the Section interface. 243func (sb SectionBody) AppendSection(dest []byte) []byte { 244 dest = append(dest, byte(SingleDocument)) 245 dest = append(dest, sb.Document...) 246 return dest 247} 248 249// SectionDocumentSequence represents the kind document sequence of an OP_MSG message. 250type SectionDocumentSequence struct { 251 PayloadType SectionType 252 Size int32 253 Identifier string 254 Documents []bson.Raw 255} 256 257// Kind implements the Section interface. 258func (sds SectionDocumentSequence) Kind() SectionType { 259 return DocumentSequence 260} 261 262// Len implements the Section interface 263func (sds SectionDocumentSequence) Len() int { 264 // PayloadType + Size + Identifier + 1 (null terminator) + totalDocLen 265 totalDocLen := 0 266 for _, doc := range sds.Documents { 267 totalDocLen += len(doc) 268 } 269 270 return 1 + 4 + len(sds.Identifier) + 1 + totalDocLen 271} 272 273// PayloadLen returns the length of the payload 274func (sds SectionDocumentSequence) PayloadLen() int { 275 // 4 bytes for size field, len identifier (including \0), and total docs len 276 return sds.Len() - 1 277} 278 279// AppendSection implements the Section interface 280func (sds SectionDocumentSequence) AppendSection(dest []byte) []byte { 281 dest = append(dest, byte(DocumentSequence)) 282 dest = appendInt32(dest, sds.Size) 283 dest = appendCString(dest, sds.Identifier) 284 285 for _, doc := range sds.Documents { 286 dest = append(dest, doc...) 287 } 288 289 return dest 290} 291 292// SectionType represents the type for 1 section in an OP_MSG 293type SectionType uint8 294 295// These constants represent the individual section types for a section in an OP_MSG 296const ( 297 SingleDocument SectionType = iota 298 DocumentSequence 299) 300 301// OpmsgWireVersion is the minimum wire version needed to use OP_MSG 302const OpmsgWireVersion = 6 303