1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20package thrift
21
22import (
23	"context"
24	"fmt"
25	"strings"
26)
27
28/*
29TMultiplexedProtocol is a protocol-independent concrete decorator
30that allows a Thrift client to communicate with a multiplexing Thrift server,
31by prepending the service name to the function name during function calls.
32
33NOTE: THIS IS NOT USED BY SERVERS.  On the server, use TMultiplexedProcessor to handle request
34from a multiplexing client.
35
36This example uses a single socket transport to invoke two services:
37
38socket := thrift.NewTSocketFromAddrTimeout(addr, TIMEOUT)
39transport := thrift.NewTFramedTransport(socket)
40protocol := thrift.NewTBinaryProtocolTransport(transport)
41
42mp := thrift.NewTMultiplexedProtocol(protocol, "Calculator")
43service := Calculator.NewCalculatorClient(mp)
44
45mp2 := thrift.NewTMultiplexedProtocol(protocol, "WeatherReport")
46service2 := WeatherReport.NewWeatherReportClient(mp2)
47
48err := transport.Open()
49if err != nil {
50	t.Fatal("Unable to open client socket", err)
51}
52
53fmt.Println(service.Add(2,2))
54fmt.Println(service2.GetTemperature())
55*/
56
57type TMultiplexedProtocol struct {
58	TProtocol
59	serviceName string
60}
61
62const MULTIPLEXED_SEPARATOR = ":"
63
64func NewTMultiplexedProtocol(protocol TProtocol, serviceName string) *TMultiplexedProtocol {
65	return &TMultiplexedProtocol{
66		TProtocol:   protocol,
67		serviceName: serviceName,
68	}
69}
70
71func (t *TMultiplexedProtocol) WriteMessageBegin(name string, typeId TMessageType, seqid int32) error {
72	if typeId == CALL || typeId == ONEWAY {
73		return t.TProtocol.WriteMessageBegin(t.serviceName+MULTIPLEXED_SEPARATOR+name, typeId, seqid)
74	} else {
75		return t.TProtocol.WriteMessageBegin(name, typeId, seqid)
76	}
77}
78
79/*
80TMultiplexedProcessor is a TProcessor allowing
81a single TServer to provide multiple services.
82
83To do so, you instantiate the processor and then register additional
84processors with it, as shown in the following example:
85
86var processor = thrift.NewTMultiplexedProcessor()
87
88firstProcessor :=
89processor.RegisterProcessor("FirstService", firstProcessor)
90
91processor.registerProcessor(
92  "Calculator",
93  Calculator.NewCalculatorProcessor(&CalculatorHandler{}),
94)
95
96processor.registerProcessor(
97  "WeatherReport",
98  WeatherReport.NewWeatherReportProcessor(&WeatherReportHandler{}),
99)
100
101serverTransport, err := thrift.NewTServerSocketTimeout(addr, TIMEOUT)
102if err != nil {
103  t.Fatal("Unable to create server socket", err)
104}
105server := thrift.NewTSimpleServer2(processor, serverTransport)
106server.Serve();
107*/
108
109type TMultiplexedProcessor struct {
110	serviceProcessorMap map[string]TProcessor
111	DefaultProcessor    TProcessor
112}
113
114func NewTMultiplexedProcessor() *TMultiplexedProcessor {
115	return &TMultiplexedProcessor{
116		serviceProcessorMap: make(map[string]TProcessor),
117	}
118}
119
120func (t *TMultiplexedProcessor) RegisterDefault(processor TProcessor) {
121	t.DefaultProcessor = processor
122}
123
124func (t *TMultiplexedProcessor) RegisterProcessor(name string, processor TProcessor) {
125	if t.serviceProcessorMap == nil {
126		t.serviceProcessorMap = make(map[string]TProcessor)
127	}
128	t.serviceProcessorMap[name] = processor
129}
130
131func (t *TMultiplexedProcessor) Process(ctx context.Context, in, out TProtocol) (bool, TException) {
132	name, typeId, seqid, err := in.ReadMessageBegin()
133	if err != nil {
134		return false, err
135	}
136	if typeId != CALL && typeId != ONEWAY {
137		return false, fmt.Errorf("Unexpected message type %v", typeId)
138	}
139	//extract the service name
140	v := strings.SplitN(name, MULTIPLEXED_SEPARATOR, 2)
141	if len(v) != 2 {
142		if t.DefaultProcessor != nil {
143			smb := NewStoredMessageProtocol(in, name, typeId, seqid)
144			return t.DefaultProcessor.Process(ctx, smb, out)
145		}
146		return false, fmt.Errorf("Service name not found in message name: %s.  Did you forget to use a TMultiplexProtocol in your client?", name)
147	}
148	actualProcessor, ok := t.serviceProcessorMap[v[0]]
149	if !ok {
150		return false, fmt.Errorf("Service name not found: %s.  Did you forget to call registerProcessor()?", v[0])
151	}
152	smb := NewStoredMessageProtocol(in, v[1], typeId, seqid)
153	return actualProcessor.Process(ctx, smb, out)
154}
155
156//Protocol that use stored message for ReadMessageBegin
157type storedMessageProtocol struct {
158	TProtocol
159	name   string
160	typeId TMessageType
161	seqid  int32
162}
163
164func NewStoredMessageProtocol(protocol TProtocol, name string, typeId TMessageType, seqid int32) *storedMessageProtocol {
165	return &storedMessageProtocol{protocol, name, typeId, seqid}
166}
167
168func (s *storedMessageProtocol) ReadMessageBegin() (name string, typeId TMessageType, seqid int32, err error) {
169	return s.name, s.typeId, s.seqid, nil
170}
171