1%%
2%% Licensed to the Apache Software Foundation (ASF) under one
3%% or more contributor license agreements. See the NOTICE file
4%% distributed with this work for additional information
5%% regarding copyright ownership. The ASF licenses this file
6%% to you under the Apache License, Version 2.0 (the
7%% "License"); you may not use this file except in compliance
8%% with the License. You may obtain a copy of the License at
9%%
10%%   http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing,
13%% software distributed under the License is distributed on an
14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15%% KIND, either express or implied. See the License for the
16%% specific language governing permissions and limitations
17%% under the License.
18%%
19
20-module(thrift_client).
21
22%% API
23-export([new/2, call/3, send_call/3, close/1]).
24
25-include("thrift_constants.hrl").
26-include("thrift_protocol.hrl").
27
28-record(tclient, {service, protocol, seqid}).
29
30
31new(Protocol, Service)
32  when is_atom(Service) ->
33    {ok, #tclient{protocol = Protocol,
34                  service = Service,
35                  seqid = 0}}.
36
37-spec call(#tclient{}, atom(), list()) -> {#tclient{}, {ok, any()} | {error, any()}}.
38call(Client = #tclient{}, Function, Args)
39when is_atom(Function), is_list(Args) ->
40  case send_function_call(Client, Function, Args) of
41    {ok, Client1} -> receive_function_result(Client1, Function);
42    {{error, X}, Client1} -> {Client1, {error, X}};
43    Else -> Else
44  end.
45
46
47%% Sends a function call but does not read the result. This is useful
48%% if you're trying to log non-oneway function calls to write-only
49%% transports like thrift_disk_log_transport.
50-spec send_call(#tclient{}, atom(), list()) -> {#tclient{}, ok}.
51send_call(Client = #tclient{}, Function, Args)
52  when is_atom(Function), is_list(Args) ->
53    case send_function_call(Client, Function, Args) of
54      {ok, Client1} -> {Client1, ok};
55      Else -> Else
56    end.
57
58-spec close(#tclient{}) -> ok.
59close(#tclient{protocol=Protocol}) ->
60    thrift_protocol:close_transport(Protocol).
61
62
63%%--------------------------------------------------------------------
64%%% Internal functions
65%%--------------------------------------------------------------------
66-spec send_function_call(#tclient{}, atom(), list()) -> {ok | {error, any()}, #tclient{}}.
67send_function_call(Client = #tclient{service = Service}, Function, Args) ->
68  {Params, Reply} = try
69    {Service:function_info(Function, params_type), Service:function_info(Function, reply_type)}
70  catch error:function_clause -> {no_function, 0}
71  end,
72  MsgType = case Reply of
73    oneway_void -> ?tMessageType_ONEWAY;
74    _ -> ?tMessageType_CALL
75  end,
76  case Params of
77    no_function ->
78      {{error, {no_function, Function}}, Client};
79    {struct, PList} when length(PList) =/= length(Args) ->
80      {{error, {bad_args, Function, Args}}, Client};
81    {struct, _PList} -> write_message(Client, Function, Args, Params, MsgType)
82  end.
83
84-spec write_message(#tclient{}, atom(), list(), {struct, list()}, integer()) ->
85  {ok | {error, any()}, #tclient{}}.
86write_message(Client = #tclient{protocol = P0, seqid = Seq}, Function, Args, Params, MsgType) ->
87  try
88    {P1, ok} = thrift_protocol:write(P0, #protocol_message_begin{
89      name = atom_to_list(Function),
90      type = MsgType,
91      seqid = Seq
92    }),
93    {P2, ok} = thrift_protocol:write(P1, {Params, list_to_tuple([Function|Args])}),
94    {P3, ok} = thrift_protocol:write(P2, message_end),
95    {P4, ok} = thrift_protocol:flush_transport(P3),
96    {ok, Client#tclient{protocol = P4}}
97  catch
98    error:{badmatch, {_, {error, _} = Error}} -> {Error, Client}
99  end.
100
101-spec receive_function_result(#tclient{}, atom()) -> {#tclient{}, {ok, any()} | {error, any()}}.
102receive_function_result(Client = #tclient{service = Service}, Function) ->
103    ResultType = Service:function_info(Function, reply_type),
104    read_result(Client, Function, ResultType).
105
106read_result(Client, _Function, oneway_void) ->
107    {Client, {ok, ok}};
108
109read_result(Client = #tclient{protocol = Proto0,
110                              seqid    = SeqId},
111            Function,
112            ReplyType) ->
113    case thrift_protocol:read(Proto0, message_begin) of
114         {Proto1, {error, Reason}} ->
115             NewClient = Client#tclient{protocol = Proto1},
116             {NewClient, {error, Reason}};
117         {Proto1, MessageBegin} ->
118             NewClient = Client#tclient{protocol = Proto1},
119             case MessageBegin of
120                 #protocol_message_begin{seqid = RetSeqId} when RetSeqId =/= SeqId ->
121                     {NewClient, {error, {bad_seq_id, SeqId}}};
122                 #protocol_message_begin{type = ?tMessageType_EXCEPTION} ->
123                     handle_application_exception(NewClient);
124                 #protocol_message_begin{type = ?tMessageType_REPLY} ->
125                     handle_reply(NewClient, Function, ReplyType)
126             end
127    end.
128
129
130handle_reply(Client = #tclient{protocol = Proto0,
131                               service = Service},
132             Function,
133             ReplyType) ->
134    {struct, ExceptionFields} = Service:function_info(Function, exceptions),
135    ReplyStructDef = {struct, [{0, ReplyType}] ++ ExceptionFields},
136    {Proto1, {ok, Reply}} = thrift_protocol:read(Proto0, ReplyStructDef),
137    {Proto2, ok} = thrift_protocol:read(Proto1, message_end),
138    NewClient = Client#tclient{protocol = Proto2},
139    ReplyList = tuple_to_list(Reply),
140    true = length(ReplyList) == length(ExceptionFields) + 1,
141    ExceptionVals = tl(ReplyList),
142    Thrown = [X || X <- ExceptionVals,
143                   X =/= undefined],
144    case Thrown of
145        [] when ReplyType == {struct, []} ->
146            {NewClient, {ok, ok}};
147        [] ->
148            {NewClient, {ok, hd(ReplyList)}};
149        [Exception] ->
150            throw({NewClient, {exception, Exception}})
151    end.
152
153handle_application_exception(Client = #tclient{protocol = Proto0}) ->
154    {Proto1, {ok, Exception}} =
155        thrift_protocol:read(Proto0, ?TApplicationException_Structure),
156    {Proto2, ok} = thrift_protocol:read(Proto1, message_end),
157    XRecord = list_to_tuple(
158                ['TApplicationException' | tuple_to_list(Exception)]),
159    error_logger:error_msg("X: ~p~n", [XRecord]),
160    true = is_record(XRecord, 'TApplicationException'),
161    NewClient = Client#tclient{protocol = Proto2},
162    throw({NewClient, {exception, XRecord}}).
163