1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from thrift.Thrift import TProcessor, TMessageType
21from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
22from thrift.protocol.TProtocol import TProtocolException
23
24
25class TMultiplexedProcessor(TProcessor):
26    def __init__(self):
27        self.defaultProcessor = None
28        self.services = {}
29
30    def registerDefault(self, processor):
31        """
32        If a non-multiplexed processor connects to the server and wants to
33        communicate, use the given processor to handle it.  This mechanism
34        allows servers to upgrade from non-multiplexed to multiplexed in a
35        backwards-compatible way and still handle old clients.
36        """
37        self.defaultProcessor = processor
38
39    def registerProcessor(self, serviceName, processor):
40        self.services[serviceName] = processor
41
42    def on_message_begin(self, func):
43        for key in self.services.keys():
44            self.services[key].on_message_begin(func)
45
46    def process(self, iprot, oprot):
47        (name, type, seqid) = iprot.readMessageBegin()
48        if type != TMessageType.CALL and type != TMessageType.ONEWAY:
49            raise TProtocolException(
50                TProtocolException.NOT_IMPLEMENTED,
51                "TMultiplexedProtocol only supports CALL & ONEWAY")
52
53        index = name.find(TMultiplexedProtocol.SEPARATOR)
54        if index < 0:
55            if self.defaultProcessor:
56                return self.defaultProcessor.process(
57                    StoredMessageProtocol(iprot, (name, type, seqid)), oprot)
58            else:
59                raise TProtocolException(
60                    TProtocolException.NOT_IMPLEMENTED,
61                    "Service name not found in message name: " + name + ".  " +
62                    "Did you forget to use TMultiplexedProtocol in your client?")
63
64        serviceName = name[0:index]
65        call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
66        if serviceName not in self.services:
67            raise TProtocolException(
68                TProtocolException.NOT_IMPLEMENTED,
69                "Service name not found: " + serviceName + ".  " +
70                "Did you forget to call registerProcessor()?")
71
72        standardMessage = (call, type, seqid)
73        return self.services[serviceName].process(
74            StoredMessageProtocol(iprot, standardMessage), oprot)
75
76
77class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
78    def __init__(self, protocol, messageBegin):
79        self.messageBegin = messageBegin
80
81    def readMessageBegin(self):
82        return self.messageBegin
83