1#
2# Licensed to the Apache Software Foundation (ASF) under one
3# or more contributor license agreements. See the NOTICE file
4# distributed with this work for additional information
5# regarding copyright ownership. The ASF licenses this file
6# to you under the Apache License, Version 2.0 (the
7# "License"); you may not use this file except in compliance
8# with the License. You may obtain a copy of the License at
9#
10#   http://www.apache.org/licenses/LICENSE-2.0
11#
12# Unless required by applicable law or agreed to in writing,
13# software distributed under the License is distributed on an
14# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15# KIND, either express or implied. See the License for the
16# specific language governing permissions and limitations
17# under the License.
18#
19
20from thrift.Thrift import TProcessor, TMessageType, TException
21from thrift.protocol import TProtocolDecorator, TMultiplexedProtocol
22
23
24class TMultiplexedProcessor(TProcessor):
25    def __init__(self):
26        self.services = {}
27
28    def registerProcessor(self, serviceName, processor):
29        self.services[serviceName] = processor
30
31    def process(self, iprot, oprot):
32        (name, type, seqid) = iprot.readMessageBegin()
33        if type != TMessageType.CALL and type != TMessageType.ONEWAY:
34            raise TException("TMultiplexed protocol only supports CALL & ONEWAY")
35
36        index = name.find(TMultiplexedProtocol.SEPARATOR)
37        if index < 0:
38            raise TException("Service name not found in message name: " + name + ". Did you forget to use TMultiplexedProtocol in your client?")
39
40        serviceName = name[0:index]
41        call = name[index + len(TMultiplexedProtocol.SEPARATOR):]
42        if serviceName not in self.services:
43            raise TException("Service name not found: " + serviceName + ". Did you forget to call registerProcessor()?")
44
45        standardMessage = (call, type, seqid)
46        return self.services[serviceName].process(StoredMessageProtocol(iprot, standardMessage), oprot)
47
48
49class StoredMessageProtocol(TProtocolDecorator.TProtocolDecorator):
50    def __init__(self, protocol, messageBegin):
51        self.messageBegin = messageBegin
52
53    def readMessageBegin(self):
54        return self.messageBegin
55