1%%
2%% Licensed to the Apache Software Foundation (ASF) under one
3%% or more contributor license agreements. See the NOTICE file
4%% distributed with this work for additional information
5%% regarding copyright ownership. The ASF licenses this file
6%% to you under the Apache License, Version 2.0 (the
7%% "License"); you may not use this file except in compliance
8%% with the License. You may obtain a copy of the License at
9%%
10%%   http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing,
13%% software distributed under the License is distributed on an
14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15%% KIND, either express or implied. See the License for the
16%% specific language governing permissions and limitations
17%% under the License.
18%%
19
20-module(thrift_socket_server).
21
22-behaviour(gen_server).
23
24-include ("thrift_constants.hrl").
25
26-ifdef(TEST).
27    -compile(export_all).
28    -export_records([thrift_socket_server]).
29-else.
30    -export([start/1, stop/1]).
31
32    -export([init/1, handle_call/3, handle_cast/2, terminate/2, code_change/3,
33             handle_info/2]).
34
35    -export([acceptor_loop/1]).
36-endif.
37
38-record(thrift_socket_server,
39        {port,
40         service,
41         handler,
42         name,
43         max=2048,
44         ip=any,
45         listen=null,
46         acceptor=null,
47         socket_opts=[{recv_timeout, 500}],
48         protocol=binary,
49         framed=false,
50         ssltransport=false,
51         ssloptions=[]
52        }).
53
54start(State=#thrift_socket_server{}) ->
55    start_server(State);
56start(Options) ->
57    start(parse_options(Options)).
58
59stop(Name) when is_atom(Name) ->
60    gen_server:cast(Name, stop);
61stop(Pid) when is_pid(Pid) ->
62    gen_server:cast(Pid, stop);
63stop({local, Name}) ->
64    stop(Name);
65stop({global, Name}) ->
66    stop(Name);
67stop(Options) ->
68    State = parse_options(Options),
69    stop(State#thrift_socket_server.name).
70
71%% Internal API
72
73parse_options(Options) ->
74    parse_options(Options, #thrift_socket_server{}).
75
76parse_options([], State) ->
77    State;
78parse_options([{name, L} | Rest], State) when is_list(L) ->
79    Name = {local, list_to_atom(L)},
80    parse_options(Rest, State#thrift_socket_server{name=Name});
81parse_options([{name, A} | Rest], State) when is_atom(A) ->
82    Name = {local, A},
83    parse_options(Rest, State#thrift_socket_server{name=Name});
84parse_options([{name, Name} | Rest], State) ->
85    parse_options(Rest, State#thrift_socket_server{name=Name});
86parse_options([{port, L} | Rest], State) when is_list(L) ->
87    Port = list_to_integer(L),
88    parse_options(Rest, State#thrift_socket_server{port=Port});
89parse_options([{port, Port} | Rest], State) ->
90    parse_options(Rest, State#thrift_socket_server{port=Port});
91parse_options([{ip, Ip} | Rest], State) ->
92    ParsedIp = case Ip of
93                   any ->
94                       any;
95                   Ip when is_tuple(Ip) ->
96                       Ip;
97                   Ip when is_list(Ip) ->
98                       {ok, IpTuple} = inet_parse:address(Ip),
99                       IpTuple
100               end,
101    parse_options(Rest, State#thrift_socket_server{ip=ParsedIp});
102parse_options([{socket_opts, L} | Rest], State) when is_list(L), length(L) > 0 ->
103    parse_options(Rest, State#thrift_socket_server{socket_opts=L});
104
105parse_options([{handler, []} | _Rest], _State) ->
106    throw("At least an error handler must be defined.");
107parse_options([{handler, ServiceHandlerPropertyList} | Rest], State) when is_list(ServiceHandlerPropertyList) ->
108    ServiceHandlerMap =
109    case State#thrift_socket_server.handler of
110        undefined ->
111            lists:foldl(
112                fun ({ServiceName, ServiceHandler}, Acc) when is_list(ServiceName), is_atom(ServiceHandler) ->
113                        thrift_multiplexed_map_wrapper:store(ServiceName, ServiceHandler, Acc);
114                    (_, _Acc) ->
115                        throw("The handler option is not properly configured for multiplexed services. It should be a kind of [{\"error_handler\", Module::atom()}, {SericeName::list(), Module::atom()}, ...]")
116                end, thrift_multiplexed_map_wrapper:new(), ServiceHandlerPropertyList);
117        _ -> throw("Error while parsing the handler option.")
118    end,
119    case thrift_multiplexed_map_wrapper:find(?MULTIPLEXED_ERROR_HANDLER_KEY, ServiceHandlerMap) of
120        {ok, _ErrorHandler} -> parse_options(Rest, State#thrift_socket_server{handler=ServiceHandlerMap});
121        error -> throw("The handler option is not properly configured for multiplexed services. It should be a kind of [{\"error_handler\", Module::atom()}, {SericeName::list(), Module::atom()}, ...]")
122    end;
123parse_options([{handler, Handler} | Rest], State) when State#thrift_socket_server.handler == undefined, is_atom(Handler) ->
124    parse_options(Rest, State#thrift_socket_server{handler=Handler});
125
126parse_options([{service, []} | _Rest], _State) ->
127    throw("At least one service module must be defined.");
128parse_options([{service, ServiceModulePropertyList} | Rest], State) when is_list(ServiceModulePropertyList) ->
129    ServiceModuleMap =
130    case State#thrift_socket_server.service of
131        undefined ->
132            lists:foldl(
133                fun ({ServiceName, ServiceModule}, Acc) when is_list(ServiceName), is_atom(ServiceModule) ->
134                        thrift_multiplexed_map_wrapper:store(ServiceName, ServiceModule, Acc);
135                    (_, _Acc) ->
136                        throw("The service option is not properly configured for multiplexed services. It should be a kind of [{SericeName::list(), ServiceModule::atom()}, ...]")
137                end, thrift_multiplexed_map_wrapper:new(), ServiceModulePropertyList);
138        _ -> throw("Error while parsing the service option.")
139    end,
140    parse_options(Rest, State#thrift_socket_server{service=ServiceModuleMap});
141parse_options([{service, Service} | Rest], State) when State#thrift_socket_server.service == undefined, is_atom(Service) ->
142    parse_options(Rest, State#thrift_socket_server{service=Service});
143
144parse_options([{max, Max} | Rest], State) ->
145    MaxInt = case Max of
146                 Max when is_list(Max) ->
147                     list_to_integer(Max);
148                 Max when is_integer(Max) ->
149                     Max
150             end,
151    parse_options(Rest, State#thrift_socket_server{max=MaxInt});
152
153parse_options([{protocol, Proto} | Rest], State) when is_atom(Proto) ->
154    parse_options(Rest, State#thrift_socket_server{protocol=Proto});
155
156parse_options([{framed, Framed} | Rest], State) when is_boolean(Framed) ->
157    parse_options(Rest, State#thrift_socket_server{framed=Framed});
158
159parse_options([{ssltransport, SSLTransport} | Rest], State) when is_boolean(SSLTransport) ->
160    parse_options(Rest, State#thrift_socket_server{ssltransport=SSLTransport});
161parse_options([{ssloptions, SSLOptions} | Rest], State) when is_list(SSLOptions) ->
162    parse_options(Rest, State#thrift_socket_server{ssloptions=SSLOptions}).
163
164start_server(State=#thrift_socket_server{name=Name}) ->
165    case Name of
166        undefined ->
167            gen_server:start_link(?MODULE, State, []);
168        _ ->
169            gen_server:start_link(Name, ?MODULE, State, [])
170    end.
171
172init(State=#thrift_socket_server{ip=Ip, port=Port}) ->
173    process_flag(trap_exit, true),
174    BaseOpts = [binary,
175                {reuseaddr, true},
176                {packet, 0},
177                {backlog, 4096},
178                {recbuf, 8192},
179                {active, false}],
180    Opts = case Ip of
181               any ->
182                   BaseOpts;
183               Ip ->
184                   [{ip, Ip} | BaseOpts]
185           end,
186    case gen_tcp_listen(Port, Opts, State) of
187        {stop, eacces} ->
188            %% fdsrv module allows another shot to bind
189            %% ports which require root access
190            case Port < 1024 of
191                true ->
192                    case fdsrv:start() of
193                        {ok, _} ->
194                            case fdsrv:bind_socket(tcp, Port) of
195                                {ok, Fd} ->
196                                    gen_tcp_listen(Port, [{fd, Fd} | Opts], State);
197                                _ ->
198                                    {stop, fdsrv_bind_failed}
199                            end;
200                        _ ->
201                            {stop, fdsrv_start_failed}
202                    end;
203                false ->
204                    {stop, eacces}
205            end;
206        Other ->
207            error_logger:info_msg("thrift service listening on port ~p", [Port]),
208            Other
209    end.
210
211gen_tcp_listen(Port, Opts, State) ->
212    case gen_tcp:listen(Port, Opts) of
213        {ok, Listen} ->
214            {ok, ListenPort} = inet:port(Listen),
215            {ok, new_acceptor(State#thrift_socket_server{listen=Listen,
216                                                         port=ListenPort})};
217        {error, Reason} ->
218            {stop, Reason}
219    end.
220
221new_acceptor(State=#thrift_socket_server{max=0}) ->
222    error_logger:error_msg("Not accepting new connections"),
223    State#thrift_socket_server{acceptor=null};
224new_acceptor(State=#thrift_socket_server{listen=Listen,
225                                         service=Service, handler=Handler,
226                                         socket_opts=Opts, framed=Framed, protocol=Proto,
227                                         ssltransport=SslTransport, ssloptions=SslOptions
228                                        }) ->
229    Pid = proc_lib:spawn_link(?MODULE, acceptor_loop,
230                              [{self(), Listen, Service, Handler, Opts, Framed, SslTransport, SslOptions, Proto}]),
231    State#thrift_socket_server{acceptor=Pid}.
232
233acceptor_loop({Server, Listen, Service, Handler, SocketOpts, Framed, SslTransport, SslOptions, Proto})
234  when is_pid(Server), is_list(SocketOpts) ->
235    case catch gen_tcp:accept(Listen) of % infinite timeout
236        {ok, Socket} ->
237            gen_server:cast(Server, {accepted, self()}),
238            ProtoGen = fun() ->
239                               {ok, SocketTransport} = case SslTransport of
240                                                           true  -> thrift_sslsocket_transport:new(Socket, SocketOpts, SslOptions);
241                                                           false -> thrift_socket_transport:new(Socket, SocketOpts)
242                                                       end,
243                               {ok, Transport}       = case Framed of
244                                                           true  -> thrift_framed_transport:new(SocketTransport);
245                                                           false -> thrift_buffered_transport:new(SocketTransport)
246                                                       end,
247                               {ok, Protocol}        = case Proto of
248                                                         compact -> thrift_compact_protocol:new(Transport);
249                                                         json -> thrift_json_protocol:new(Transport);
250                                                         _ -> thrift_binary_protocol:new(Transport)
251                                                       end,
252                               {ok, Protocol}
253                       end,
254            thrift_processor:init({Server, ProtoGen, Service, Handler});
255        {error, closed} ->
256            exit({error, closed});
257        Other ->
258            error_logger:error_report(
259              [{application, thrift},
260               "Accept failed error",
261               lists:flatten(io_lib:format("~p", [Other]))]),
262            exit({error, accept_failed})
263    end.
264
265handle_call({get, port}, _From, State=#thrift_socket_server{port=Port}) ->
266    {reply, Port, State};
267handle_call(_Message, _From, State) ->
268    Res = error,
269    {reply, Res, State}.
270
271handle_cast({accepted, Pid},
272            State=#thrift_socket_server{acceptor=Pid, max=Max}) ->
273    % io:format("accepted ~p~n", [Pid]),
274    State1 = State#thrift_socket_server{max=Max - 1},
275    {noreply, new_acceptor(State1)};
276handle_cast(stop, State) ->
277    {stop, normal, State}.
278
279terminate(Reason, #thrift_socket_server{listen=Listen, port=Port}) ->
280    gen_tcp:close(Listen),
281    case Reason of
282        normal -> ok;
283        shutdown -> ok;
284        _ -> {backtrace, Bt} = erlang:process_info(self(), backtrace),
285             error_logger:error_report({?MODULE, ?LINE,
286                                       {child_error, Reason, Bt}})
287    end,
288    case Port < 1024 of
289        true ->
290            catch fdsrv:stop(),
291            ok;
292        false ->
293            ok
294    end.
295
296code_change(_OldVsn, State, _Extra) ->
297    State.
298
299handle_info({'EXIT', Pid, normal},
300            State=#thrift_socket_server{acceptor=Pid}) ->
301    {noreply, new_acceptor(State)};
302handle_info({'EXIT', Pid, Reason},
303            State=#thrift_socket_server{acceptor=Pid}) ->
304    error_logger:error_report({?MODULE, ?LINE,
305                               {acceptor_error, Reason}}),
306    timer:sleep(100),
307    {noreply, new_acceptor(State)};
308handle_info({'EXIT', _LoopPid, Reason},
309            State=#thrift_socket_server{acceptor=Pid, max=Max}) ->
310    case Reason of
311        normal -> ok;
312        shutdown -> ok;
313        _ -> error_logger:error_report({?MODULE, ?LINE,
314                                        {child_error, Reason, erlang:get_stacktrace()}})
315    end,
316    State1 = State#thrift_socket_server{max=Max + 1},
317    State2 = case Pid of
318                 null -> new_acceptor(State1);
319                 _ -> State1
320             end,
321    {noreply, State2};
322handle_info(Info, State) ->
323    error_logger:info_report([{'INFO', Info}, {'State', State}]),
324    {noreply, State}.
325