1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <wangle/acceptor/SocketPeeker.h>
18 
19 #include <folly/portability/GMock.h>
20 #include <folly/portability/GTest.h>
21 #include <thread>
22 
23 #include <folly/io/async/test/MockAsyncSocket.h>
24 
25 using namespace folly;
26 using namespace folly::test;
27 using namespace wangle;
28 using namespace testing;
29 
30 class MockSocketPeekerCallback : public SocketPeeker::Callback {
31  public:
32   ~MockSocketPeekerCallback() override = default;
33 
34   MOCK_METHOD1(peekSuccess_, void(typename std::vector<uint8_t>));
peekSuccess(std::vector<uint8_t> peekBytes)35   void peekSuccess(std::vector<uint8_t> peekBytes) noexcept override {
36     peekSuccess_(peekBytes);
37   }
38 
39   MOCK_METHOD1(peekError_, void(const folly::AsyncSocketException&));
peekError(const folly::AsyncSocketException & ex)40   void peekError(const folly::AsyncSocketException& ex) noexcept override {
41     peekError_(ex);
42   }
43 };
44 
45 class SocketPeekerTest : public Test {
46  public:
SetUp()47   void SetUp() override {
48     sock = new MockAsyncSocket(&base);
49   }
50 
TearDown()51   void TearDown() override {
52     sock->destroy();
53   }
54 
55   MockAsyncSocket* sock;
56   MockSocketPeekerCallback callback;
57   EventBase base;
58 };
59 
60 MATCHER_P2(BufMatches, buf, len, "") {
61   if (arg.size() != size_t(len)) {
62     return false;
63   } else if (len == 0) {
64     return true;
65   }
66   return memcmp(buf, arg.data(), len) == 0;
67 }
68 
69 MATCHER_P2(IOBufMatches, buf, len, "") {
70   return folly::IOBufEqualTo()(arg, folly::IOBuf::copyBuffer(buf, len));
71 }
72 
TEST_F(SocketPeekerTest,TestPeekSuccess)73 TEST_F(SocketPeekerTest, TestPeekSuccess) {
74   EXPECT_CALL(*sock, setReadCB(_));
75   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
76   peeker->start();
77 
78   uint8_t* buf = nullptr;
79   size_t len = 0;
80   peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
81   EXPECT_EQ(2, len);
82   // first 2 bytes of SSL3+.
83   buf[0] = 0x16;
84   buf[1] = 0x03;
85   EXPECT_CALL(*sock, _setPreReceivedData(IOBufMatches(buf, 2)));
86   EXPECT_CALL(callback, peekSuccess_(BufMatches(buf, 2)));
87   // once after peeking, and once during destruction.
88   EXPECT_CALL(*sock, setReadCB(nullptr));
89   peeker->readDataAvailable(2);
90 }
91 
TEST_F(SocketPeekerTest,TestEOFDuringPeek)92 TEST_F(SocketPeekerTest, TestEOFDuringPeek) {
93   EXPECT_CALL(*sock, setReadCB(_));
94   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
95   peeker->start();
96 
97   EXPECT_CALL(callback, peekError_(_));
98   EXPECT_CALL(*sock, setReadCB(nullptr));
99   peeker->readEOF();
100 }
101 
TEST_F(SocketPeekerTest,TestNotEnoughDataError)102 TEST_F(SocketPeekerTest, TestNotEnoughDataError) {
103   EXPECT_CALL(*sock, setReadCB(_));
104   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
105   peeker->start();
106 
107   uint8_t* buf = nullptr;
108   size_t len = 0;
109   peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
110   EXPECT_EQ(2, len);
111   buf[0] = 0x16;
112   peeker->readDataAvailable(1);
113 
114   EXPECT_CALL(callback, peekError_(_));
115   EXPECT_CALL(*sock, setReadCB(nullptr));
116   peeker->readEOF();
117 }
118 
TEST_F(SocketPeekerTest,TestMultiplePeeks)119 TEST_F(SocketPeekerTest, TestMultiplePeeks) {
120   EXPECT_CALL(*sock, setReadCB(_));
121   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
122   peeker->start();
123 
124   uint8_t* buf = nullptr;
125   size_t len = 0;
126   peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
127   EXPECT_EQ(2, len);
128   buf[0] = 0x16;
129   peeker->readDataAvailable(1);
130 
131   peeker->getReadBuffer(reinterpret_cast<void**>(&buf), &len);
132   EXPECT_EQ(1, len);
133   buf[0] = 0x03;
134 
135   EXPECT_CALL(*sock, _setPreReceivedData(IOBufMatches("\x16\x03", 2)));
136   EXPECT_CALL(callback, peekSuccess_(BufMatches("\x16\x03", 2)));
137   EXPECT_CALL(*sock, setReadCB(nullptr));
138   peeker->readDataAvailable(1);
139 }
140 
TEST_F(SocketPeekerTest,TestDestoryWhilePeeking)141 TEST_F(SocketPeekerTest, TestDestoryWhilePeeking) {
142   EXPECT_CALL(*sock, setReadCB(_));
143   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 2));
144   peeker->start();
145   peeker = nullptr;
146 }
147 
TEST_F(SocketPeekerTest,TestNoPeekSuccess)148 TEST_F(SocketPeekerTest, TestNoPeekSuccess) {
149   SocketPeeker::UniquePtr peeker(new SocketPeeker(*sock, &callback, 0));
150 
151   char buf = '\0';
152   EXPECT_CALL(callback, peekSuccess_(BufMatches(&buf, 0)));
153   peeker->start();
154 }
155