1 //
2 // Copyright (c) ZeroC, Inc. All rights reserved.
3 //
4 
5 using System;
6 using System.Collections;
7 using System.Collections.Generic;
8 using System.Diagnostics;
9 using System.Threading;
10 using Ice.Instrumentation;
11 
12 namespace IceInternal
13 {
14     public class CollocatedRequestHandler : RequestHandler, ResponseHandler
15     {
16         private void
fillInValue(Ice.OutputStream os, int pos, int value)17         fillInValue(Ice.OutputStream os, int pos, int value)
18         {
19             os.rewriteInt(value, pos);
20         }
21 
22         public
23         CollocatedRequestHandler(Reference @ref, Ice.ObjectAdapter adapter)
24         {
25             _reference = @ref;
26             _dispatcher = _reference.getInstance().initializationData().dispatcher != null;
27             _response = _reference.getMode() == Reference.Mode.ModeTwoway;
28             _adapter = (Ice.ObjectAdapterI)adapter;
29 
30             _logger = _reference.getInstance().initializationData().logger; // Cached for better performance.
31             _traceLevels = _reference.getInstance().traceLevels(); // Cached for better performance.
32             _requestId = 0;
33         }
34 
update(RequestHandler previousHandler, RequestHandler newHandler)35         public RequestHandler update(RequestHandler previousHandler, RequestHandler newHandler)
36         {
37             return previousHandler == this ? newHandler : this;
38         }
39 
sendAsyncRequest(ProxyOutgoingAsyncBase outAsync)40         public int sendAsyncRequest(ProxyOutgoingAsyncBase outAsync)
41         {
42             return outAsync.invokeCollocated(this);
43         }
44 
asyncRequestCanceled(OutgoingAsyncBase outAsync, Ice.LocalException ex)45         public void asyncRequestCanceled(OutgoingAsyncBase outAsync, Ice.LocalException ex)
46         {
47             lock(this)
48             {
49                 int requestId;
50                 if(_sendAsyncRequests.TryGetValue(outAsync, out requestId))
51                 {
52                     if(requestId > 0)
53                     {
54                         _asyncRequests.Remove(requestId);
55                     }
56                     _sendAsyncRequests.Remove(outAsync);
57                     if(outAsync.exception(ex))
58                     {
59                         outAsync.invokeExceptionAsync();
60                     }
61                     _adapter.decDirectCount(); // invokeAll won't be called, decrease the direct count.
62                     return;
63                 }
64                 if(outAsync is OutgoingAsync)
65                 {
66                     OutgoingAsync o = (OutgoingAsync)outAsync;
67                     Debug.Assert(o != null);
68                     foreach(KeyValuePair<int, OutgoingAsyncBase> e in _asyncRequests)
69                     {
70                         if(e.Value == o)
71                         {
72                             _asyncRequests.Remove(e.Key);
73                             if(outAsync.exception(ex))
74                             {
75                                 outAsync.invokeExceptionAsync();
76                             }
77                             return;
78                         }
79                     }
80                 }
81             }
82         }
83 
sendResponse(int requestId, Ice.OutputStream os, byte status, bool amd)84         public void sendResponse(int requestId, Ice.OutputStream os, byte status, bool amd)
85         {
86             OutgoingAsyncBase outAsync;
87             lock(this)
88             {
89                 Debug.Assert(_response);
90 
91                 if(_traceLevels.protocol >= 1)
92                 {
93                     fillInValue(os, 10, os.size());
94                 }
95 
96                 // Adopt the OutputStream's buffer.
97                 Ice.InputStream iss = new Ice.InputStream(os.instance(), os.getEncoding(), os.getBuffer(), true);
98 
99                 iss.pos(Protocol.replyHdr.Length + 4);
100 
101                 if(_traceLevels.protocol >= 1)
102                 {
103                     TraceUtil.traceRecv(iss, _logger, _traceLevels);
104                 }
105 
106                 if(_asyncRequests.TryGetValue(requestId, out outAsync))
107                 {
108                     outAsync.getIs().swap(iss);
109                     if(!outAsync.response())
110                     {
111                         outAsync = null;
112                     }
113                     _asyncRequests.Remove(requestId);
114                 }
115             }
116 
117             if(outAsync != null)
118             {
119                 if(amd)
120                 {
121                     outAsync.invokeResponseAsync();
122                 }
123                 else
124                 {
125                     outAsync.invokeResponse();
126                 }
127             }
128             _adapter.decDirectCount();
129         }
130 
131         public void
sendNoResponse()132         sendNoResponse()
133         {
134             _adapter.decDirectCount();
135         }
136 
137         public bool
systemException(int requestId, Ice.SystemException ex, bool amd)138         systemException(int requestId, Ice.SystemException ex, bool amd)
139         {
140             handleException(requestId, ex, amd);
141             _adapter.decDirectCount();
142             return true;
143         }
144 
145         public void
invokeException(int requestId, Ice.LocalException ex, int invokeNum, bool amd)146         invokeException(int requestId, Ice.LocalException ex, int invokeNum, bool amd)
147         {
148             handleException(requestId, ex, amd);
149             _adapter.decDirectCount();
150         }
151 
152         public Reference
getReference()153         getReference()
154         {
155             return _reference;
156         }
157 
158         public Ice.ConnectionI
getConnection()159         getConnection()
160         {
161             return null;
162         }
163 
invokeAsyncRequest(OutgoingAsyncBase outAsync, int batchRequestNum, bool synchronous)164         public int invokeAsyncRequest(OutgoingAsyncBase outAsync, int batchRequestNum, bool synchronous)
165         {
166             //
167             // Increase the direct count to prevent the thread pool from being destroyed before
168             // invokeAll is called. This will also throw if the object adapter has been deactivated.
169             //
170             _adapter.incDirectCount();
171 
172             int requestId = 0;
173             try
174             {
175                 lock(this)
176                 {
177                     outAsync.cancelable(this); // This will throw if the request is canceled
178 
179                     if(_response)
180                     {
181                         requestId = ++_requestId;
182                         _asyncRequests.Add(requestId, outAsync);
183                     }
184 
185                     _sendAsyncRequests.Add(outAsync, requestId);
186                 }
187             }
188             catch(Exception)
189             {
190                 _adapter.decDirectCount();
191                 throw;
192             }
193 
194             outAsync.attachCollocatedObserver(_adapter, requestId);
195             if(!synchronous || !_response || _reference.getInvocationTimeout() > 0)
196             {
197                 // Don't invoke from the user thread if async or invocation timeout is set
198                 _adapter.getThreadPool().dispatch(
199                     () =>
200                     {
201                         if (sentAsync(outAsync))
202                         {
203                             invokeAll(outAsync.getOs(), requestId, batchRequestNum);
204                         }
205                     }, null);
206             }
207             else if(_dispatcher)
208             {
209                 _adapter.getThreadPool().dispatchFromThisThread(
210                     () =>
211                     {
212                         if (sentAsync(outAsync))
213                         {
214                             invokeAll(outAsync.getOs(), requestId, batchRequestNum);
215                         }
216                     }, null);
217             }
218             else // Optimization: directly call invokeAll if there's no dispatcher.
219             {
220                 if(sentAsync(outAsync))
221                 {
222                     invokeAll(outAsync.getOs(), requestId, batchRequestNum);
223                 }
224             }
225             return OutgoingAsyncBase.AsyncStatusQueued;
226         }
227 
sentAsync(OutgoingAsyncBase outAsync)228         private bool sentAsync(OutgoingAsyncBase outAsync)
229         {
230             lock(this)
231             {
232                 if(!_sendAsyncRequests.Remove(outAsync))
233                 {
234                     return false; // The request timed-out.
235                 }
236 
237                 if(!outAsync.sent())
238                 {
239                     return true;
240                 }
241             }
242             outAsync.invokeSent();
243             return true;
244         }
245 
invokeAll(Ice.OutputStream os, int requestId, int batchRequestNum)246         private void invokeAll(Ice.OutputStream os, int requestId, int batchRequestNum)
247         {
248             if(_traceLevels.protocol >= 1)
249             {
250                 fillInValue(os, 10, os.size());
251                 if(requestId > 0)
252                 {
253                     fillInValue(os, Protocol.headerSize, requestId);
254                 }
255                 else if(batchRequestNum > 0)
256                 {
257                     fillInValue(os, Protocol.headerSize, batchRequestNum);
258                 }
259                 TraceUtil.traceSend(os, _logger, _traceLevels);
260             }
261 
262             Ice.InputStream iss = new Ice.InputStream(os.instance(), os.getEncoding(), os.getBuffer(), false);
263 
264             if(batchRequestNum > 0)
265             {
266                 iss.pos(Protocol.requestBatchHdr.Length);
267             }
268             else
269             {
270                 iss.pos(Protocol.requestHdr.Length);
271             }
272 
273             int invokeNum = batchRequestNum > 0 ? batchRequestNum : 1;
274             ServantManager servantManager = _adapter.getServantManager();
275             try
276             {
277                 while(invokeNum > 0)
278                 {
279                     //
280                     // Increase the direct count for the dispatch. We increase it again here for
281                     // each dispatch. It's important for the direct count to be > 0 until the last
282                     // collocated request response is sent to make sure the thread pool isn't
283                     // destroyed before.
284                     //
285                     try
286                     {
287                         _adapter.incDirectCount();
288                     }
289                     catch(Ice.ObjectAdapterDeactivatedException ex)
290                     {
291                         handleException(requestId, ex, false);
292                         break;
293                     }
294 
295                     Incoming inS = new Incoming(_reference.getInstance(), this, null, _adapter, _response, (byte)0,
296                                                 requestId);
297                     inS.invoke(servantManager, iss);
298                     --invokeNum;
299                 }
300             }
301             catch(Ice.LocalException ex)
302             {
303                 invokeException(requestId, ex, invokeNum, false); // Fatal invocation exception
304             }
305 
306             _adapter.decDirectCount();
307         }
308 
309         void
handleException(int requestId, Ice.Exception ex, bool amd)310         handleException(int requestId, Ice.Exception ex, bool amd)
311         {
312             if(requestId == 0)
313             {
314                 return; // Ignore exception for oneway messages.
315             }
316 
317             OutgoingAsyncBase outAsync;
318             lock(this)
319             {
320                 if(_asyncRequests.TryGetValue(requestId, out outAsync))
321                 {
322                     if(!outAsync.exception(ex))
323                     {
324                         outAsync = null;
325                     }
326                     _asyncRequests.Remove(requestId);
327                 }
328             }
329 
330             if(outAsync != null)
331             {
332                 //
333                 // If called from an AMD dispatch, invoke asynchronously
334                 // the completion callback since this might be called from
335                 // the user code.
336                 //
337                 if(amd)
338                 {
339                     outAsync.invokeExceptionAsync();
340                 }
341                 else
342                 {
343                     outAsync.invokeException();
344                 }
345             }
346         }
347 
348         private readonly Reference _reference;
349         private readonly bool _dispatcher;
350         private readonly bool _response;
351         private readonly Ice.ObjectAdapterI _adapter;
352         private readonly Ice.Logger _logger;
353         private readonly TraceLevels _traceLevels;
354 
355         private int _requestId;
356 
357         private Dictionary<OutgoingAsyncBase, int> _sendAsyncRequests = new Dictionary<OutgoingAsyncBase, int>();
358         private Dictionary<int, OutgoingAsyncBase> _asyncRequests = new Dictionary<int, OutgoingAsyncBase>();
359     }
360 }
361