1// Copyright (C) MongoDB, Inc. 2017-present. 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); you may 4// not use this file except in compliance with the License. You may obtain 5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 6 7package wiremessage 8 9import ( 10 "errors" 11 "fmt" 12 "strings" 13 14 "go.mongodb.org/mongo-driver/bson" 15 "go.mongodb.org/mongo-driver/x/bsonx" 16) 17 18// Reply represents the OP_REPLY message of the MongoDB wire protocol. 19type Reply struct { 20 MsgHeader Header 21 ResponseFlags ReplyFlag 22 CursorID int64 23 StartingFrom int32 24 NumberReturned int32 25 Documents []bson.Raw 26} 27 28// MarshalWireMessage implements the Marshaler and WireMessage interfaces. 29// 30// See AppendWireMessage for a description of the rules this method follows. 31func (r Reply) MarshalWireMessage() ([]byte, error) { 32 b := make([]byte, 0, r.Len()) 33 return r.AppendWireMessage(b) 34} 35 36// ValidateWireMessage implements the Validator and WireMessage interfaces. 37func (r Reply) ValidateWireMessage() error { 38 if int(r.MsgHeader.MessageLength) != r.Len() { 39 return errors.New("incorrect header: message length is not correct") 40 } 41 if r.MsgHeader.OpCode != OpReply { 42 return errors.New("incorrect header: op code is not OpReply") 43 } 44 45 return nil 46} 47 48// AppendWireMessage implements the Appender and WireMessage interfaces. 49// 50// AppendWireMessage will set the MessageLength property of the MsgHeader 51// if it is zero. It will also set the OpCode to OpQuery if the OpCode is 52// zero. If either of these properties are non-zero and not correct, this 53// method will return both the []byte with the wire message appended to it 54// and an invalid header error. 55func (r Reply) AppendWireMessage(b []byte) ([]byte, error) { 56 var err error 57 err = r.MsgHeader.SetDefaults(r.Len(), OpReply) 58 59 b = r.MsgHeader.AppendHeader(b) 60 b = appendInt32(b, int32(r.ResponseFlags)) 61 b = appendInt64(b, r.CursorID) 62 b = appendInt32(b, r.StartingFrom) 63 b = appendInt32(b, r.NumberReturned) 64 for _, d := range r.Documents { 65 b = append(b, d...) 66 } 67 return b, err 68} 69 70// String implements the fmt.Stringer interface. 71func (r Reply) String() string { 72 return fmt.Sprintf( 73 `OP_REPLY{MsgHeader: %s, ResponseFlags: %s, CursorID: %d, StartingFrom: %d, NumberReturned: %d, Documents: %v}`, 74 r.MsgHeader, r.ResponseFlags, r.CursorID, r.StartingFrom, r.NumberReturned, r.Documents, 75 ) 76} 77 78// Len implements the WireMessage interface. 79func (r Reply) Len() int { 80 // Header + Flags + CursorID + StartingFrom + NumberReturned + Length of Length of Documents 81 docsLen := 0 82 for _, d := range r.Documents { 83 docsLen += len(d) 84 } 85 return 16 + 4 + 8 + 4 + 4 + docsLen 86} 87 88// UnmarshalWireMessage implements the Unmarshaler interface. 89func (r *Reply) UnmarshalWireMessage(b []byte) error { 90 var err error 91 r.MsgHeader, err = ReadHeader(b, 0) 92 if err != nil { 93 return err 94 } 95 if r.MsgHeader.MessageLength < 36 { 96 return errors.New("invalid OP_REPLY: header length too small") 97 } 98 if len(b) < int(r.MsgHeader.MessageLength) { 99 return errors.New("invalid OP_REPLY: []byte too small") 100 } 101 102 r.ResponseFlags = ReplyFlag(readInt32(b, 16)) 103 r.CursorID = readInt64(b, 20) 104 r.StartingFrom = readInt32(b, 28) 105 r.NumberReturned = readInt32(b, 32) 106 pos := 36 107 for pos < len(b) { 108 rdr, size, err := readDocument(b, int32(pos)) 109 if err.Message != "" { 110 err.Type = ErrOpReply 111 return err 112 } 113 r.Documents = append(r.Documents, rdr) 114 pos += size 115 } 116 117 return nil 118} 119 120// GetMainLegacyDocument constructs and returns a BSON document for this reply. 121func (r *Reply) GetMainLegacyDocument(fullCollectionName string) (bsonx.Doc, error) { 122 if r.ResponseFlags&CursorNotFound > 0 { 123 fmt.Println("cursor not found err") 124 return bsonx.Doc{ 125 {"ok", bsonx.Int32(0)}, 126 }, nil 127 } 128 if r.ResponseFlags&QueryFailure > 0 { 129 firstDoc := r.Documents[0] 130 return bsonx.Doc{ 131 {"ok", bsonx.Int32(0)}, 132 {"errmsg", bsonx.String(firstDoc.Lookup("$err").StringValue())}, 133 {"code", bsonx.Int32(firstDoc.Lookup("code").Int32())}, 134 }, nil 135 } 136 137 doc := bsonx.Doc{ 138 {"ok", bsonx.Int32(1)}, 139 } 140 141 batchStr := "firstBatch" 142 if r.StartingFrom != 0 { 143 batchStr = "nextBatch" 144 } 145 146 batchArr := make([]bsonx.Val, len(r.Documents)) 147 for i, docRaw := range r.Documents { 148 doc, err := bsonx.ReadDoc(docRaw) 149 if err != nil { 150 return nil, err 151 } 152 153 batchArr[i] = bsonx.Document(doc) 154 } 155 156 cursorDoc := bsonx.Doc{ 157 {"id", bsonx.Int64(r.CursorID)}, 158 {"ns", bsonx.String(fullCollectionName)}, 159 {batchStr, bsonx.Array(batchArr)}, 160 } 161 162 doc = doc.Append("cursor", bsonx.Document(cursorDoc)) 163 return doc, nil 164} 165 166// GetMainDocument returns the main BSON document for this reply. 167func (r *Reply) GetMainDocument() (bsonx.Doc, error) { 168 return bsonx.ReadDoc([]byte(r.Documents[0])) 169} 170 171// ReplyFlag represents the flags of an OP_REPLY message. 172type ReplyFlag int32 173 174// These constants represent the individual flags of an OP_REPLY message. 175const ( 176 CursorNotFound ReplyFlag = 1 << iota 177 QueryFailure 178 ShardConfigStale 179 AwaitCapable 180) 181 182// String implements the fmt.Stringer interface. 183func (rf ReplyFlag) String() string { 184 strs := make([]string, 0) 185 if rf&CursorNotFound == CursorNotFound { 186 strs = append(strs, "CursorNotFound") 187 } 188 if rf&QueryFailure == QueryFailure { 189 strs = append(strs, "QueryFailure") 190 } 191 if rf&ShardConfigStale == ShardConfigStale { 192 strs = append(strs, "ShardConfigStale") 193 } 194 if rf&AwaitCapable == AwaitCapable { 195 strs = append(strs, "AwaitCapable") 196 } 197 str := "[" 198 str += strings.Join(strs, ", ") 199 str += "]" 200 return str 201} 202