1%%
2%% Licensed to the Apache Software Foundation (ASF) under one
3%% or more contributor license agreements. See the NOTICE file
4%% distributed with this work for additional information
5%% regarding copyright ownership. The ASF licenses this file
6%% to you under the Apache License, Version 2.0 (the
7%% "License"); you may not use this file except in compliance
8%% with the License. You may obtain a copy of the License at
9%%
10%%   http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing,
13%% software distributed under the License is distributed on an
14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15%% KIND, either express or implied. See the License for the
16%% specific language governing permissions and limitations
17%% under the License.
18%%
19
20-module(thrift_protocol).
21
22-export([new/2,
23         write/2,
24         read/2,
25         read/3,
26         skip/2,
27         flush_transport/1,
28         close_transport/1,
29         typeid_to_atom/1
30        ]).
31
32-export([behaviour_info/1]).
33
34-include("thrift_constants.hrl").
35-include("thrift_protocol.hrl").
36
37-record(protocol, {module, data}).
38
39behaviour_info(callbacks) ->
40    [
41     {read, 2},
42     {write, 2},
43     {flush_transport, 1},
44     {close_transport, 1}
45    ];
46behaviour_info(_Else) -> undefined.
47
48new(Module, Data) when is_atom(Module) ->
49    {ok, #protocol{module = Module,
50                   data = Data}}.
51
52-spec flush_transport(#protocol{}) -> {#protocol{}, ok}.
53flush_transport(Proto = #protocol{module = Module,
54                                  data = Data}) ->
55    {NewData, Result} = Module:flush_transport(Data),
56    {Proto#protocol{data = NewData}, Result}.
57
58-spec close_transport(#protocol{}) -> ok.
59close_transport(#protocol{module = Module,
60                          data = Data}) ->
61    Module:close_transport(Data).
62
63typeid_to_atom(?tType_STOP) -> field_stop;
64typeid_to_atom(?tType_VOID) -> void;
65typeid_to_atom(?tType_BOOL) -> bool;
66typeid_to_atom(?tType_DOUBLE) -> double;
67typeid_to_atom(?tType_I8) -> byte;
68typeid_to_atom(?tType_I16) -> i16;
69typeid_to_atom(?tType_I32) -> i32;
70typeid_to_atom(?tType_I64) -> i64;
71typeid_to_atom(?tType_STRING) -> string;
72typeid_to_atom(?tType_STRUCT) -> struct;
73typeid_to_atom(?tType_MAP) -> map;
74typeid_to_atom(?tType_SET) -> set;
75typeid_to_atom(?tType_LIST) -> list.
76
77term_to_typeid(void) -> ?tType_VOID;
78term_to_typeid(bool) -> ?tType_BOOL;
79term_to_typeid(byte) -> ?tType_I8;
80term_to_typeid(double) -> ?tType_DOUBLE;
81term_to_typeid(i8) -> ?tType_I8;
82term_to_typeid(i16) -> ?tType_I16;
83term_to_typeid(i32) -> ?tType_I32;
84term_to_typeid(i64) -> ?tType_I64;
85term_to_typeid(string) -> ?tType_STRING;
86term_to_typeid({struct, _}) -> ?tType_STRUCT;
87term_to_typeid({map, _, _}) -> ?tType_MAP;
88term_to_typeid({set, _}) -> ?tType_SET;
89term_to_typeid({list, _}) -> ?tType_LIST.
90
91%% Structure is like:
92%%    [{Fid, Type}, ...]
93-spec read(#protocol{}, {struct, _StructDef}, atom()) -> {#protocol{}, {ok, tuple()}}.
94read(IProto0, {struct, Structure}, Tag)
95  when is_list(Structure), is_atom(Tag) ->
96
97    % If we want a tagged tuple, we need to offset all the tuple indices
98    % by 1 to avoid overwriting the tag.
99    Offset = if Tag =/= undefined -> 1; true -> 0 end,
100    IndexList = case length(Structure) of
101                    N when N > 0 -> lists:seq(1 + Offset, N + Offset);
102                    _ -> []
103                end,
104
105    SWithIndices = [{Fid, {Type, Index}} ||
106                       {{Fid, Type}, Index} <-
107                           lists:zip(Structure, IndexList)],
108    % Fid -> {Type, Index}
109    SDict = dict:from_list(SWithIndices),
110
111    {IProto1, ok} = read(IProto0, struct_begin),
112    RTuple0 = erlang:make_tuple(length(Structure) + Offset, undefined),
113    RTuple1 = if Tag =/= undefined -> setelement(1, RTuple0, Tag);
114                 true              -> RTuple0
115              end,
116
117    {IProto2, RTuple2} = read_struct_loop(IProto1, SDict, RTuple1),
118    {IProto2, {ok, RTuple2}}.
119
120
121%% NOTE: Keep this in sync with thrift_protocol_behaviour:read
122-spec read
123        (#protocol{}, {struct, _Info}) ->    {#protocol{}, {ok, tuple()}      | {error, _Reason}};
124        (#protocol{}, tprot_cont_tag()) ->   {#protocol{}, {ok, any()}        | {error, _Reason}};
125        (#protocol{}, tprot_empty_tag()) ->  {#protocol{},  ok                | {error, _Reason}};
126        (#protocol{}, tprot_header_tag()) -> {#protocol{}, tprot_header_val() | {error, _Reason}};
127        (#protocol{}, tprot_data_tag()) ->   {#protocol{}, {ok, any()}        | {error, _Reason}}.
128
129read(IProto, {struct, {Module, StructureName}}) when is_atom(Module),
130                                                     is_atom(StructureName) ->
131    read(IProto, Module:struct_info(StructureName), StructureName);
132
133read(IProto, S={struct, Structure}) when is_list(Structure) ->
134    read(IProto, S, undefined);
135
136read(IProto0, {list, Type}) ->
137    {IProto1, #protocol_list_begin{etype = EType, size = Size}} =
138        read(IProto0, list_begin),
139    {EType, EType} = {term_to_typeid(Type), EType},
140    {List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
141                                             {ProtoS1, {ok, Item}} = read(ProtoS0, Type),
142                                             {Item, ProtoS1}
143                                     end,
144                                     IProto1,
145                                     lists:duplicate(Size, 0)),
146    {IProto3, ok} = read(IProto2, list_end),
147    {IProto3, {ok, List}};
148
149read(IProto0, {map, KeyType, ValType}) ->
150    {IProto1, #protocol_map_begin{size = Size, ktype = KType, vtype = VType}} =
151        read(IProto0, map_begin),
152    _ = case Size of
153      0 -> 0;
154      _ ->
155        {KType, KType} = {term_to_typeid(KeyType), KType},
156        {VType, VType} = {term_to_typeid(ValType), VType}
157    end,
158    {List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
159                                             {ProtoS1, {ok, Key}} = read(ProtoS0, KeyType),
160                                             {ProtoS2, {ok, Val}} = read(ProtoS1, ValType),
161                                             {{Key, Val}, ProtoS2}
162                                     end,
163                                     IProto1,
164                                     lists:duplicate(Size, 0)),
165    {IProto3, ok} = read(IProto2, map_end),
166    {IProto3, {ok, dict:from_list(List)}};
167
168read(IProto0, {set, Type}) ->
169    {IProto1, #protocol_set_begin{etype = EType, size = Size}} =
170        read(IProto0, set_begin),
171    {EType, EType} = {term_to_typeid(Type), EType},
172    {List, IProto2} = lists:mapfoldl(fun(_, ProtoS0) ->
173                                             {ProtoS1, {ok, Item}} = read(ProtoS0, Type),
174                                             {Item, ProtoS1}
175                                     end,
176                                     IProto1,
177                                     lists:duplicate(Size, 0)),
178    {IProto3, ok} = read(IProto2, set_end),
179    {IProto3, {ok, sets:from_list(List)}};
180
181read(Protocol, ProtocolType) ->
182    read_specific(Protocol, ProtocolType).
183
184%% NOTE: Keep this in sync with thrift_protocol_behaviour:read
185-spec read_specific
186        (#protocol{}, tprot_empty_tag()) ->  {#protocol{},  ok                | {error, _Reason}};
187        (#protocol{}, tprot_header_tag()) -> {#protocol{}, tprot_header_val() | {error, _Reason}};
188        (#protocol{}, tprot_data_tag()) ->   {#protocol{}, {ok, any()}        | {error, _Reason}}.
189read_specific(Proto = #protocol{module = Module,
190                                data = ModuleData}, ProtocolType) ->
191    {NewData, Result} = Module:read(ModuleData, ProtocolType),
192    {Proto#protocol{data = NewData}, Result}.
193
194read_struct_loop(IProto0, SDict, RTuple) ->
195    {IProto1, #protocol_field_begin{type = FType, id = Fid}} =
196        thrift_protocol:read(IProto0, field_begin),
197    case {FType, Fid} of
198        {?tType_STOP, _} ->
199            {IProto2, ok} = read(IProto1, struct_end),
200            {IProto2, RTuple};
201        _Else ->
202            case dict:find(Fid, SDict) of
203                {ok, {Type, Index}} ->
204                    case term_to_typeid(Type) of
205                        FType ->
206                            {IProto2, {ok, Val}} = read(IProto1, Type),
207                            {IProto3, ok} = thrift_protocol:read(IProto2, field_end),
208                            NewRTuple = setelement(Index, RTuple, Val),
209                            read_struct_loop(IProto3, SDict, NewRTuple);
210                        Expected ->
211                            error_logger:info_msg(
212                              "Skipping field ~p with wrong type (~p != ~p)~n",
213                              [Fid, FType, Expected]),
214                            skip_field(FType, IProto1, SDict, RTuple)
215                    end;
216                _Else2 ->
217                    skip_field(FType, IProto1, SDict, RTuple)
218            end
219    end.
220
221skip_field(FType, IProto0, SDict, RTuple) ->
222    {IProto1, ok} = skip(IProto0, typeid_to_atom(FType)),
223    {IProto2, ok} = read(IProto1, field_end),
224    read_struct_loop(IProto2, SDict, RTuple).
225
226-spec skip(#protocol{}, atom()) -> {#protocol{}, ok}.
227
228skip(Proto0, struct) ->
229    {Proto1, ok} = read(Proto0, struct_begin),
230    {Proto2, ok} = skip_struct_loop(Proto1),
231    {Proto3, ok} = read(Proto2, struct_end),
232    {Proto3, ok};
233
234skip(Proto0, map) ->
235    {Proto1, Map} = read(Proto0, map_begin),
236    {Proto2, ok} = skip_map_loop(Proto1, Map),
237    {Proto3, ok} = read(Proto2, map_end),
238    {Proto3, ok};
239
240skip(Proto0, set) ->
241    {Proto1, Set} = read(Proto0, set_begin),
242    {Proto2, ok} = skip_set_loop(Proto1, Set),
243    {Proto3, ok} = read(Proto2, set_end),
244    {Proto3, ok};
245
246skip(Proto0, list) ->
247    {Proto1, List} = read(Proto0, list_begin),
248    {Proto2, ok} = skip_list_loop(Proto1, List),
249    {Proto3, ok} = read(Proto2, list_end),
250    {Proto3, ok};
251
252skip(Proto0, Type) when is_atom(Type) ->
253    {Proto1, _Ignore} = read(Proto0, Type),
254    {Proto1, ok}.
255
256
257skip_struct_loop(Proto0) ->
258    {Proto1, #protocol_field_begin{type = Type}} = read(Proto0, field_begin),
259    case Type of
260        ?tType_STOP ->
261            {Proto1, ok};
262        _Else ->
263            {Proto2, ok} = skip(Proto1, typeid_to_atom(Type)),
264            {Proto3, ok} = read(Proto2, field_end),
265            skip_struct_loop(Proto3)
266    end.
267
268skip_map_loop(Proto0, Map = #protocol_map_begin{ktype = Ktype,
269                                                vtype = Vtype,
270                                                size = Size}) ->
271    case Size of
272        N when N > 0 ->
273            {Proto1, ok} = skip(Proto0, typeid_to_atom(Ktype)),
274            {Proto2, ok} = skip(Proto1, typeid_to_atom(Vtype)),
275            skip_map_loop(Proto2,
276                          Map#protocol_map_begin{size = Size - 1});
277        0 -> {Proto0, ok}
278    end.
279
280skip_set_loop(Proto0, Map = #protocol_set_begin{etype = Etype,
281                                                size = Size}) ->
282    case Size of
283        N when N > 0 ->
284            {Proto1, ok} = skip(Proto0, typeid_to_atom(Etype)),
285            skip_set_loop(Proto1,
286                          Map#protocol_set_begin{size = Size - 1});
287        0 -> {Proto0, ok}
288    end.
289
290skip_list_loop(Proto0, Map = #protocol_list_begin{etype = Etype,
291                                                  size = Size}) ->
292    case Size of
293        N when N > 0 ->
294            {Proto1, ok} = skip(Proto0, typeid_to_atom(Etype)),
295            skip_list_loop(Proto1,
296                           Map#protocol_list_begin{size = Size - 1});
297        0 -> {Proto0, ok}
298    end.
299
300
301%%--------------------------------------------------------------------
302%% Function: write(OProto, {Type, Data}) -> ok
303%%
304%% Type = {struct, StructDef} |
305%%        {list, Type} |
306%%        {map, KeyType, ValType} |
307%%        {set, Type} |
308%%        BaseType
309%%
310%% Data =
311%%         tuple()  -- for struct
312%%       | list()   -- for list
313%%       | dictionary()   -- for map
314%%       | set()    -- for set
315%%       | any()    -- for base types
316%%
317%% Description:
318%%--------------------------------------------------------------------
319-spec write(#protocol{}, any()) -> {#protocol{}, ok | {error, _Reason}}.
320
321write(Proto0, {{struct, StructDef}, Data})
322  when is_list(StructDef), is_tuple(Data), length(StructDef) == size(Data) - 1 ->
323
324    [StructName | Elems] = tuple_to_list(Data),
325    {Proto1, ok} = write(Proto0, #protocol_struct_begin{name = StructName}),
326    {Proto2, ok} = struct_write_loop(Proto1, StructDef, Elems),
327    {Proto3, ok} = write(Proto2, struct_end),
328    {Proto3, ok};
329
330write(Proto, {{struct, {Module, StructureName}}, Data})
331  when is_atom(Module),
332       is_atom(StructureName),
333       element(1, Data) =:= StructureName ->
334    write(Proto, {Module:struct_info(StructureName), Data});
335
336write(_, {{struct, {Module, StructureName}}, Data})
337  when is_atom(Module),
338       is_atom(StructureName) ->
339    erlang:error(struct_unmatched, {{provided, element(1, Data)},
340                             {expected, StructureName}});
341
342write(Proto0, {{list, Type}, Data})
343  when is_list(Data) ->
344    {Proto1, ok} = write(Proto0,
345               #protocol_list_begin{
346                 etype = term_to_typeid(Type),
347                 size = length(Data)
348                }),
349    Proto2 = lists:foldl(fun(Elem, ProtoIn) ->
350                                 {ProtoOut, ok} = write(ProtoIn, {Type, Elem}),
351                                 ProtoOut
352                         end,
353                         Proto1,
354                         Data),
355    {Proto3, ok} = write(Proto2, list_end),
356    {Proto3, ok};
357
358write(Proto0, {{map, KeyType, ValType}, Data}) ->
359    {Proto1, ok} = write(Proto0,
360                         #protocol_map_begin{
361                           ktype = term_to_typeid(KeyType),
362                           vtype = term_to_typeid(ValType),
363                           size  = dict:size(Data)
364                          }),
365    Proto2 = dict:fold(fun(KeyData, ValData, ProtoS0) ->
366                               {ProtoS1, ok} = write(ProtoS0, {KeyType, KeyData}),
367                               {ProtoS2, ok} = write(ProtoS1, {ValType, ValData}),
368                               ProtoS2
369                       end,
370                       Proto1,
371                       Data),
372    {Proto3, ok} = write(Proto2, map_end),
373    {Proto3, ok};
374
375write(Proto0, {{set, Type}, Data}) ->
376    true = sets:is_set(Data),
377    {Proto1, ok} = write(Proto0,
378                         #protocol_set_begin{
379                           etype = term_to_typeid(Type),
380                           size  = sets:size(Data)
381                          }),
382    Proto2 = sets:fold(fun(Elem, ProtoIn) ->
383                               {ProtoOut, ok} = write(ProtoIn, {Type, Elem}),
384                               ProtoOut
385                       end,
386                       Proto1,
387                       Data),
388    {Proto3, ok} = write(Proto2, set_end),
389    {Proto3, ok};
390
391write(Proto = #protocol{module = Module,
392                        data = ModuleData}, Data) ->
393    {NewData, Result} = Module:write(ModuleData, Data),
394    {Proto#protocol{data = NewData}, Result}.
395
396struct_write_loop(Proto0, [{Fid, Type} | RestStructDef], [Data | RestData]) ->
397    NewProto = case Data of
398                   undefined ->
399                       Proto0; % null fields are skipped in response
400                   _ ->
401                       {Proto1, ok} = write(Proto0,
402                                           #protocol_field_begin{
403                                             type = term_to_typeid(Type),
404                                             id = Fid
405                                            }),
406                       {Proto2, ok} = write(Proto1, {Type, Data}),
407                       {Proto3, ok} = write(Proto2, field_end),
408                       Proto3
409               end,
410    struct_write_loop(NewProto, RestStructDef, RestData);
411struct_write_loop(Proto, [], []) ->
412    write(Proto, field_stop).
413