1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7package wiremessage
8
9import (
10	"bytes"
11	"testing"
12
13	"github.com/google/go-cmp/cmp"
14	"go.mongodb.org/mongo-driver/bson"
15)
16
17func TestQuery(t *testing.T) {
18	t.Run("AppendWireMessage", func(t *testing.T) {
19		testCases := []struct {
20			name string
21			q    Query
22			res  []byte
23			err  error
24		}{
25			{
26				"success",
27				Query{
28					MsgHeader:          Header{},
29					FullCollectionName: "foo.bar",
30					Query:              bson.Raw{0x05, 0x00, 0x00, 0x00, 0x00},
31				},
32				[]byte{
33					0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
34					0x00, 0x00, 0x00, 0x00, 0xD4, 0x07, 0x00, 0x00,
35					0x00, 0x00, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x2e,
36					0x62, 0x61, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
37					0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
38					0x00,
39				},
40				nil,
41			},
42		}
43
44		for _, tc := range testCases {
45			t.Run(tc.name, func(t *testing.T) {
46				res := make([]byte, 0)
47				res, err := tc.q.AppendWireMessage(res)
48				if err != tc.err {
49					t.Errorf("Did not get expected error. got %v; want %v", err, tc.err)
50				}
51				if !bytes.Equal(res, tc.res) {
52					t.Errorf("Results do not match. got %#v; want %#v", res, tc.res)
53				}
54			})
55		}
56	})
57	t.Run("UnmarshalWireMessage", func(t *testing.T) {
58		testCases := []struct {
59			name string
60			req  []byte
61			q    Query
62			err  error
63		}{
64			{
65				"success",
66				[]byte{
67					0x29, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
68					0x00, 0x00, 0x00, 0x00, 0xD4, 0x07, 0x00, 0x00,
69					0x00, 0x00, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x2e,
70					0x62, 0x61, 0x72, 0x00, 0x00, 0x00, 0x00, 0x00,
71					0x00, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00,
72					0x00,
73				},
74				Query{
75					MsgHeader: Header{
76						MessageLength: 41,
77						OpCode:        OpQuery,
78					},
79					FullCollectionName: "foo.bar",
80					Query:              bson.Raw{0x05, 0x00, 0x00, 0x00, 0x00},
81				},
82				nil,
83			},
84		}
85
86		for _, tc := range testCases {
87			t.Run(tc.name, func(t *testing.T) {
88				var q Query
89				err := q.UnmarshalWireMessage(tc.req)
90				if err != tc.err {
91					t.Errorf("Did not get expected error. got %v; want %v", err, tc.err)
92				}
93				if diff := cmp.Diff(q, tc.q); diff != "" {
94					t.Errorf("Results do not match. (-got +want):\n%s", diff)
95				}
96			})
97		}
98	})
99}
100