1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 // http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17
18 use log::debug;
19
20 use std::collections::HashMap;
21 use std::convert::Into;
22 use std::fmt;
23 use std::fmt::{Debug, Formatter};
24 use std::sync::{Arc, Mutex};
25
26 use crate::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol};
27
28 use super::{handle_process_result, TProcessor};
29
30 const MISSING_SEPARATOR_AND_NO_DEFAULT: &str =
31 "missing service separator and no default processor set";
32 type ThreadSafeProcessor = Box<dyn TProcessor + Send + Sync>;
33
34 /// A `TProcessor` that can demux service calls to multiple underlying
35 /// Thrift services.
36 ///
37 /// Users register service-specific `TProcessor` instances with a
38 /// `TMultiplexedProcessor`, and then register that processor with a server
39 /// implementation. Following that, all incoming service calls are automatically
40 /// routed to the service-specific `TProcessor`.
41 ///
42 /// A `TMultiplexedProcessor` can only handle messages sent by a
43 /// `TMultiplexedOutputProtocol`.
44 #[derive(Default)]
45 pub struct TMultiplexedProcessor {
46 stored: Mutex<StoredProcessors>,
47 }
48
49 #[derive(Default)]
50 struct StoredProcessors {
51 processors: HashMap<String, Arc<ThreadSafeProcessor>>,
52 default_processor: Option<Arc<ThreadSafeProcessor>>,
53 }
54
55 impl TMultiplexedProcessor {
56 /// Create a new `TMultiplexedProcessor` with no registered service-specific
57 /// processors.
new() -> TMultiplexedProcessor58 pub fn new() -> TMultiplexedProcessor {
59 TMultiplexedProcessor {
60 stored: Mutex::new(StoredProcessors {
61 processors: HashMap::new(),
62 default_processor: None,
63 }),
64 }
65 }
66
67 /// Register a service-specific `processor` for the service named
68 /// `service_name`. This implementation is also backwards-compatible with
69 /// non-multiplexed clients. Set `as_default` to `true` to allow
70 /// non-namespaced requests to be dispatched to a default processor.
71 ///
72 /// Returns success if a new entry was inserted. Returns an error if:
73 /// * A processor exists for `service_name`
74 /// * You attempt to register a processor as default, and an existing default exists
75 #[allow(clippy::map_entry)]
register<S: Into<String>>( &mut self, service_name: S, processor: Box<dyn TProcessor + Send + Sync>, as_default: bool, ) -> crate::Result<()>76 pub fn register<S: Into<String>>(
77 &mut self,
78 service_name: S,
79 processor: Box<dyn TProcessor + Send + Sync>,
80 as_default: bool,
81 ) -> crate::Result<()> {
82 let mut stored = self.stored.lock().unwrap();
83
84 let name = service_name.into();
85 if !stored.processors.contains_key(&name) {
86 let processor = Arc::new(processor);
87
88 if as_default {
89 if stored.default_processor.is_none() {
90 stored.processors.insert(name, processor.clone());
91 stored.default_processor = Some(processor.clone());
92 Ok(())
93 } else {
94 Err("cannot reset default processor".into())
95 }
96 } else {
97 stored.processors.insert(name, processor);
98 Ok(())
99 }
100 } else {
101 Err(format!("cannot overwrite existing processor for service {}", name).into())
102 }
103 }
104
process_message( &self, msg_ident: &TMessageIdentifier, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol, ) -> crate::Result<()>105 fn process_message(
106 &self,
107 msg_ident: &TMessageIdentifier,
108 i_prot: &mut dyn TInputProtocol,
109 o_prot: &mut dyn TOutputProtocol,
110 ) -> crate::Result<()> {
111 let (svc_name, svc_call) = split_ident_name(&msg_ident.name);
112 debug!("routing svc_name {:?} svc_call {}", &svc_name, &svc_call);
113
114 let processor: Option<Arc<ThreadSafeProcessor>> = {
115 let stored = self.stored.lock().unwrap();
116 if let Some(name) = svc_name {
117 stored.processors.get(name).cloned()
118 } else {
119 stored.default_processor.clone()
120 }
121 };
122
123 match processor {
124 Some(arc) => {
125 let new_msg_ident = TMessageIdentifier::new(
126 svc_call,
127 msg_ident.message_type,
128 msg_ident.sequence_number,
129 );
130 let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident);
131 (*arc).process(&mut proxy_i_prot, o_prot)
132 }
133 None => Err(missing_processor_message(svc_name).into()),
134 }
135 }
136 }
137
138 impl TProcessor for TMultiplexedProcessor {
process(&self, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol) -> crate::Result<()>139 fn process(&self, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol) -> crate::Result<()> {
140 let msg_ident = i_prot.read_message_begin()?;
141
142 debug!("process incoming msg id:{:?}", &msg_ident);
143 let res = self.process_message(&msg_ident, i_prot, o_prot);
144
145 handle_process_result(&msg_ident, res, o_prot)
146 }
147 }
148
149 impl Debug for TMultiplexedProcessor {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result150 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
151 let stored = self.stored.lock().unwrap();
152 write!(
153 f,
154 "TMultiplexedProcess {{ registered_count: {:?} default: {:?} }}",
155 stored.processors.keys().len(),
156 stored.default_processor.is_some()
157 )
158 }
159 }
160
split_ident_name(ident_name: &str) -> (Option<&str>, &str)161 fn split_ident_name(ident_name: &str) -> (Option<&str>, &str) {
162 ident_name
163 .find(':')
164 .map(|pos| {
165 let (svc_name, svc_call) = ident_name.split_at(pos);
166 let (_, svc_call) = svc_call.split_at(1); // remove colon from service call name
167 (Some(svc_name), svc_call)
168 })
169 .or_else(|| Some((None, ident_name)))
170 .unwrap()
171 }
172
missing_processor_message(svc_name: Option<&str>) -> String173 fn missing_processor_message(svc_name: Option<&str>) -> String {
174 match svc_name {
175 Some(name) => format!("no processor found for service {}", name),
176 None => MISSING_SEPARATOR_AND_NO_DEFAULT.to_owned(),
177 }
178 }
179
180 #[cfg(test)]
181 mod tests {
182 use std::convert::Into;
183 use std::sync::atomic::{AtomicBool, Ordering};
184 use std::sync::Arc;
185
186 use crate::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TMessageIdentifier, TMessageType};
187 use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
188 use crate::{ApplicationError, ApplicationErrorKind};
189
190 use super::*;
191
192 #[test]
should_split_name_into_proper_separator_and_service_call()193 fn should_split_name_into_proper_separator_and_service_call() {
194 let ident_name = "foo:bar_call";
195 let (serv, call) = split_ident_name(&ident_name);
196 assert_eq!(serv, Some("foo"));
197 assert_eq!(call, "bar_call");
198 }
199
200 #[test]
should_return_full_ident_if_no_separator_exists()201 fn should_return_full_ident_if_no_separator_exists() {
202 let ident_name = "bar_call";
203 let (serv, call) = split_ident_name(&ident_name);
204 assert_eq!(serv, None);
205 assert_eq!(call, "bar_call");
206 }
207
208 #[test]
should_write_error_if_no_separator_found_and_no_default_processor_exists()209 fn should_write_error_if_no_separator_found_and_no_default_processor_exists() {
210 let (mut i, mut o) = build_objects();
211
212 let sent_ident = TMessageIdentifier::new("foo", TMessageType::Call, 10);
213 o.write_message_begin(&sent_ident).unwrap();
214 o.flush().unwrap();
215 o.transport.copy_write_buffer_to_read_buffer();
216 o.transport.empty_write_buffer();
217
218 let p = TMultiplexedProcessor::new();
219 p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
220
221 i.transport.set_readable_bytes(&o.transport.write_bytes());
222 let rcvd_ident = i.read_message_begin().unwrap();
223 let expected_ident = TMessageIdentifier::new("foo", TMessageType::Exception, 10);
224 assert_eq!(rcvd_ident, expected_ident);
225 let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
226 let expected_err = ApplicationError::new(
227 ApplicationErrorKind::Unknown,
228 MISSING_SEPARATOR_AND_NO_DEFAULT,
229 );
230 assert_eq!(rcvd_err, expected_err);
231 }
232
233 #[test]
should_write_error_if_separator_exists_and_no_processor_found()234 fn should_write_error_if_separator_exists_and_no_processor_found() {
235 let (mut i, mut o) = build_objects();
236
237 let sent_ident = TMessageIdentifier::new("missing:call", TMessageType::Call, 10);
238 o.write_message_begin(&sent_ident).unwrap();
239 o.flush().unwrap();
240 o.transport.copy_write_buffer_to_read_buffer();
241 o.transport.empty_write_buffer();
242
243 let p = TMultiplexedProcessor::new();
244 p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
245
246 i.transport.set_readable_bytes(&o.transport.write_bytes());
247 let rcvd_ident = i.read_message_begin().unwrap();
248 let expected_ident = TMessageIdentifier::new("missing:call", TMessageType::Exception, 10);
249 assert_eq!(rcvd_ident, expected_ident);
250 let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
251 let expected_err = ApplicationError::new(
252 ApplicationErrorKind::Unknown,
253 missing_processor_message(Some("missing")),
254 );
255 assert_eq!(rcvd_err, expected_err);
256 }
257
258 #[derive(Default)]
259 struct Service {
260 pub invoked: Arc<AtomicBool>,
261 }
262
263 impl TProcessor for Service {
process(&self, _: &mut dyn TInputProtocol, _: &mut dyn TOutputProtocol) -> crate::Result<()>264 fn process(&self, _: &mut dyn TInputProtocol, _: &mut dyn TOutputProtocol) -> crate::Result<()> {
265 let res = self
266 .invoked
267 .compare_and_swap(false, true, Ordering::Relaxed);
268 if res {
269 Ok(())
270 } else {
271 Err("failed swap".into())
272 }
273 }
274 }
275
276 #[test]
should_route_call_to_correct_processor()277 fn should_route_call_to_correct_processor() {
278 let (mut i, mut o) = build_objects();
279
280 // build the services
281 let svc_1 = Service {
282 invoked: Arc::new(AtomicBool::new(false)),
283 };
284 let atm_1 = svc_1.invoked.clone();
285 let svc_2 = Service {
286 invoked: Arc::new(AtomicBool::new(false)),
287 };
288 let atm_2 = svc_2.invoked.clone();
289
290 // register them
291 let mut p = TMultiplexedProcessor::new();
292 p.register("service_1", Box::new(svc_1), false).unwrap();
293 p.register("service_2", Box::new(svc_2), false).unwrap();
294
295 // make the service call
296 let sent_ident = TMessageIdentifier::new("service_1:call", TMessageType::Call, 10);
297 o.write_message_begin(&sent_ident).unwrap();
298 o.flush().unwrap();
299 o.transport.copy_write_buffer_to_read_buffer();
300 o.transport.empty_write_buffer();
301
302 p.process(&mut i, &mut o).unwrap();
303
304 // service 1 should have been invoked, not service 2
305 assert_eq!(atm_1.load(Ordering::Relaxed), true);
306 assert_eq!(atm_2.load(Ordering::Relaxed), false);
307 }
308
309 #[test]
should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set()310 fn should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set() {
311 let (mut i, mut o) = build_objects();
312
313 // build the services
314 let svc_1 = Service {
315 invoked: Arc::new(AtomicBool::new(false)),
316 };
317 let atm_1 = svc_1.invoked.clone();
318 let svc_2 = Service {
319 invoked: Arc::new(AtomicBool::new(false)),
320 };
321 let atm_2 = svc_2.invoked.clone();
322
323 // register them
324 let mut p = TMultiplexedProcessor::new();
325 p.register("service_1", Box::new(svc_1), false).unwrap();
326 p.register("service_2", Box::new(svc_2), true).unwrap(); // second processor is default
327
328 // make the service call (it's an old client, so we have to be backwards compatible)
329 let sent_ident = TMessageIdentifier::new("old_call", TMessageType::Call, 10);
330 o.write_message_begin(&sent_ident).unwrap();
331 o.flush().unwrap();
332 o.transport.copy_write_buffer_to_read_buffer();
333 o.transport.empty_write_buffer();
334
335 p.process(&mut i, &mut o).unwrap();
336
337 // service 2 should have been invoked, not service 1
338 assert_eq!(atm_1.load(Ordering::Relaxed), false);
339 assert_eq!(atm_2.load(Ordering::Relaxed), true);
340 }
341
build_objects() -> ( TBinaryInputProtocol<ReadHalf<TBufferChannel>>, TBinaryOutputProtocol<WriteHalf<TBufferChannel>>, )342 fn build_objects() -> (
343 TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
344 TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
345 ) {
346 let c = TBufferChannel::with_capacity(128, 128);
347 let (r_c, w_c) = c.split().unwrap();
348 (
349 TBinaryInputProtocol::new(r_c, true),
350 TBinaryOutputProtocol::new(w_c, true),
351 )
352 }
353 }
354