1 //! See [RequestDispatcher].
2 use std::{fmt, panic, thread};
3 
4 use serde::{de::DeserializeOwned, Serialize};
5 
6 use crate::{
7     global_state::{GlobalState, GlobalStateSnapshot},
8     lsp_utils::is_cancelled,
9     main_loop::Task,
10     LspError, Result,
11 };
12 
13 /// A visitor for routing a raw JSON request to an appropriate handler function.
14 ///
15 /// Most requests are read-only and async and are handled on the threadpool
16 /// (`on` method).
17 ///
18 /// Some read-only requests are latency sensitive, and are immediately handled
19 /// on the main loop thread (`on_sync`). These are typically typing-related
20 /// requests.
21 ///
22 /// Some requests modify the state, and are run on the main thread to get
23 /// `&mut` (`on_sync_mut`).
24 ///
25 /// Read-only requests are wrapped into `catch_unwind` -- they don't modify the
26 /// state, so it's OK to recover from their failures.
27 pub(crate) struct RequestDispatcher<'a> {
28     pub(crate) req: Option<lsp_server::Request>,
29     pub(crate) global_state: &'a mut GlobalState,
30 }
31 
32 impl<'a> RequestDispatcher<'a> {
33     /// Dispatches the request onto the current thread, given full access to
34     /// mutable global state. Unlike all other methods here, this one isn't
35     /// guarded by `catch_unwind`, so, please, don't make bugs :-)
on_sync_mut<R>( &mut self, f: fn(&mut GlobalState, R::Params) -> Result<R::Result>, ) -> Result<&mut Self> where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug + 'static, R::Result: Serialize + 'static,36     pub(crate) fn on_sync_mut<R>(
37         &mut self,
38         f: fn(&mut GlobalState, R::Params) -> Result<R::Result>,
39     ) -> Result<&mut Self>
40     where
41         R: lsp_types::request::Request + 'static,
42         R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug + 'static,
43         R::Result: Serialize + 'static,
44     {
45         let (id, params, panic_context) = match self.parse::<R>() {
46             Some(it) => it,
47             None => return Ok(self),
48         };
49         let _pctx = stdx::panic_context::enter(panic_context);
50 
51         let result = f(&mut self.global_state, params);
52         let response = result_to_response::<R>(id, result);
53 
54         self.global_state.respond(response);
55         Ok(self)
56     }
57 
58     /// Dispatches the request onto the current thread.
on_sync<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> Result<&mut Self> where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug + 'static, R::Result: Serialize + 'static,59     pub(crate) fn on_sync<R>(
60         &mut self,
61         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
62     ) -> Result<&mut Self>
63     where
64         R: lsp_types::request::Request + 'static,
65         R::Params: DeserializeOwned + panic::UnwindSafe + fmt::Debug + 'static,
66         R::Result: Serialize + 'static,
67     {
68         let (id, params, panic_context) = match self.parse::<R>() {
69             Some(it) => it,
70             None => return Ok(self),
71         };
72         let global_state_snapshot = self.global_state.snapshot();
73 
74         let result = panic::catch_unwind(move || {
75             let _pctx = stdx::panic_context::enter(panic_context);
76             f(global_state_snapshot, params)
77         });
78         let response = thread_result_to_response::<R>(id, result);
79 
80         self.global_state.respond(response);
81         Ok(self)
82     }
83 
84     /// Dispatches the request onto thread pool
on<R>( &mut self, f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>, ) -> &mut Self where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug + 'static, R::Result: Serialize + 'static,85     pub(crate) fn on<R>(
86         &mut self,
87         f: fn(GlobalStateSnapshot, R::Params) -> Result<R::Result>,
88     ) -> &mut Self
89     where
90         R: lsp_types::request::Request + 'static,
91         R::Params: DeserializeOwned + panic::UnwindSafe + Send + fmt::Debug + 'static,
92         R::Result: Serialize + 'static,
93     {
94         let (id, params, panic_context) = match self.parse::<R>() {
95             Some(it) => it,
96             None => return self,
97         };
98 
99         self.global_state.task_pool.handle.spawn({
100             let world = self.global_state.snapshot();
101             move || {
102                 let result = panic::catch_unwind(move || {
103                     let _pctx = stdx::panic_context::enter(panic_context);
104                     f(world, params)
105                 });
106                 let response = thread_result_to_response::<R>(id, result);
107                 Task::Response(response)
108             }
109         });
110 
111         self
112     }
113 
finish(&mut self)114     pub(crate) fn finish(&mut self) {
115         if let Some(req) = self.req.take() {
116             tracing::error!("unknown request: {:?}", req);
117             let response = lsp_server::Response::new_err(
118                 req.id,
119                 lsp_server::ErrorCode::MethodNotFound as i32,
120                 "unknown request".to_string(),
121             );
122             self.global_state.respond(response);
123         }
124     }
125 
parse<R>(&mut self) -> Option<(lsp_server::RequestId, R::Params, String)> where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + fmt::Debug + 'static,126     fn parse<R>(&mut self) -> Option<(lsp_server::RequestId, R::Params, String)>
127     where
128         R: lsp_types::request::Request + 'static,
129         R::Params: DeserializeOwned + fmt::Debug + 'static,
130     {
131         let req = match &self.req {
132             Some(req) if req.method == R::METHOD => self.req.take().unwrap(),
133             _ => return None,
134         };
135 
136         let res = crate::from_json(R::METHOD, req.params);
137         match res {
138             Ok(params) => {
139                 let panic_context =
140                     format!("\nversion: {}\nrequest: {} {:#?}", env!("REV"), R::METHOD, params);
141                 Some((req.id, params, panic_context))
142             }
143             Err(err) => {
144                 let response = lsp_server::Response::new_err(
145                     req.id,
146                     lsp_server::ErrorCode::InvalidParams as i32,
147                     err.to_string(),
148                 );
149                 self.global_state.respond(response);
150                 None
151             }
152         }
153     }
154 }
155 
thread_result_to_response<R>( id: lsp_server::RequestId, result: thread::Result<Result<R::Result>>, ) -> lsp_server::Response where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + 'static, R::Result: Serialize + 'static,156 fn thread_result_to_response<R>(
157     id: lsp_server::RequestId,
158     result: thread::Result<Result<R::Result>>,
159 ) -> lsp_server::Response
160 where
161     R: lsp_types::request::Request + 'static,
162     R::Params: DeserializeOwned + 'static,
163     R::Result: Serialize + 'static,
164 {
165     match result {
166         Ok(result) => result_to_response::<R>(id, result),
167         Err(panic) => {
168             let mut message = "server panicked".to_string();
169 
170             let panic_message = panic
171                 .downcast_ref::<String>()
172                 .map(String::as_str)
173                 .or_else(|| panic.downcast_ref::<&str>().copied());
174 
175             if let Some(panic_message) = panic_message {
176                 message.push_str(": ");
177                 message.push_str(panic_message)
178             };
179 
180             lsp_server::Response::new_err(id, lsp_server::ErrorCode::InternalError as i32, message)
181         }
182     }
183 }
184 
result_to_response<R>( id: lsp_server::RequestId, result: Result<R::Result>, ) -> lsp_server::Response where R: lsp_types::request::Request + 'static, R::Params: DeserializeOwned + 'static, R::Result: Serialize + 'static,185 fn result_to_response<R>(
186     id: lsp_server::RequestId,
187     result: Result<R::Result>,
188 ) -> lsp_server::Response
189 where
190     R: lsp_types::request::Request + 'static,
191     R::Params: DeserializeOwned + 'static,
192     R::Result: Serialize + 'static,
193 {
194     match result {
195         Ok(resp) => lsp_server::Response::new_ok(id, &resp),
196         Err(e) => match e.downcast::<LspError>() {
197             Ok(lsp_error) => lsp_server::Response::new_err(id, lsp_error.code, lsp_error.message),
198             Err(e) => {
199                 if is_cancelled(&*e) {
200                     lsp_server::Response::new_err(
201                         id,
202                         lsp_server::ErrorCode::ContentModified as i32,
203                         "content modified".to_string(),
204                     )
205                 } else {
206                     lsp_server::Response::new_err(
207                         id,
208                         lsp_server::ErrorCode::InternalError as i32,
209                         e.to_string(),
210                     )
211                 }
212             }
213         },
214     }
215 }
216 
217 pub(crate) struct NotificationDispatcher<'a> {
218     pub(crate) not: Option<lsp_server::Notification>,
219     pub(crate) global_state: &'a mut GlobalState,
220 }
221 
222 impl<'a> NotificationDispatcher<'a> {
on<N>( &mut self, f: fn(&mut GlobalState, N::Params) -> Result<()>, ) -> Result<&mut Self> where N: lsp_types::notification::Notification + 'static, N::Params: DeserializeOwned + Send + 'static,223     pub(crate) fn on<N>(
224         &mut self,
225         f: fn(&mut GlobalState, N::Params) -> Result<()>,
226     ) -> Result<&mut Self>
227     where
228         N: lsp_types::notification::Notification + 'static,
229         N::Params: DeserializeOwned + Send + 'static,
230     {
231         let not = match self.not.take() {
232             Some(it) => it,
233             None => return Ok(self),
234         };
235         let params = match not.extract::<N::Params>(N::METHOD) {
236             Ok(it) => it,
237             Err(not) => {
238                 self.not = Some(not);
239                 return Ok(self);
240             }
241         };
242         let _pctx = stdx::panic_context::enter(format!(
243             "\nversion: {}\nnotification: {}",
244             env!("REV"),
245             N::METHOD
246         ));
247         f(self.global_state, params)?;
248         Ok(self)
249     }
250 
finish(&mut self)251     pub(crate) fn finish(&mut self) {
252         if let Some(not) = &self.not {
253             if !not.method.starts_with("$/") {
254                 tracing::error!("unhandled notification: {:?}", not);
255             }
256         }
257     }
258 }
259