1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <wangle/acceptor/SocketPeeker.h>
18
19 #include <folly/portability/GMock.h>
20 #include <folly/portability/GTest.h>
21 #include <thread>
22
23 #include <folly/io/async/test/MockAsyncSocket.h>
24
25 using namespace folly;
26 using namespace folly::test;
27 using namespace wangle;
28 using namespace testing;
29
30 class MockSocketPeekerCallback : public SocketPeeker::Callback {
31 public:
32 ~MockSocketPeekerCallback() override = default;
33
34 MOCK_METHOD1(peekSuccess_, void(typename std::vector<uint8_t>));
peekSuccess(std::vector<uint8_t> peekBytes)35 void peekSuccess(std::vector<uint8_t> peekBytes) noexcept override {
36 peekSuccess_(peekBytes);
37 }
38
39 MOCK_METHOD1(peekError_, void(const folly::AsyncSocketException&));
peekError(const folly::AsyncSocketException & ex)40 void peekError(const folly::AsyncSocketException& ex) noexcept override {
41 peekError_(ex);
42 }
43 };
44
45 class SocketPeekerTest : public Test {
46 public:
SetUp()47 void SetUp() override {
48 sock = new MockAsyncSocket(&base);
49 }
50
TearDown()51 void TearDown() override {
52 sock->destroy();
53 }
54
55 MockAsyncSocket* sock;
56 MockSocketPeekerCallback callback;
57 EventBase base;
58 };
59
60 MATCHER_P2(BufMatches, buf, len, "") {
61 if (arg.size() != size_t(len)) {
62 return false;
63 } else if (len == 0) {
64 return true;
65 }
66 return memcmp(buf, arg.data(), len) == 0;
67 }
68
69 MATCHER_P2(IOBufMatches, buf, len, "") {
70 return folly::IOBufEqualTo()(arg, folly::IOBuf::copyBuffer(buf, len));
71 }
72
TEST_F(SocketPeekerTest,TestPeekSuccess)73 TEST_F(SocketPeekerTest, TestPeekSuccess) {
74 EXPECT_CALL(*sock, setReadCB(_));
75 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
76 peeker->start();
77
78 uint8_t* buf = nullptr;
79 size_t len = 0;
80 peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
81 EXPECT_EQ(2, len);
82 // first 2 bytes of SSL3+.
83 buf[0] = 0x16;
84 buf[1] = 0x03;
85 EXPECT_CALL(*sock, _setPreReceivedData(IOBufMatches(buf, 2)));
86 EXPECT_CALL(callback, peekSuccess_(BufMatches(buf, 2)));
87 // once after peeking, and once during destruction.
88 EXPECT_CALL(*sock, setReadCB(nullptr));
89 peeker->readDataAvailable(2);
90 }
91
TEST_F(SocketPeekerTest,TestEOFDuringPeek)92 TEST_F(SocketPeekerTest, TestEOFDuringPeek) {
93 EXPECT_CALL(*sock, setReadCB(_));
94 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
95 peeker->start();
96
97 EXPECT_CALL(callback, peekError_(_));
98 EXPECT_CALL(*sock, setReadCB(nullptr));
99 peeker->readEOF();
100 }
101
TEST_F(SocketPeekerTest,TestNotEnoughDataError)102 TEST_F(SocketPeekerTest, TestNotEnoughDataError) {
103 EXPECT_CALL(*sock, setReadCB(_));
104 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
105 peeker->start();
106
107 uint8_t* buf = nullptr;
108 size_t len = 0;
109 peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
110 EXPECT_EQ(2, len);
111 buf[0] = 0x16;
112 peeker->readDataAvailable(1);
113
114 EXPECT_CALL(callback, peekError_(_));
115 EXPECT_CALL(*sock, setReadCB(nullptr));
116 peeker->readEOF();
117 }
118
TEST_F(SocketPeekerTest,TestMultiplePeeks)119 TEST_F(SocketPeekerTest, TestMultiplePeeks) {
120 EXPECT_CALL(*sock, setReadCB(_));
121 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
122 peeker->start();
123
124 uint8_t* buf = nullptr;
125 size_t len = 0;
126 peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
127 EXPECT_EQ(2, len);
128 buf[0] = 0x16;
129 peeker->readDataAvailable(1);
130
131 peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
132 EXPECT_EQ(1, len);
133 buf[0] = 0x03;
134
135 EXPECT_CALL(*sock, _setPreReceivedData(IOBufMatches("\x16\x03", 2)));
136 EXPECT_CALL(callback, peekSuccess_(BufMatches("\x16\x03", 2)));
137 EXPECT_CALL(*sock, setReadCB(nullptr));
138 peeker->readDataAvailable(1);
139 }
140
TEST_F(SocketPeekerTest,TestDestoryWhilePeeking)141 TEST_F(SocketPeekerTest, TestDestoryWhilePeeking) {
142 EXPECT_CALL(*sock, setReadCB(_));
143 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
144 peeker->start();
145 peeker = nullptr;
146 }
147
TEST_F(SocketPeekerTest,TestNoPeekSuccess)148 TEST_F(SocketPeekerTest, TestNoPeekSuccess) {
149 SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 0));
150
151 char buf = '\0';
152 EXPECT_CALL(callback, peekSuccess_(BufMatches(&buf, 0)));
153 peeker->start();
154 }
155