1 // Licensed to the Apache Software Foundation (ASF) under one
2 // or more contributor license agreements. See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership. The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 //   http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 use log::debug;
19 
20 use std::collections::HashMap;
21 use std::convert::Into;
22 use std::fmt;
23 use std::fmt::{Debug, Formatter};
24 use std::sync::{Arc, Mutex};
25 
26 use crate::protocol::{TInputProtocol, TMessageIdentifier, TOutputProtocol, TStoredInputProtocol};
27 
28 use super::{handle_process_result, TProcessor};
29 
30 const MISSING_SEPARATOR_AND_NO_DEFAULT: &str =
31     "missing service separator and no default processor set";
32 type ThreadSafeProcessor = Box<dyn TProcessor + Send + Sync>;
33 
34 /// A `TProcessor` that can demux service calls to multiple underlying
35 /// Thrift services.
36 ///
37 /// Users register service-specific `TProcessor` instances with a
38 /// `TMultiplexedProcessor`, and then register that processor with a server
39 /// implementation. Following that, all incoming service calls are automatically
40 /// routed to the service-specific `TProcessor`.
41 ///
42 /// A `TMultiplexedProcessor` can only handle messages sent by a
43 /// `TMultiplexedOutputProtocol`.
44 #[derive(Default)]
45 pub struct TMultiplexedProcessor {
46     stored: Mutex<StoredProcessors>,
47 }
48 
49 #[derive(Default)]
50 struct StoredProcessors {
51     processors: HashMap<String, Arc<ThreadSafeProcessor>>,
52     default_processor: Option<Arc<ThreadSafeProcessor>>,
53 }
54 
55 impl TMultiplexedProcessor {
56     /// Create a new `TMultiplexedProcessor` with no registered service-specific
57     /// processors.
new() -> TMultiplexedProcessor58     pub fn new() -> TMultiplexedProcessor {
59         TMultiplexedProcessor {
60             stored: Mutex::new(StoredProcessors {
61                 processors: HashMap::new(),
62                 default_processor: None,
63             }),
64         }
65     }
66 
67     /// Register a service-specific `processor` for the service named
68     /// `service_name`. This implementation is also backwards-compatible with
69     /// non-multiplexed clients. Set `as_default` to `true` to allow
70     /// non-namespaced requests to be dispatched to a default processor.
71     ///
72     /// Returns success if a new entry was inserted. Returns an error if:
73     /// * A processor exists for `service_name`
74     /// * You attempt to register a processor as default, and an existing default exists
75     #[allow(clippy::map_entry)]
register<S: Into<String>>( &mut self, service_name: S, processor: Box<dyn TProcessor + Send + Sync>, as_default: bool, ) -> crate::Result<()>76     pub fn register<S: Into<String>>(
77         &mut self,
78         service_name: S,
79         processor: Box<dyn TProcessor + Send + Sync>,
80         as_default: bool,
81     ) -> crate::Result<()> {
82         let mut stored = self.stored.lock().unwrap();
83 
84         let name = service_name.into();
85         if !stored.processors.contains_key(&name) {
86             let processor = Arc::new(processor);
87 
88             if as_default {
89                 if stored.default_processor.is_none() {
90                     stored.processors.insert(name, processor.clone());
91                     stored.default_processor = Some(processor.clone());
92                     Ok(())
93                 } else {
94                     Err("cannot reset default processor".into())
95                 }
96             } else {
97                 stored.processors.insert(name, processor);
98                 Ok(())
99             }
100         } else {
101             Err(format!("cannot overwrite existing processor for service {}", name).into())
102         }
103     }
104 
process_message( &self, msg_ident: &TMessageIdentifier, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol, ) -> crate::Result<()>105     fn process_message(
106         &self,
107         msg_ident: &TMessageIdentifier,
108         i_prot: &mut dyn TInputProtocol,
109         o_prot: &mut dyn TOutputProtocol,
110     ) -> crate::Result<()> {
111         let (svc_name, svc_call) = split_ident_name(&msg_ident.name);
112         debug!("routing svc_name {:?} svc_call {}", &svc_name, &svc_call);
113 
114         let processor: Option<Arc<ThreadSafeProcessor>> = {
115             let stored = self.stored.lock().unwrap();
116             if let Some(name) = svc_name {
117                 stored.processors.get(name).cloned()
118             } else {
119                 stored.default_processor.clone()
120             }
121         };
122 
123         match processor {
124             Some(arc) => {
125                 let new_msg_ident = TMessageIdentifier::new(
126                     svc_call,
127                     msg_ident.message_type,
128                     msg_ident.sequence_number,
129                 );
130                 let mut proxy_i_prot = TStoredInputProtocol::new(i_prot, new_msg_ident);
131                 (*arc).process(&mut proxy_i_prot, o_prot)
132             }
133             None => Err(missing_processor_message(svc_name).into()),
134         }
135     }
136 }
137 
138 impl TProcessor for TMultiplexedProcessor {
process(&self, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol) -> crate::Result<()>139     fn process(&self, i_prot: &mut dyn TInputProtocol, o_prot: &mut dyn TOutputProtocol) -> crate::Result<()> {
140         let msg_ident = i_prot.read_message_begin()?;
141 
142         debug!("process incoming msg id:{:?}", &msg_ident);
143         let res = self.process_message(&msg_ident, i_prot, o_prot);
144 
145         handle_process_result(&msg_ident, res, o_prot)
146     }
147 }
148 
149 impl Debug for TMultiplexedProcessor {
fmt(&self, f: &mut Formatter<'_>) -> fmt::Result150     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
151         let stored = self.stored.lock().unwrap();
152         write!(
153             f,
154             "TMultiplexedProcess {{ registered_count: {:?} default: {:?} }}",
155             stored.processors.keys().len(),
156             stored.default_processor.is_some()
157         )
158     }
159 }
160 
split_ident_name(ident_name: &str) -> (Option<&str>, &str)161 fn split_ident_name(ident_name: &str) -> (Option<&str>, &str) {
162     ident_name
163         .find(':')
164         .map(|pos| {
165             let (svc_name, svc_call) = ident_name.split_at(pos);
166             let (_, svc_call) = svc_call.split_at(1); // remove colon from service call name
167             (Some(svc_name), svc_call)
168         })
169         .or_else(|| Some((None, ident_name)))
170         .unwrap()
171 }
172 
missing_processor_message(svc_name: Option<&str>) -> String173 fn missing_processor_message(svc_name: Option<&str>) -> String {
174     match svc_name {
175         Some(name) => format!("no processor found for service {}", name),
176         None => MISSING_SEPARATOR_AND_NO_DEFAULT.to_owned(),
177     }
178 }
179 
180 #[cfg(test)]
181 mod tests {
182     use std::convert::Into;
183     use std::sync::atomic::{AtomicBool, Ordering};
184     use std::sync::Arc;
185 
186     use crate::protocol::{TBinaryInputProtocol, TBinaryOutputProtocol, TMessageIdentifier, TMessageType};
187     use crate::transport::{ReadHalf, TBufferChannel, TIoChannel, WriteHalf};
188     use crate::{ApplicationError, ApplicationErrorKind};
189 
190     use super::*;
191 
192     #[test]
should_split_name_into_proper_separator_and_service_call()193     fn should_split_name_into_proper_separator_and_service_call() {
194         let ident_name = "foo:bar_call";
195         let (serv, call) = split_ident_name(&ident_name);
196         assert_eq!(serv, Some("foo"));
197         assert_eq!(call, "bar_call");
198     }
199 
200     #[test]
should_return_full_ident_if_no_separator_exists()201     fn should_return_full_ident_if_no_separator_exists() {
202         let ident_name = "bar_call";
203         let (serv, call) = split_ident_name(&ident_name);
204         assert_eq!(serv, None);
205         assert_eq!(call, "bar_call");
206     }
207 
208     #[test]
should_write_error_if_no_separator_found_and_no_default_processor_exists()209     fn should_write_error_if_no_separator_found_and_no_default_processor_exists() {
210         let (mut i, mut o) = build_objects();
211 
212         let sent_ident = TMessageIdentifier::new("foo", TMessageType::Call, 10);
213         o.write_message_begin(&sent_ident).unwrap();
214         o.flush().unwrap();
215         o.transport.copy_write_buffer_to_read_buffer();
216         o.transport.empty_write_buffer();
217 
218         let p = TMultiplexedProcessor::new();
219         p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
220 
221         i.transport.set_readable_bytes(&o.transport.write_bytes());
222         let rcvd_ident = i.read_message_begin().unwrap();
223         let expected_ident = TMessageIdentifier::new("foo", TMessageType::Exception, 10);
224         assert_eq!(rcvd_ident, expected_ident);
225         let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
226         let expected_err = ApplicationError::new(
227             ApplicationErrorKind::Unknown,
228             MISSING_SEPARATOR_AND_NO_DEFAULT,
229         );
230         assert_eq!(rcvd_err, expected_err);
231     }
232 
233     #[test]
should_write_error_if_separator_exists_and_no_processor_found()234     fn should_write_error_if_separator_exists_and_no_processor_found() {
235         let (mut i, mut o) = build_objects();
236 
237         let sent_ident = TMessageIdentifier::new("missing:call", TMessageType::Call, 10);
238         o.write_message_begin(&sent_ident).unwrap();
239         o.flush().unwrap();
240         o.transport.copy_write_buffer_to_read_buffer();
241         o.transport.empty_write_buffer();
242 
243         let p = TMultiplexedProcessor::new();
244         p.process(&mut i, &mut o).unwrap(); // at this point an error should be written out
245 
246         i.transport.set_readable_bytes(&o.transport.write_bytes());
247         let rcvd_ident = i.read_message_begin().unwrap();
248         let expected_ident = TMessageIdentifier::new("missing:call", TMessageType::Exception, 10);
249         assert_eq!(rcvd_ident, expected_ident);
250         let rcvd_err = crate::Error::read_application_error_from_in_protocol(&mut i).unwrap();
251         let expected_err = ApplicationError::new(
252             ApplicationErrorKind::Unknown,
253             missing_processor_message(Some("missing")),
254         );
255         assert_eq!(rcvd_err, expected_err);
256     }
257 
258     #[derive(Default)]
259     struct Service {
260         pub invoked: Arc<AtomicBool>,
261     }
262 
263     impl TProcessor for Service {
process(&self, _: &mut dyn TInputProtocol, _: &mut dyn TOutputProtocol) -> crate::Result<()>264         fn process(&self, _: &mut dyn TInputProtocol, _: &mut dyn TOutputProtocol) -> crate::Result<()> {
265             let res = self
266                 .invoked
267                 .compare_and_swap(false, true, Ordering::Relaxed);
268             if res {
269                 Ok(())
270             } else {
271                 Err("failed swap".into())
272             }
273         }
274     }
275 
276     #[test]
should_route_call_to_correct_processor()277     fn should_route_call_to_correct_processor() {
278         let (mut i, mut o) = build_objects();
279 
280         // build the services
281         let svc_1 = Service {
282             invoked: Arc::new(AtomicBool::new(false)),
283         };
284         let atm_1 = svc_1.invoked.clone();
285         let svc_2 = Service {
286             invoked: Arc::new(AtomicBool::new(false)),
287         };
288         let atm_2 = svc_2.invoked.clone();
289 
290         // register them
291         let mut p = TMultiplexedProcessor::new();
292         p.register("service_1", Box::new(svc_1), false).unwrap();
293         p.register("service_2", Box::new(svc_2), false).unwrap();
294 
295         // make the service call
296         let sent_ident = TMessageIdentifier::new("service_1:call", TMessageType::Call, 10);
297         o.write_message_begin(&sent_ident).unwrap();
298         o.flush().unwrap();
299         o.transport.copy_write_buffer_to_read_buffer();
300         o.transport.empty_write_buffer();
301 
302         p.process(&mut i, &mut o).unwrap();
303 
304         // service 1 should have been invoked, not service 2
305         assert_eq!(atm_1.load(Ordering::Relaxed), true);
306         assert_eq!(atm_2.load(Ordering::Relaxed), false);
307     }
308 
309     #[test]
should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set()310     fn should_route_call_to_correct_processor_if_no_separator_exists_and_default_processor_set() {
311         let (mut i, mut o) = build_objects();
312 
313         // build the services
314         let svc_1 = Service {
315             invoked: Arc::new(AtomicBool::new(false)),
316         };
317         let atm_1 = svc_1.invoked.clone();
318         let svc_2 = Service {
319             invoked: Arc::new(AtomicBool::new(false)),
320         };
321         let atm_2 = svc_2.invoked.clone();
322 
323         // register them
324         let mut p = TMultiplexedProcessor::new();
325         p.register("service_1", Box::new(svc_1), false).unwrap();
326         p.register("service_2", Box::new(svc_2), true).unwrap(); // second processor is default
327 
328         // make the service call (it's an old client, so we have to be backwards compatible)
329         let sent_ident = TMessageIdentifier::new("old_call", TMessageType::Call, 10);
330         o.write_message_begin(&sent_ident).unwrap();
331         o.flush().unwrap();
332         o.transport.copy_write_buffer_to_read_buffer();
333         o.transport.empty_write_buffer();
334 
335         p.process(&mut i, &mut o).unwrap();
336 
337         // service 2 should have been invoked, not service 1
338         assert_eq!(atm_1.load(Ordering::Relaxed), false);
339         assert_eq!(atm_2.load(Ordering::Relaxed), true);
340     }
341 
build_objects() -> ( TBinaryInputProtocol<ReadHalf<TBufferChannel>>, TBinaryOutputProtocol<WriteHalf<TBufferChannel>>, )342     fn build_objects() -> (
343         TBinaryInputProtocol<ReadHalf<TBufferChannel>>,
344         TBinaryOutputProtocol<WriteHalf<TBufferChannel>>,
345     ) {
346         let c = TBufferChannel::with_capacity(128, 128);
347         let (r_c, w_c) = c.split().unwrap();
348         (
349             TBinaryInputProtocol::new(r_c, true),
350             TBinaryOutputProtocol::new(w_c, true),
351         )
352     }
353 }
354