1%%
2%% Licensed to the Apache Software Foundation (ASF) under one
3%% or more contributor license agreements. See the NOTICE file
4%% distributed with this work for additional information
5%% regarding copyright ownership. The ASF licenses this file
6%% to you under the Apache License, Version 2.0 (the
7%% "License"); you may not use this file except in compliance
8%% with the License. You may obtain a copy of the License at
9%%
10%%   http://www.apache.org/licenses/LICENSE-2.0
11%%
12%% Unless required by applicable law or agreed to in writing,
13%% software distributed under the License is distributed on an
14%% "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15%% KIND, either express or implied. See the License for the
16%% specific language governing permissions and limitations
17%% under the License.
18%%
19
20-module(thrift_binary_protocol).
21
22-behaviour(thrift_protocol).
23
24-include("thrift_constants.hrl").
25-include("thrift_protocol.hrl").
26
27-export([new/1, new/2,
28         read/2,
29         write/2,
30         flush_transport/1,
31         close_transport/1,
32
33         new_protocol_factory/2
34        ]).
35
36-record(binary_protocol, {transport,
37                          strict_read=true,
38                          strict_write=true
39                         }).
40-type state() :: #binary_protocol{}.
41-include("thrift_protocol_behaviour.hrl").
42
43-define(VERSION_MASK, 16#FFFF0000).
44-define(VERSION_1, 16#80010000).
45-define(TYPE_MASK, 16#000000ff).
46
47new(Transport) ->
48    new(Transport, _Options = []).
49
50new(Transport, Options) ->
51    State  = #binary_protocol{transport = Transport},
52    State1 = parse_options(Options, State),
53    thrift_protocol:new(?MODULE, State1).
54
55parse_options([], State) ->
56    State;
57parse_options([{strict_read, Bool} | Rest], State) when is_boolean(Bool) ->
58    parse_options(Rest, State#binary_protocol{strict_read=Bool});
59parse_options([{strict_write, Bool} | Rest], State) when is_boolean(Bool) ->
60    parse_options(Rest, State#binary_protocol{strict_write=Bool}).
61
62
63flush_transport(This = #binary_protocol{transport = Transport}) ->
64    {NewTransport, Result} = thrift_transport:flush(Transport),
65    {This#binary_protocol{transport = NewTransport}, Result}.
66
67close_transport(This = #binary_protocol{transport = Transport}) ->
68    {NewTransport, Result} = thrift_transport:close(Transport),
69    {This#binary_protocol{transport = NewTransport}, Result}.
70
71%%%
72%%% instance methods
73%%%
74
75write(This0, #protocol_message_begin{
76        name = Name,
77        type = Type,
78        seqid = Seqid}) ->
79    case This0#binary_protocol.strict_write of
80        true ->
81            {This1, ok} = write(This0, {i32, ?VERSION_1 bor Type}),
82            {This2, ok} = write(This1, {string, Name}),
83            {This3, ok} = write(This2, {i32, Seqid}),
84            {This3, ok};
85        false ->
86            {This1, ok} = write(This0, {string, Name}),
87            {This2, ok} = write(This1, {byte, Type}),
88            {This3, ok} = write(This2, {i32, Seqid}),
89            {This3, ok}
90    end;
91
92write(This, message_end) -> {This, ok};
93
94write(This0, #protocol_field_begin{
95       name = _Name,
96       type = Type,
97       id = Id}) ->
98    {This1, ok} = write(This0, {byte, Type}),
99    {This2, ok} = write(This1, {i16, Id}),
100    {This2, ok};
101
102write(This, field_stop) ->
103    write(This, {byte, ?tType_STOP});
104
105write(This, field_end) -> {This, ok};
106
107write(This0, #protocol_map_begin{
108       ktype = Ktype,
109       vtype = Vtype,
110       size = Size}) ->
111    {This1, ok} = write(This0, {byte, Ktype}),
112    {This2, ok} = write(This1, {byte, Vtype}),
113    {This3, ok} = write(This2, {i32, Size}),
114    {This3, ok};
115
116write(This, map_end) -> {This, ok};
117
118write(This0, #protocol_list_begin{
119        etype = Etype,
120        size = Size}) ->
121    {This1, ok} = write(This0, {byte, Etype}),
122    {This2, ok} = write(This1, {i32, Size}),
123    {This2, ok};
124
125write(This, list_end) -> {This, ok};
126
127write(This0, #protocol_set_begin{
128        etype = Etype,
129        size = Size}) ->
130    {This1, ok} = write(This0, {byte, Etype}),
131    {This2, ok} = write(This1, {i32, Size}),
132    {This2, ok};
133
134write(This, set_end) -> {This, ok};
135
136write(This, #protocol_struct_begin{}) -> {This, ok};
137write(This, struct_end) -> {This, ok};
138
139write(This, {bool, true})  -> write(This, {byte, 1});
140write(This, {bool, false}) -> write(This, {byte, 0});
141
142write(This, {byte, Byte}) ->
143    write(This, <<Byte:8/big-signed>>);
144
145write(This, {i16, I16}) ->
146    write(This, <<I16:16/big-signed>>);
147
148write(This, {i32, I32}) ->
149    write(This, <<I32:32/big-signed>>);
150
151write(This, {i64, I64}) ->
152    write(This, <<I64:64/big-signed>>);
153
154write(This, {double, Double}) ->
155    write(This, <<Double:64/big-signed-float>>);
156
157write(This0, {string, Str}) when is_list(Str) ->
158    {This1, ok} = write(This0, {i32, length(Str)}),
159    {This2, ok} = write(This1, list_to_binary(Str)),
160    {This2, ok};
161
162write(This0, {string, Bin}) when is_binary(Bin) ->
163    {This1, ok} = write(This0, {i32, size(Bin)}),
164    {This2, ok} = write(This1, Bin),
165    {This2, ok};
166
167%% Data :: iolist()
168write(This = #binary_protocol{transport = Trans}, Data) ->
169    {NewTransport, Result} = thrift_transport:write(Trans, Data),
170    {This#binary_protocol{transport = NewTransport}, Result}.
171
172%%
173
174read(This0, message_begin) ->
175    {This1, Initial} = read(This0, ui32),
176    case Initial of
177        {ok, Sz} when Sz band ?VERSION_MASK =:= ?VERSION_1 ->
178            %% we're at version 1
179            {This2, {ok, Name}}  = read(This1, string),
180            {This3, {ok, SeqId}} = read(This2, i32),
181            Type                 = Sz band ?TYPE_MASK,
182            {This3, #protocol_message_begin{name  = binary_to_list(Name),
183                                            type  = Type,
184                                            seqid = SeqId}};
185
186        {ok, Sz} when Sz < 0 ->
187            %% there's a version number but it's unexpected
188            {This1, {error, {bad_binary_protocol_version, Sz}}};
189
190        {ok, _Sz} when This1#binary_protocol.strict_read =:= true ->
191            %% strict_read is true and there's no version header; that's an error
192            {This1, {error, no_binary_protocol_version}};
193
194        {ok, Sz} when This1#binary_protocol.strict_read =:= false ->
195            %% strict_read is false, so just read the old way
196            {This2, {ok, Name}}  = read_data(This1, Sz),
197            {This3, {ok, Type}}  = read(This2, byte),
198            {This4, {ok, SeqId}} = read(This3, i32),
199            {This4, #protocol_message_begin{name  = binary_to_list(Name),
200                                            type  = Type,
201                                            seqid = SeqId}};
202
203        Else ->
204            {This1, Else}
205    end;
206
207read(This, message_end) -> {This, ok};
208
209read(This, struct_begin) -> {This, ok};
210read(This, struct_end) -> {This, ok};
211
212read(This0, field_begin) ->
213    {This1, Result} = read(This0, byte),
214    case Result of
215        {ok, Type = ?tType_STOP} ->
216            {This1, #protocol_field_begin{type = Type}};
217        {ok, Type} ->
218            {This2, {ok, Id}} = read(This1, i16),
219            {This2, #protocol_field_begin{type = Type,
220                                          id = Id}}
221    end;
222
223read(This, field_end) -> {This, ok};
224
225read(This0, map_begin) ->
226    {This1, {ok, Ktype}} = read(This0, byte),
227    {This2, {ok, Vtype}} = read(This1, byte),
228    {This3, {ok, Size}}  = read(This2, i32),
229    {This3, #protocol_map_begin{ktype = Ktype,
230                                vtype = Vtype,
231                                size = Size}};
232read(This, map_end) -> {This, ok};
233
234read(This0, list_begin) ->
235    {This1, {ok, Etype}} = read(This0, byte),
236    {This2, {ok, Size}}  = read(This1, i32),
237    {This2, #protocol_list_begin{etype = Etype,
238                                 size = Size}};
239read(This, list_end) -> {This, ok};
240
241read(This0, set_begin) ->
242    {This1, {ok, Etype}} = read(This0, byte),
243    {This2, {ok, Size}}  = read(This1, i32),
244    {This2, #protocol_set_begin{etype = Etype,
245                                 size = Size}};
246read(This, set_end) -> {This, ok};
247
248read(This0, field_stop) ->
249    {This1, {ok, ?tType_STOP}} = read(This0, byte),
250    {This1, ok};
251
252%%
253
254read(This0, bool) ->
255    {This1, Result} = read(This0, byte),
256    case Result of
257        {ok, Byte} -> {This1, {ok, Byte /= 0}};
258        Else -> {This1, Else}
259    end;
260
261read(This0, byte) ->
262    {This1, Bytes} = read_data(This0, 1),
263    case Bytes of
264        {ok, <<Val:8/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
265        Else -> {This1, Else}
266    end;
267
268read(This0, i16) ->
269    {This1, Bytes} = read_data(This0, 2),
270    case Bytes of
271        {ok, <<Val:16/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
272        Else -> {This1, Else}
273    end;
274
275read(This0, i32) ->
276    {This1, Bytes} = read_data(This0, 4),
277    case Bytes of
278        {ok, <<Val:32/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
279        Else -> {This1, Else}
280    end;
281
282%% unsigned ints aren't used by thrift itself, but it's used for the parsing
283%% of the packet version header. Without this special function BEAM works fine
284%% but hipe thinks it received a bad version header.
285read(This0, ui32) ->
286    {This1, Bytes} = read_data(This0, 4),
287    case Bytes of
288        {ok, <<Val:32/integer-unsigned-big, _/binary>>} -> {This1, {ok, Val}};
289        Else -> {This1, Else}
290    end;
291
292read(This0, i64) ->
293    {This1, Bytes} = read_data(This0, 8),
294    case Bytes of
295        {ok, <<Val:64/integer-signed-big, _/binary>>} -> {This1, {ok, Val}};
296        Else -> {This1, Else}
297    end;
298
299read(This0, double) ->
300    {This1, Bytes} = read_data(This0, 8),
301    case Bytes of
302        {ok, <<Val:64/float-signed-big, _/binary>>} -> {This1, {ok, Val}};
303        Else -> {This1, Else}
304    end;
305
306% returns a binary directly, call binary_to_list if necessary
307read(This0, string) ->
308    {This1, {ok, Sz}}  = read(This0, i32),
309    read_data(This1, Sz).
310
311-spec read_data(#binary_protocol{}, non_neg_integer()) ->
312    {#binary_protocol{}, {ok, binary()} | {error, _Reason}}.
313read_data(This, 0) -> {This, {ok, <<>>}};
314read_data(This = #binary_protocol{transport = Trans}, Len) when is_integer(Len) andalso Len > 0 ->
315    {NewTransport, Result} = thrift_transport:read(Trans, Len),
316    {This#binary_protocol{transport = NewTransport}, Result}.
317
318
319%%%% FACTORY GENERATION %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
320
321-record(tbp_opts, {strict_read = true,
322                   strict_write = true}).
323
324parse_factory_options([], Opts) ->
325    Opts;
326parse_factory_options([{strict_read, Bool} | Rest], Opts) when is_boolean(Bool) ->
327    parse_factory_options(Rest, Opts#tbp_opts{strict_read=Bool});
328parse_factory_options([{strict_write, Bool} | Rest], Opts) when is_boolean(Bool) ->
329    parse_factory_options(Rest, Opts#tbp_opts{strict_write=Bool}).
330
331
332%% returns a (fun() -> thrift_protocol())
333new_protocol_factory(TransportFactory, Options) ->
334    ParsedOpts = parse_factory_options(Options, #tbp_opts{}),
335    F = fun() ->
336               case TransportFactory() of
337                    {ok, Transport} ->
338                        thrift_binary_protocol:new(
339                            Transport,
340                            [{strict_read,  ParsedOpts#tbp_opts.strict_read},
341                             {strict_write, ParsedOpts#tbp_opts.strict_write}]);
342                    {error, Error} ->
343                        {error, Error}
344                end
345        end,
346    {ok, F}.
347
348