1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10	"errors"
11	"fmt"
12	"strings"
13
14	"go.mongodb.org/mongo-driver/bson"
15	"go.mongodb.org/mongo-driver/x/bsonx"
16)
17
18// Reply represents the OP_REPLY message of the MongoDB wire protocol.
19type Reply struct {
20	MsgHeader      Header
21	ResponseFlags  ReplyFlag
22	CursorID       int64
23	StartingFrom   int32
24	NumberReturned int32
25	Documents      []bson.Raw
26}
27
28// MarshalWireMessage implements the Marshaler and WireMessage interfaces.
29//
30// See AppendWireMessage for a description of the rules this method follows.
31func (r Reply) MarshalWireMessage() ([]byte, error) {
32	b := make([]byte, 0, r.Len())
33	return r.AppendWireMessage(b)
34}
35
36// ValidateWireMessage implements the Validator and WireMessage interfaces.
37func (r Reply) ValidateWireMessage() error {
38	if int(r.MsgHeader.MessageLength) != r.Len() {
39		return errors.New("incorrect header: message length is not correct")
40	}
41	if r.MsgHeader.OpCode != OpReply {
42		return errors.New("incorrect header: op code is not OpReply")
43	}
44
45	return nil
46}
47
48// AppendWireMessage implements the Appender and WireMessage interfaces.
49//
50// AppendWireMessage will set the MessageLength property of the MsgHeader
51// if it is zero. It will also set the OpCode to OpQuery if the OpCode is
52// zero. If either of these properties are non-zero and not correct, this
53// method will return both the []byte with the wire message appended to it
54// and an invalid header error.
55func (r Reply) AppendWireMessage(b []byte) ([]byte, error) {
56	var err error
57	err = r.MsgHeader.SetDefaults(r.Len(), OpReply)
58
59	b = r.MsgHeader.AppendHeader(b)
60	b = appendInt32(b, int32(r.ResponseFlags))
61	b = appendInt64(b, r.CursorID)
62	b = appendInt32(b, r.StartingFrom)
63	b = appendInt32(b, r.NumberReturned)
64	for _, d := range r.Documents {
65		b = append(b, d...)
66	}
67	return b, err
68}
69
70// String implements the fmt.Stringer interface.
71func (r Reply) String() string {
72	return fmt.Sprintf(
73		`OP_REPLY{MsgHeader: %s, ResponseFlags: %s, CursorID: %d, StartingFrom: %d, NumberReturned: %d, Documents: %v}`,
74		r.MsgHeader, r.ResponseFlags, r.CursorID, r.StartingFrom, r.NumberReturned, r.Documents,
75	)
76}
77
78// Len implements the WireMessage interface.
79func (r Reply) Len() int {
80	// Header + Flags + CursorID + StartingFrom + NumberReturned + Length of Length of Documents
81	docsLen := 0
82	for _, d := range r.Documents {
83		docsLen += len(d)
84	}
85	return 16 + 4 + 8 + 4 + 4 + docsLen
86}
87
88// UnmarshalWireMessage implements the Unmarshaler interface.
89func (r *Reply) UnmarshalWireMessage(b []byte) error {
90	var err error
91	r.MsgHeader, err = ReadHeader(b, 0)
92	if err != nil {
93		return err
94	}
95	if r.MsgHeader.MessageLength < 36 {
96		return errors.New("invalid OP_REPLY: header length too small")
97	}
98	if len(b) < int(r.MsgHeader.MessageLength) {
99		return errors.New("invalid OP_REPLY: []byte too small")
100	}
101
102	r.ResponseFlags = ReplyFlag(readInt32(b, 16))
103	r.CursorID = readInt64(b, 20)
104	r.StartingFrom = readInt32(b, 28)
105	r.NumberReturned = readInt32(b, 32)
106	pos := 36
107	for pos < len(b) {
108		rdr, size, err := readDocument(b, int32(pos))
109		if err.Message != "" {
110			err.Type = ErrOpReply
111			return err
112		}
113		r.Documents = append(r.Documents, rdr)
114		pos += size
115	}
116
117	return nil
118}
119
120// GetMainLegacyDocument constructs and returns a BSON document for this reply.
121func (r *Reply) GetMainLegacyDocument(fullCollectionName string) (bsonx.Doc, error) {
122	if r.ResponseFlags&CursorNotFound > 0 {
123		fmt.Println("cursor not found err")
124		return bsonx.Doc{
125			{"ok", bsonx.Int32(0)},
126		}, nil
127	}
128	if r.ResponseFlags&QueryFailure > 0 {
129		firstDoc := r.Documents[0]
130		return bsonx.Doc{
131			{"ok", bsonx.Int32(0)},
132			{"errmsg", bsonx.String(firstDoc.Lookup("$err").StringValue())},
133			{"code", bsonx.Int32(firstDoc.Lookup("code").Int32())},
134		}, nil
135	}
136
137	doc := bsonx.Doc{
138		{"ok", bsonx.Int32(1)},
139	}
140
141	batchStr := "firstBatch"
142	if r.StartingFrom != 0 {
143		batchStr = "nextBatch"
144	}
145
146	batchArr := make([]bsonx.Val, len(r.Documents))
147	for i, docRaw := range r.Documents {
148		doc, err := bsonx.ReadDoc(docRaw)
149		if err != nil {
150			return nil, err
151		}
152
153		batchArr[i] = bsonx.Document(doc)
154	}
155
156	cursorDoc := bsonx.Doc{
157		{"id", bsonx.Int64(r.CursorID)},
158		{"ns", bsonx.String(fullCollectionName)},
159		{batchStr, bsonx.Array(batchArr)},
160	}
161
162	doc = doc.Append("cursor", bsonx.Document(cursorDoc))
163	return doc, nil
164}
165
166// GetMainDocument returns the main BSON document for this reply.
167func (r *Reply) GetMainDocument() (bsonx.Doc, error) {
168	return bsonx.ReadDoc([]byte(r.Documents[0]))
169}
170
171// ReplyFlag represents the flags of an OP_REPLY message.
172type ReplyFlag int32
173
174// These constants represent the individual flags of an OP_REPLY message.
175const (
176	CursorNotFound ReplyFlag = 1 << iota
177	QueryFailure
178	ShardConfigStale
179	AwaitCapable
180)
181
182// String implements the fmt.Stringer interface.
183func (rf ReplyFlag) String() string {
184	strs := make([]string, 0)
185	if rf&CursorNotFound == CursorNotFound {
186		strs = append(strs, "CursorNotFound")
187	}
188	if rf&QueryFailure == QueryFailure {
189		strs = append(strs, "QueryFailure")
190	}
191	if rf&ShardConfigStale == ShardConfigStale {
192		strs = append(strs, "ShardConfigStale")
193	}
194	if rf&AwaitCapable == AwaitCapable {
195		strs = append(strs, "AwaitCapable")
196	}
197	str := "["
198	str += strings.Join(strs, ", ")
199	str += "]"
200	return str
201}
202