1 /*
2 restinio
3 */
4
5 /*!
6 Test upgrade request.
7 */
8
9 #include <catch2/catch.hpp>
10
11 #include <restinio/all.hpp>
12 #include <restinio/websocket/websocket.hpp>
13 #include <restinio/utils/base64.hpp>
14 #include <restinio/utils/sha1.hpp>
15 #include <restinio/websocket/impl/utf8.hpp>
16
17 #include <test/common/utest_logger.hpp>
18 #include <test/common/pub.hpp>
19 #include <test/websocket/common/pub.hpp>
20
21 #include <so_5/all.hpp>
22
23 namespace rws = restinio::websocket::basic;
24
25 using traits_t =
26 restinio::traits_t<
27 restinio::asio_timer_manager_t,
28 utest_logger_t >;
29
30 using http_server_t = restinio::http_server_t< traits_t >;
31
32 struct upgrade_request_t : public so_5::message_t
33 {
upgrade_request_tupgrade_request_t34 upgrade_request_t( restinio::request_handle_t req )
35 : m_req{ std::move( req ) }
36 {}
37
38 restinio::request_handle_t m_req;
39 };
40
41 struct msg_ws_message_t : public so_5::message_t
42 {
msg_ws_message_tmsg_ws_message_t43 msg_ws_message_t( rws::message_handle_t msg )
44 : m_msg{ msg }
45 {}
46
47 rws::message_handle_t m_msg;
48 };
49
50 struct server_started_t : public so_5::signal_t {};
51
52 //
53 // g_last_close_code
54 //
55
56 std::atomic< std::uint16_t > g_last_close_code{ 0 };
57 std::atomic< std::uint16_t > g_message_handled{ 0 };
58
59 //
60 // a_server_t
61 //
62
63 //! Agent running ws server logic.
64 class a_server_t
65 : public so_5::agent_t
66 {
67 using so_base_type_t = so_5::agent_t;
68
69 public:
a_server_t(context_t ctx,so_5::mchain_t server_started_mchain,std::promise<void> & notificator_promise)70 a_server_t(
71 context_t ctx,
72 so_5::mchain_t server_started_mchain,
73 std::promise< void > & notificator_promise )
74 : so_base_type_t{ ctx }
75 , m_server_started_mchain( std::move(server_started_mchain) )
76 , m_notificator_promise{ ¬ificator_promise }
77 , m_http_server{
78 restinio::own_io_context(),
__anon03b238c00102( )79 [this]( auto & settings ){
80 auto mbox = this->so_direct_mbox();
81 settings
82 .port( utest_default_port() )
83 .address( "127.0.0.1" )
84 .request_handler(
85 [mbox]( auto req ){
86 if( restinio::http_connection_header_t::upgrade ==
87 req->header().connection() )
88 {
89 ++g_message_handled;
90 so_5::send< upgrade_request_t >( mbox, std::move( req ) );
91
92 return restinio::request_accepted();
93 }
94
95 return restinio::request_rejected();
96 } );
97 } }
98 , m_other_thread{ m_http_server }
99 {
100 g_last_close_code = 0;
101 g_message_handled = 0;
102 }
103
104 virtual void
so_define_agent()105 so_define_agent() override
106 {
107 so_subscribe_self()
108 .event( &a_server_t::evt_upgrade_request )
109 .event( &a_server_t::evt_ws_message );
110 }
111
112 virtual void
so_evt_start()113 so_evt_start() override
114 {
115 m_other_thread.run();
116 so_5::send<server_started_t>( m_server_started_mchain );
117 }
118
119 virtual void
so_evt_finish()120 so_evt_finish() override
121 {
122 m_ws.reset();
123 m_other_thread.stop_and_join();
124 }
125
126 private:
127 void
evt_upgrade_request(const upgrade_request_t & msg)128 evt_upgrade_request( const upgrade_request_t & msg )
129 {
130 auto req = msg.m_req;
131
132 m_ws =
133 rws::upgrade< traits_t >(
134 *req,
135 rws::activation_t::immediate,
136 [mbox = so_direct_mbox()]( auto /* ws_handle*/, rws::message_handle_t m ){
137 so_5::send< msg_ws_message_t >( mbox, m );
138 } );
139 }
140
141 void
evt_ws_message(const msg_ws_message_t & msg)142 evt_ws_message( const msg_ws_message_t & msg )
143 {
144 if( m_ws )
145 {
146 auto & req = *(msg.m_msg);
147
148 if( rws::opcode_t::text_frame == req.opcode() ||
149 rws::opcode_t::binary_frame == req.opcode() )
150 {
151 if( req.payload() == "close" )
152 {
153 m_ws->send_message(
154 rws::final_frame,
155 rws::opcode_t::connection_close_frame,
156 rws::status_code_to_bin( rws::status_code_t::normal_closure ) );
157 }
158 else if( req.payload() == "shutdown" )
159 {
160 m_ws->shutdown();
161 m_ws.reset();
162 }
163 else if( req.payload() == "kill" )
164 {
165 m_ws->kill();
166 m_ws.reset();
167 }
168 else
169 {
170 auto resp = req;
171 m_ws->send_message(
172 resp,
173 [ notificator_promise = m_notificator_promise ]( const auto & ){
174 notificator_promise->set_value();
175 } );
176 }
177 }
178 else if( rws::opcode_t::ping_frame == req.opcode() )
179 {
180 auto resp = req;
181 resp.set_opcode( rws::opcode_t::pong_frame );
182 m_ws->send_message( resp );
183 }
184 // else if( rws::opcode_t::pong_frame == req.opcode() )
185 // {
186 // // ?
187 // }
188 else if( rws::opcode_t::connection_close_frame == req.opcode() )
189 {
190 g_last_close_code = (std::uint16_t)rws::status_code_from_bin( req.payload() );
191 std::cout << "CLOSE FRAME: " << g_last_close_code << std::endl;
192 m_ws.reset();
193 }
194 }
195 }
196
197 const so_5::mchain_t m_server_started_mchain;
198 std::promise< void > * const m_notificator_promise;
199 http_server_t m_http_server;
200 other_work_thread_for_server_t<http_server_t> m_other_thread;
201 rws::ws_handle_t m_ws;
202
203 };
204
205 const std::string upgrade_request{
206 "GET /chat HTTP/1.1\r\n"
207 "Host: 127.0.0.1\r\n"
208 "Upgrade: websocket\r\n"
209 "Connection: Upgrade\r\n"
210 "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
211 "Sec-WebSocket-Protocol: chat\r\n"
212 "Sec-WebSocket-Version: 1\r\n"
213 "User-Agent: unit-test\r\n"
214 "\r\n" };
215
216 class sobj_t
217 {
218 so_5::wrapped_env_t m_sobj;
219
220 static void
init(so_5::environment_t & env,std::promise<void> & p)221 init( so_5::environment_t & env, std::promise< void > & p )
222 {
223 auto server_started_mchain = so_5::create_mchain(env);
224 const auto binder_maker = [](auto & env) {
225 #if !defined(SO_5_VERSION) || SO_5_VERSION < SO_5_VERSION_MAKE(6ull, 0ull, 0ull)
226 return so_5::disp::active_obj::create_private_disp(env)->binder();
227 #else
228 return so_5::disp::active_obj::make_dispatcher(env).binder();
229 #endif
230 };
231
232 // Launch server as separate coop.
233 env.introduce_coop(
234 binder_maker(env),
235 [&]( so_5::coop_t & coop ) {
236 coop.make_agent< a_server_t >( server_started_mchain, p );
237 } );
238 // Wait acknowledgement about successful server start.
239 so_5::receive(
240 from(server_started_mchain)
241 .handle_n(1u)
242 .empty_timeout(std::chrono::seconds(5)),
243 [](so_5::mhood_t<server_started_t>) {});
244 }
245
246 public:
247 sobj_t( const sobj_t & ) = delete;
248 sobj_t( sobj_t && ) = delete;
249
sobj_t(std::promise<void> & p)250 sobj_t( std::promise< void > & p )
251 {
252 init( m_sobj.environment(), p );
253 }
254
255 void
stop_and_join()256 stop_and_join()
257 {
258 m_sobj.stop();
259 m_sobj.join();
260 }
261 };
262
263 template < typename Socket >
264 void
fragmented_send(Socket & socket,void * buf,std::size_t n)265 fragmented_send( Socket & socket, void * buf, std::size_t n )
266 {
267 const auto * b = static_cast< std::uint8_t * >( buf );
268 while( n-- )
269 {
270 restinio::asio_ns::write( socket, restinio::asio_ns::buffer( b++, 1 ) );
271 if( 0 < n )
272 std::this_thread::sleep_for( std::chrono::milliseconds( n ) );
273 }
274 }
275
276 TEST_CASE( "Simple echo" , "[ws_connection][echo][notificator]" )
277 {
278 std::promise< void > notificator;
279 sobj_t sobj{ notificator };
280
281 do_with_socket(
__anon03b238c00802( auto & socket, auto & )282 [&]( auto & socket, auto & /*io_context*/ ){
283 REQUIRE_NOTHROW(
284 restinio::asio_ns::write(
285 socket, restinio::asio_ns::buffer( upgrade_request.data(), upgrade_request.size() ) )
286 );
287
288 std::array< std::uint8_t, 1024 > data;
289
290 std::size_t len{ 0 };
291 REQUIRE_NOTHROW(
292 len = socket.read_some( restinio::asio_ns::buffer( data.data(), data.size() ) )
293 );
294
295 std::vector< std::uint8_t > msg_frame =
296 { 0x81, 0x85, 0xAA,0xBB,0xCC,0xDD,
297 0xAA ^ 'H', 0xBB ^ 'e', 0xCC ^ 'l', 0xDD ^ 'l', 0xAA ^ 'o' };
298
299 REQUIRE_NOTHROW(
300 restinio::asio_ns::write( socket, restinio::asio_ns::buffer( msg_frame.data(), msg_frame.size() ) )
301 );
302
303 REQUIRE_NOTHROW(
304 len = socket.read_some( restinio::asio_ns::buffer( data.data(), data.size() ) )
305 );
306
307 REQUIRE( 7 == len );
308 REQUIRE( 0x81 == data[ 0 ] );
309 REQUIRE( 0x05 == data[ 1 ] );
310 REQUIRE( 'H' == data[ 2 ] );
311 REQUIRE( 'e' == data[ 3 ] );
312 REQUIRE( 'l' == data[ 4 ] );
313 REQUIRE( 'l' == data[ 5 ] );
314 REQUIRE( 'o' == data[ 6 ] );
315
316 notificator.get_future().wait();
317
318 std::vector< std::uint8_t > close_frame =
319 {0x88, 0x82, 0xFF,0xFF,0xFF,0xFF, 0xFF ^ 0x03, 0xFF ^ 0xe8 };
320
321 REQUIRE_NOTHROW(
322 restinio::asio_ns::write(
323 socket, restinio::asio_ns::buffer( close_frame.data(), close_frame.size() ) )
324 );
325
326 REQUIRE_NOTHROW(
327 len = socket.read_some( restinio::asio_ns::buffer( data.data(), data.size() ) )
328 );
329
330 REQUIRE( 4 == len );
331 REQUIRE( 0x88 == data[ 0 ] );
332 REQUIRE( 0x02 == data[ 1 ] );
333 REQUIRE( 0x03 == data[ 2 ] );
334 REQUIRE( 0xe8 == data[ 3 ] );
335
336 restinio::asio_ns::error_code ec;
337 len = socket.read_some( restinio::asio_ns::buffer( data.data(), data.size() ), ec );
338 REQUIRE( ec );
339 REQUIRE( restinio::asio_ns::error::eof == ec.value() );
340
341 } );
342
343 sobj.stop_and_join();
344
345 REQUIRE( 1000 == g_last_close_code );
346 }
347