1%%
2%% Licensed to the Apache Software Foundation (ASF) under one
3%% or more contributor license agreements. See the NOTICE file
4%% distributed with this work for additional information
5%% regarding copyright ownership. The ASF licenses this file
6%% to you under the Apache License, Version 2.0 (the
7%% "License"); you may not use this file except in compliance
8%% with the License. You may obtain a copy of the License at
9%%
10%%   http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing,
13%% software distributed under the License is distributed on an
14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15%% KIND, either express or implied. See the License for the
16%% specific language governing permissions and limitations
17%% under the License.
18%%
19-module(thrift_sslsocket_transport).
20
21-include("thrift_transport_behaviour.hrl").
22
23-behaviour(thrift_transport).
24
25-export([new/3,
26         write/2, read/2, flush/1, close/1,
27
28         new_transport_factory/3]).
29
30%% Export only for the transport factory
31-export([new/2]).
32
33-record(data, {socket,
34               recv_timeout=infinity}).
35-type state() :: #data{}.
36
37%% The following "local" record is filled in by parse_factory_options/2
38%% below. These options can be passed to new_protocol_factory/3 in a
39%% proplists-style option list. They're parsed like this so it is an O(n)
40%% operation instead of O(n^2)
41-record(factory_opts, {connect_timeout = infinity,
42                       sockopts = [],
43                       framed = false,
44                       ssloptions = []}).
45
46parse_factory_options([], Opts) ->
47    Opts;
48parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) ->
49    parse_factory_options(Rest, Opts#factory_opts{framed=Bool});
50parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) ->
51    parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList});
52parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) ->
53    parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO});
54parse_factory_options([{ssloptions, SslOptions} | Rest], Opts) when is_list(SslOptions) ->
55    parse_factory_options(Rest, Opts#factory_opts{ssloptions=SslOptions}).
56
57new(Socket, SockOpts, SslOptions) when is_list(SockOpts), is_list(SslOptions) ->
58    inet:setopts(Socket, [{active, false}]), %% => prevent the ssl handshake messages get lost
59
60    %% upgrade to an ssl socket
61    case catch ssl:ssl_accept(Socket, SslOptions) of % infinite timeout
62        {ok, SslSocket} ->
63            new(SslSocket, SockOpts);
64        {error, Reason} ->
65            exit({error, Reason});
66        Other ->
67            error_logger:error_report(
68              [{application, thrift},
69               "SSL accept failed error",
70               lists:flatten(io_lib:format("~p", [Other]))]),
71            exit({error, ssl_accept_failed})
72    end.
73
74new(SslSocket, SockOpts) ->
75    State =
76        case lists:keysearch(recv_timeout, 1, SockOpts) of
77            {value, {recv_timeout, Timeout}}
78              when is_integer(Timeout), Timeout > 0 ->
79                #data{socket=SslSocket, recv_timeout=Timeout};
80            _ ->
81                #data{socket=SslSocket}
82        end,
83    thrift_transport:new(?MODULE, State).
84
85%% Data :: iolist()
86write(This = #data{socket = Socket}, Data) ->
87    {This, ssl:send(Socket, Data)}.
88
89read(This = #data{socket=Socket, recv_timeout=Timeout}, Len)
90  when is_integer(Len), Len >= 0 ->
91    case ssl:recv(Socket, Len, Timeout) of
92        Err = {error, timeout} ->
93            error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]),
94            ssl:close(Socket),
95            {This, Err};
96        Data ->
97            {This, Data}
98    end.
99
100%% We can't really flush - everything is flushed when we write
101flush(This) ->
102    {This, ok}.
103
104close(This = #data{socket = Socket}) ->
105    {This, ssl:close(Socket)}.
106
107%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
108
109%%
110%% Generates a "transport factory" function - a fun which returns a thrift_transport()
111%% instance.
112%% This can be passed into a protocol factory to generate a connection to a
113%% thrift server over a socket.
114%%
115new_transport_factory(Host, Port, Options) ->
116    ParsedOpts = parse_factory_options(Options, #factory_opts{}),
117
118    F = fun() ->
119                SockOpts = [binary,
120                            {packet, 0},
121                            {active, false},
122                            {nodelay, true} |
123                            ParsedOpts#factory_opts.sockopts],
124                case catch gen_tcp:connect(Host, Port, SockOpts,
125                                           ParsedOpts#factory_opts.connect_timeout) of
126                    {ok, Sock} ->
127                        SslSock = case catch ssl:connect(Sock, ParsedOpts#factory_opts.ssloptions,
128                                                         ParsedOpts#factory_opts.connect_timeout) of
129                                      {ok, SslSocket} ->
130                                          SslSocket;
131                                      Other ->
132                                          error_logger:info_msg("error while connecting over ssl - reason: ~p~n", [Other]),
133                                          catch gen_tcp:close(Sock),
134                                          exit(error)
135                                  end,
136                        {ok, Transport} = thrift_sslsocket_transport:new(SslSock, SockOpts),
137                        {ok, BufTransport} =
138                            case ParsedOpts#factory_opts.framed of
139                                true  -> thrift_framed_transport:new(Transport);
140                                false -> thrift_buffered_transport:new(Transport)
141                            end,
142                        {ok, BufTransport};
143                    Error  ->
144                        Error
145                end
146        end,
147    {ok, F}.