1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements. See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership. The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License. You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied. See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 using System;
21 using System.Net.Security;
22 using System.Net.Sockets;
23 using System.Security.Authentication;
24 using System.Security.Cryptography.X509Certificates;
25 
26 namespace Thrift.Transport
27 {
28     /// <summary>
29     /// SSL Socket Wrapper class
30     /// </summary>
31     public class TTLSSocket : TStreamTransport
32     {
33         /// <summary>
34         /// Internal TCP Client
35         /// </summary>
36         private TcpClient client;
37 
38         /// <summary>
39         /// The host
40         /// </summary>
41         private string host;
42 
43         /// <summary>
44         /// The port
45         /// </summary>
46         private int port;
47 
48         /// <summary>
49         /// The timeout for the connection
50         /// </summary>
51         private int timeout;
52 
53         /// <summary>
54         /// Internal SSL Stream for IO
55         /// </summary>
56         private SslStream secureStream;
57 
58         /// <summary>
59         /// Defines wheter or not this socket is a server socket<br/>
60         /// This is used for the TLS-authentication
61         /// </summary>
62         private bool isServer;
63 
64         /// <summary>
65         /// The certificate
66         /// </summary>
67         private X509Certificate certificate;
68 
69         /// <summary>
70         /// User defined certificate validator.
71         /// </summary>
72         private RemoteCertificateValidationCallback certValidator;
73 
74         /// <summary>
75         /// The function to determine which certificate to use.
76         /// </summary>
77         private LocalCertificateSelectionCallback localCertificateSelectionCallback;
78 
79         /// <summary>
80         /// The SslProtocols value that represents the protocol used for authentication.SSL protocols to be used.
81         /// </summary>
82         private readonly SslProtocols sslProtocols;
83 
84         /// <summary>
85         /// Initializes a new instance of the <see cref="TTLSSocket"/> class.
86         /// </summary>
87         /// <param name="client">An already created TCP-client</param>
88         /// <param name="certificate">The certificate.</param>
89         /// <param name="isServer">if set to <c>true</c> [is server].</param>
90         /// <param name="certValidator">User defined cert validator.</param>
91         /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
92         /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param>
TTLSSocket( TcpClient client, X509Certificate certificate, bool isServer = false, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)93         public TTLSSocket(
94             TcpClient client,
95             X509Certificate certificate,
96             bool isServer = false,
97             RemoteCertificateValidationCallback certValidator = null,
98             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
99             // TODO: Enable Tls11 and Tls12 (TLS 1.1 and 1.2) by default once we start using .NET 4.5+.
100             SslProtocols sslProtocols = SslProtocols.Tls)
101         {
102             this.client = client;
103             this.certificate = certificate;
104             this.certValidator = certValidator;
105             this.localCertificateSelectionCallback = localCertificateSelectionCallback;
106             this.sslProtocols = sslProtocols;
107             this.isServer = isServer;
108             if (isServer && certificate == null)
109             {
110                 throw new ArgumentException("TTLSSocket needs certificate to be used for server", "certificate");
111             }
112 
113             if (IsOpen)
114             {
115                 base.inputStream = client.GetStream();
116                 base.outputStream = client.GetStream();
117             }
118         }
119 
120         /// <summary>
121         /// Initializes a new instance of the <see cref="TTLSSocket"/> class.
122         /// </summary>
123         /// <param name="host">The host, where the socket should connect to.</param>
124         /// <param name="port">The port.</param>
125         /// <param name="certificatePath">The certificate path.</param>
126         /// <param name="certValidator">User defined cert validator.</param>
127         /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
128         /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param>
TTLSSocket( string host, int port, string certificatePath, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)129         public TTLSSocket(
130             string host,
131             int port,
132             string certificatePath,
133             RemoteCertificateValidationCallback certValidator = null,
134             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
135             SslProtocols sslProtocols = SslProtocols.Tls)
136             : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator, localCertificateSelectionCallback, sslProtocols)
137         {
138         }
139 
140         /// <summary>
141         /// Initializes a new instance of the <see cref="TTLSSocket"/> class.
142         /// </summary>
143         /// <param name="host">The host, where the socket should connect to.</param>
144         /// <param name="port">The port.</param>
145         /// <param name="certificate">The certificate.</param>
146         /// <param name="certValidator">User defined cert validator.</param>
147         /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
148         /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param>
TTLSSocket( string host, int port, X509Certificate certificate = null, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)149         public TTLSSocket(
150             string host,
151             int port,
152             X509Certificate certificate = null,
153             RemoteCertificateValidationCallback certValidator = null,
154             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
155             SslProtocols sslProtocols = SslProtocols.Tls)
156             : this(host, port, 0, certificate, certValidator, localCertificateSelectionCallback, sslProtocols)
157         {
158         }
159 
160         /// <summary>
161         /// Initializes a new instance of the <see cref="TTLSSocket"/> class.
162         /// </summary>
163         /// <param name="host">The host, where the socket should connect to.</param>
164         /// <param name="port">The port.</param>
165         /// <param name="timeout">The timeout.</param>
166         /// <param name="certificate">The certificate.</param>
167         /// <param name="certValidator">User defined cert validator.</param>
168         /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param>
169         /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param>
TTLSSocket( string host, int port, int timeout, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)170         public TTLSSocket(
171             string host,
172             int port,
173             int timeout,
174             X509Certificate certificate,
175             RemoteCertificateValidationCallback certValidator = null,
176             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
177             SslProtocols sslProtocols = SslProtocols.Tls)
178         {
179             this.host = host;
180             this.port = port;
181             this.timeout = timeout;
182             this.certificate = certificate;
183             this.certValidator = certValidator;
184             this.localCertificateSelectionCallback = localCertificateSelectionCallback;
185             this.sslProtocols = sslProtocols;
186 
187             InitSocket();
188         }
189 
190         /// <summary>
191         /// Creates the TcpClient and sets the timeouts
192         /// </summary>
InitSocket()193         private void InitSocket()
194         {
195             client = TSocketVersionizer.CreateTcpClient();
196             client.ReceiveTimeout = client.SendTimeout = timeout;
197             client.Client.NoDelay = true;
198         }
199 
200         /// <summary>
201         /// Sets Send / Recv Timeout for IO
202         /// </summary>
203         public int Timeout
204         {
205             set
206             {
207                 this.client.ReceiveTimeout = this.client.SendTimeout = this.timeout = value;
208             }
209         }
210 
211         /// <summary>
212         /// Gets the TCP client.
213         /// </summary>
214         public TcpClient TcpClient
215         {
216             get
217             {
218                 return client;
219             }
220         }
221 
222         /// <summary>
223         /// Gets the host.
224         /// </summary>
225         public string Host
226         {
227             get
228             {
229                 return host;
230             }
231         }
232 
233         /// <summary>
234         /// Gets the port.
235         /// </summary>
236         public int Port
237         {
238             get
239             {
240                 return port;
241             }
242         }
243 
244         /// <summary>
245         /// Gets a value indicating whether TCP Client is Cpen
246         /// </summary>
247         public override bool IsOpen
248         {
249             get
250             {
251                 if (this.client == null)
252                 {
253                     return false;
254                 }
255 
256                 return this.client.Connected;
257             }
258         }
259 
260         /// <summary>
261         /// Validates the certificates!<br/>
262         /// </summary>
263         /// <param name="sender">The sender-object.</param>
264         /// <param name="certificate">The used certificate.</param>
265         /// <param name="chain">The certificate chain.</param>
266         /// <param name="sslValidationErrors">An enum, which lists all the errors from the .NET certificate check.</param>
267         /// <returns></returns>
DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors)268         private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors)
269         {
270             return (sslValidationErrors == SslPolicyErrors.None);
271         }
272 
273         /// <summary>
274         /// Connects to the host and starts the routine, which sets up the TLS
275         /// </summary>
Open()276         public override void Open()
277         {
278             if (IsOpen)
279             {
280                 throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected");
281             }
282 
283             if (string.IsNullOrEmpty(host))
284             {
285                 throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host");
286             }
287 
288             if (port <= 0)
289             {
290                 throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port");
291             }
292 
293             if (client == null)
294             {
295                 InitSocket();
296             }
297 
298             if (timeout == 0)            // no timeout -> infinite
299             {
300                 client.Connect(host, port);
301             }
302             else                        // we have a timeout -> use it
303             {
304                 ConnectHelper hlp = new ConnectHelper(client);
305                 IAsyncResult asyncres = client.BeginConnect(host, port, new AsyncCallback(ConnectCallback), hlp);
306                 bool bConnected = asyncres.AsyncWaitHandle.WaitOne(timeout) && client.Connected;
307                 if (!bConnected)
308                 {
309                     lock (hlp.Mutex)
310                     {
311                         if (hlp.CallbackDone)
312                         {
313                             asyncres.AsyncWaitHandle.Close();
314                             client.Close();
315                         }
316                         else
317                         {
318                             hlp.DoCleanup = true;
319                             client = null;
320                         }
321                     }
322                     throw new TTransportException(TTransportException.ExceptionType.TimedOut, "Connect timed out");
323                 }
324             }
325 
326             setupTLS();
327         }
328 
329         /// <summary>
330         /// Creates a TLS-stream and lays it over the existing socket
331         /// </summary>
setupTLS()332         public void setupTLS()
333         {
334             RemoteCertificateValidationCallback validator = this.certValidator ?? DefaultCertificateValidator;
335 
336             if (this.localCertificateSelectionCallback != null)
337             {
338                 this.secureStream = new SslStream(
339                     this.client.GetStream(),
340                     false,
341                     validator,
342                     this.localCertificateSelectionCallback
343                 );
344             }
345             else
346             {
347                 this.secureStream = new SslStream(
348                     this.client.GetStream(),
349                     false,
350                     validator
351                 );
352             }
353 
354             try
355             {
356                 if (isServer)
357                 {
358                     // Server authentication
359                     this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, sslProtocols, true);
360                 }
361                 else
362                 {
363                     // Client authentication
364                     X509CertificateCollection certs = certificate != null ? new X509CertificateCollection { certificate } : new X509CertificateCollection();
365                     this.secureStream.AuthenticateAsClient(host, certs, sslProtocols, true);
366                 }
367             }
368             catch (Exception)
369             {
370                 this.Close();
371                 throw;
372             }
373 
374             inputStream = this.secureStream;
375             outputStream = this.secureStream;
376         }
377 
ConnectCallback(IAsyncResult asyncres)378         static void ConnectCallback(IAsyncResult asyncres)
379         {
380             ConnectHelper hlp = asyncres.AsyncState as ConnectHelper;
381             lock (hlp.Mutex)
382             {
383                 hlp.CallbackDone = true;
384 
385                 try
386                 {
387                     if (hlp.Client.Client != null)
388                         hlp.Client.EndConnect(asyncres);
389                 }
390                 catch (Exception)
391                 {
392                     // catch that away
393                 }
394 
395                 if (hlp.DoCleanup)
396                 {
397                     try
398                     {
399                         asyncres.AsyncWaitHandle.Close();
400                     }
401                     catch (Exception) { }
402 
403                     try
404                     {
405                         if (hlp.Client is IDisposable)
406                             ((IDisposable)hlp.Client).Dispose();
407                     }
408                     catch (Exception) { }
409                     hlp.Client = null;
410                 }
411             }
412         }
413 
414         private class ConnectHelper
415         {
416             public object Mutex = new object();
417             public bool DoCleanup = false;
418             public bool CallbackDone = false;
419             public TcpClient Client;
ConnectHelper(TcpClient client)420             public ConnectHelper(TcpClient client)
421             {
422                 Client = client;
423             }
424         }
425 
426         /// <summary>
427         /// Closes the SSL Socket
428         /// </summary>
Close()429         public override void Close()
430         {
431             base.Close();
432             if (this.client != null)
433             {
434                 this.client.Close();
435                 this.client = null;
436             }
437 
438             if (this.secureStream != null)
439             {
440                 this.secureStream.Close();
441                 this.secureStream = null;
442             }
443         }
444     }
445 }
446