1%% 2%% Licensed to the Apache Software Foundation (ASF) under one 3%% or more contributor license agreements. See the NOTICE file 4%% distributed with this work for additional information 5%% regarding copyright ownership. The ASF licenses this file 6%% to you under the Apache License, Version 2.0 (the 7%% "License"); you may not use this file except in compliance 8%% with the License. You may obtain a copy of the License at 9%% 10%% http://www.apache.org/licenses/LICENSE-2.0 11%% 12%% Unless required by applicable law or agreed to in writing, 13%% software distributed under the License is distributed on an 14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15%% KIND, either express or implied. See the License for the 16%% specific language governing permissions and limitations 17%% under the License. 18%% 19-module(thrift_sslsocket_transport). 20 21-include("thrift_transport_behaviour.hrl"). 22 23-behaviour(thrift_transport). 24 25-export([new/3, 26 write/2, read/2, flush/1, close/1, 27 28 new_transport_factory/3]). 29 30%% Export only for the transport factory 31-export([new/2]). 32 33-record(data, {socket, 34 recv_timeout=infinity}). 35-type state() :: #data{}. 36 37%% The following "local" record is filled in by parse_factory_options/2 38%% below. These options can be passed to new_protocol_factory/3 in a 39%% proplists-style option list. They're parsed like this so it is an O(n) 40%% operation instead of O(n^2) 41-record(factory_opts, {connect_timeout = infinity, 42 sockopts = [], 43 framed = false, 44 ssloptions = []}). 45 46parse_factory_options([], Opts) -> 47 Opts; 48parse_factory_options([{framed, Bool} | Rest], Opts) when is_boolean(Bool) -> 49 parse_factory_options(Rest, Opts#factory_opts{framed=Bool}); 50parse_factory_options([{sockopts, OptList} | Rest], Opts) when is_list(OptList) -> 51 parse_factory_options(Rest, Opts#factory_opts{sockopts=OptList}); 52parse_factory_options([{connect_timeout, TO} | Rest], Opts) when TO =:= infinity; is_integer(TO) -> 53 parse_factory_options(Rest, Opts#factory_opts{connect_timeout=TO}); 54parse_factory_options([{ssloptions, SslOptions} | Rest], Opts) when is_list(SslOptions) -> 55 parse_factory_options(Rest, Opts#factory_opts{ssloptions=SslOptions}). 56 57new(Socket, SockOpts, SslOptions) when is_list(SockOpts), is_list(SslOptions) -> 58 inet:setopts(Socket, [{active, false}]), %% => prevent the ssl handshake messages get lost 59 60 %% upgrade to an ssl socket 61 case catch ssl:ssl_accept(Socket, SslOptions) of % infinite timeout 62 {ok, SslSocket} -> 63 new(SslSocket, SockOpts); 64 {error, Reason} -> 65 exit({error, Reason}); 66 Other -> 67 error_logger:error_report( 68 [{application, thrift}, 69 "SSL accept failed error", 70 lists:flatten(io_lib:format("~p", [Other]))]), 71 exit({error, ssl_accept_failed}) 72 end. 73 74new(SslSocket, SockOpts) -> 75 State = 76 case lists:keysearch(recv_timeout, 1, SockOpts) of 77 {value, {recv_timeout, Timeout}} 78 when is_integer(Timeout), Timeout > 0 -> 79 #data{socket=SslSocket, recv_timeout=Timeout}; 80 _ -> 81 #data{socket=SslSocket} 82 end, 83 thrift_transport:new(?MODULE, State). 84 85%% Data :: iolist() 86write(This = #data{socket = Socket}, Data) -> 87 {This, ssl:send(Socket, Data)}. 88 89read(This = #data{socket=Socket, recv_timeout=Timeout}, Len) 90 when is_integer(Len), Len >= 0 -> 91 case ssl:recv(Socket, Len, Timeout) of 92 Err = {error, timeout} -> 93 error_logger:info_msg("read timeout: peer conn ~p", [inet:peername(Socket)]), 94 ssl:close(Socket), 95 {This, Err}; 96 Data -> 97 {This, Data} 98 end. 99 100%% We can't really flush - everything is flushed when we write 101flush(This) -> 102 {This, ok}. 103 104close(This = #data{socket = Socket}) -> 105 {This, ssl:close(Socket)}. 106 107%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 108 109%% 110%% Generates a "transport factory" function - a fun which returns a thrift_transport() 111%% instance. 112%% This can be passed into a protocol factory to generate a connection to a 113%% thrift server over a socket. 114%% 115new_transport_factory(Host, Port, Options) -> 116 ParsedOpts = parse_factory_options(Options, #factory_opts{}), 117 118 F = fun() -> 119 SockOpts = [binary, 120 {packet, 0}, 121 {active, false}, 122 {nodelay, true} | 123 ParsedOpts#factory_opts.sockopts], 124 case catch gen_tcp:connect(Host, Port, SockOpts, 125 ParsedOpts#factory_opts.connect_timeout) of 126 {ok, Sock} -> 127 SslSock = case catch ssl:connect(Sock, ParsedOpts#factory_opts.ssloptions, 128 ParsedOpts#factory_opts.connect_timeout) of 129 {ok, SslSocket} -> 130 SslSocket; 131 Other -> 132 error_logger:info_msg("error while connecting over ssl - reason: ~p~n", [Other]), 133 catch gen_tcp:close(Sock), 134 exit(error) 135 end, 136 {ok, Transport} = thrift_sslsocket_transport:new(SslSock, SockOpts), 137 {ok, BufTransport} = 138 case ParsedOpts#factory_opts.framed of 139 true -> thrift_framed_transport:new(Transport); 140 false -> thrift_buffered_transport:new(Transport) 141 end, 142 {ok, BufTransport}; 143 Error -> 144 Error 145 end 146 end, 147 {ok, F}.