1 /** 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 using System; 21 using System.Net.Security; 22 using System.Net.Sockets; 23 using System.Security.Authentication; 24 using System.Security.Cryptography.X509Certificates; 25 26 namespace Thrift.Transport 27 { 28 /// <summary> 29 /// SSL Socket Wrapper class 30 /// </summary> 31 public class TTLSSocket : TStreamTransport 32 { 33 /// <summary> 34 /// Internal TCP Client 35 /// </summary> 36 private TcpClient client; 37 38 /// <summary> 39 /// The host 40 /// </summary> 41 private string host; 42 43 /// <summary> 44 /// The port 45 /// </summary> 46 private int port; 47 48 /// <summary> 49 /// The timeout for the connection 50 /// </summary> 51 private int timeout; 52 53 /// <summary> 54 /// Internal SSL Stream for IO 55 /// </summary> 56 private SslStream secureStream; 57 58 /// <summary> 59 /// Defines wheter or not this socket is a server socket<br/> 60 /// This is used for the TLS-authentication 61 /// </summary> 62 private bool isServer; 63 64 /// <summary> 65 /// The certificate 66 /// </summary> 67 private X509Certificate certificate; 68 69 /// <summary> 70 /// User defined certificate validator. 71 /// </summary> 72 private RemoteCertificateValidationCallback certValidator; 73 74 /// <summary> 75 /// The function to determine which certificate to use. 76 /// </summary> 77 private LocalCertificateSelectionCallback localCertificateSelectionCallback; 78 79 /// <summary> 80 /// The SslProtocols value that represents the protocol used for authentication.SSL protocols to be used. 81 /// </summary> 82 private readonly SslProtocols sslProtocols; 83 84 /// <summary> 85 /// Initializes a new instance of the <see cref="TTLSSocket"/> class. 86 /// </summary> 87 /// <param name="client">An already created TCP-client</param> 88 /// <param name="certificate">The certificate.</param> 89 /// <param name="isServer">if set to <c>true</c> [is server].</param> 90 /// <param name="certValidator">User defined cert validator.</param> 91 /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> 92 /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param> TTLSSocket( TcpClient client, X509Certificate certificate, bool isServer = false, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)93 public TTLSSocket( 94 TcpClient client, 95 X509Certificate certificate, 96 bool isServer = false, 97 RemoteCertificateValidationCallback certValidator = null, 98 LocalCertificateSelectionCallback localCertificateSelectionCallback = null, 99 // TODO: Enable Tls11 and Tls12 (TLS 1.1 and 1.2) by default once we start using .NET 4.5+. 100 SslProtocols sslProtocols = SslProtocols.Tls) 101 { 102 this.client = client; 103 this.certificate = certificate; 104 this.certValidator = certValidator; 105 this.localCertificateSelectionCallback = localCertificateSelectionCallback; 106 this.sslProtocols = sslProtocols; 107 this.isServer = isServer; 108 if (isServer && certificate == null) 109 { 110 throw new ArgumentException("TTLSSocket needs certificate to be used for server", "certificate"); 111 } 112 113 if (IsOpen) 114 { 115 base.inputStream = client.GetStream(); 116 base.outputStream = client.GetStream(); 117 } 118 } 119 120 /// <summary> 121 /// Initializes a new instance of the <see cref="TTLSSocket"/> class. 122 /// </summary> 123 /// <param name="host">The host, where the socket should connect to.</param> 124 /// <param name="port">The port.</param> 125 /// <param name="certificatePath">The certificate path.</param> 126 /// <param name="certValidator">User defined cert validator.</param> 127 /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> 128 /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param> TTLSSocket( string host, int port, string certificatePath, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)129 public TTLSSocket( 130 string host, 131 int port, 132 string certificatePath, 133 RemoteCertificateValidationCallback certValidator = null, 134 LocalCertificateSelectionCallback localCertificateSelectionCallback = null, 135 SslProtocols sslProtocols = SslProtocols.Tls) 136 : this(host, port, 0, X509Certificate.CreateFromCertFile(certificatePath), certValidator, localCertificateSelectionCallback, sslProtocols) 137 { 138 } 139 140 /// <summary> 141 /// Initializes a new instance of the <see cref="TTLSSocket"/> class. 142 /// </summary> 143 /// <param name="host">The host, where the socket should connect to.</param> 144 /// <param name="port">The port.</param> 145 /// <param name="certificate">The certificate.</param> 146 /// <param name="certValidator">User defined cert validator.</param> 147 /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> 148 /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param> TTLSSocket( string host, int port, X509Certificate certificate = null, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)149 public TTLSSocket( 150 string host, 151 int port, 152 X509Certificate certificate = null, 153 RemoteCertificateValidationCallback certValidator = null, 154 LocalCertificateSelectionCallback localCertificateSelectionCallback = null, 155 SslProtocols sslProtocols = SslProtocols.Tls) 156 : this(host, port, 0, certificate, certValidator, localCertificateSelectionCallback, sslProtocols) 157 { 158 } 159 160 /// <summary> 161 /// Initializes a new instance of the <see cref="TTLSSocket"/> class. 162 /// </summary> 163 /// <param name="host">The host, where the socket should connect to.</param> 164 /// <param name="port">The port.</param> 165 /// <param name="timeout">The timeout.</param> 166 /// <param name="certificate">The certificate.</param> 167 /// <param name="certValidator">User defined cert validator.</param> 168 /// <param name="localCertificateSelectionCallback">The callback to select which certificate to use.</param> 169 /// <param name="sslProtocols">The SslProtocols value that represents the protocol used for authentication.</param> TTLSSocket( string host, int port, int timeout, X509Certificate certificate, RemoteCertificateValidationCallback certValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls)170 public TTLSSocket( 171 string host, 172 int port, 173 int timeout, 174 X509Certificate certificate, 175 RemoteCertificateValidationCallback certValidator = null, 176 LocalCertificateSelectionCallback localCertificateSelectionCallback = null, 177 SslProtocols sslProtocols = SslProtocols.Tls) 178 { 179 this.host = host; 180 this.port = port; 181 this.timeout = timeout; 182 this.certificate = certificate; 183 this.certValidator = certValidator; 184 this.localCertificateSelectionCallback = localCertificateSelectionCallback; 185 this.sslProtocols = sslProtocols; 186 187 InitSocket(); 188 } 189 190 /// <summary> 191 /// Creates the TcpClient and sets the timeouts 192 /// </summary> InitSocket()193 private void InitSocket() 194 { 195 client = TSocketVersionizer.CreateTcpClient(); 196 client.ReceiveTimeout = client.SendTimeout = timeout; 197 client.Client.NoDelay = true; 198 } 199 200 /// <summary> 201 /// Sets Send / Recv Timeout for IO 202 /// </summary> 203 public int Timeout 204 { 205 set 206 { 207 this.client.ReceiveTimeout = this.client.SendTimeout = this.timeout = value; 208 } 209 } 210 211 /// <summary> 212 /// Gets the TCP client. 213 /// </summary> 214 public TcpClient TcpClient 215 { 216 get 217 { 218 return client; 219 } 220 } 221 222 /// <summary> 223 /// Gets the host. 224 /// </summary> 225 public string Host 226 { 227 get 228 { 229 return host; 230 } 231 } 232 233 /// <summary> 234 /// Gets the port. 235 /// </summary> 236 public int Port 237 { 238 get 239 { 240 return port; 241 } 242 } 243 244 /// <summary> 245 /// Gets a value indicating whether TCP Client is Cpen 246 /// </summary> 247 public override bool IsOpen 248 { 249 get 250 { 251 if (this.client == null) 252 { 253 return false; 254 } 255 256 return this.client.Connected; 257 } 258 } 259 260 /// <summary> 261 /// Validates the certificates!<br/> 262 /// </summary> 263 /// <param name="sender">The sender-object.</param> 264 /// <param name="certificate">The used certificate.</param> 265 /// <param name="chain">The certificate chain.</param> 266 /// <param name="sslValidationErrors">An enum, which lists all the errors from the .NET certificate check.</param> 267 /// <returns></returns> DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors)268 private bool DefaultCertificateValidator(object sender, X509Certificate certificate, X509Chain chain, SslPolicyErrors sslValidationErrors) 269 { 270 return (sslValidationErrors == SslPolicyErrors.None); 271 } 272 273 /// <summary> 274 /// Connects to the host and starts the routine, which sets up the TLS 275 /// </summary> Open()276 public override void Open() 277 { 278 if (IsOpen) 279 { 280 throw new TTransportException(TTransportException.ExceptionType.AlreadyOpen, "Socket already connected"); 281 } 282 283 if (string.IsNullOrEmpty(host)) 284 { 285 throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open null host"); 286 } 287 288 if (port <= 0) 289 { 290 throw new TTransportException(TTransportException.ExceptionType.NotOpen, "Cannot open without port"); 291 } 292 293 if (client == null) 294 { 295 InitSocket(); 296 } 297 298 if (timeout == 0) // no timeout -> infinite 299 { 300 client.Connect(host, port); 301 } 302 else // we have a timeout -> use it 303 { 304 ConnectHelper hlp = new ConnectHelper(client); 305 IAsyncResult asyncres = client.BeginConnect(host, port, new AsyncCallback(ConnectCallback), hlp); 306 bool bConnected = asyncres.AsyncWaitHandle.WaitOne(timeout) && client.Connected; 307 if (!bConnected) 308 { 309 lock (hlp.Mutex) 310 { 311 if (hlp.CallbackDone) 312 { 313 asyncres.AsyncWaitHandle.Close(); 314 client.Close(); 315 } 316 else 317 { 318 hlp.DoCleanup = true; 319 client = null; 320 } 321 } 322 throw new TTransportException(TTransportException.ExceptionType.TimedOut, "Connect timed out"); 323 } 324 } 325 326 setupTLS(); 327 } 328 329 /// <summary> 330 /// Creates a TLS-stream and lays it over the existing socket 331 /// </summary> setupTLS()332 public void setupTLS() 333 { 334 RemoteCertificateValidationCallback validator = this.certValidator ?? DefaultCertificateValidator; 335 336 if (this.localCertificateSelectionCallback != null) 337 { 338 this.secureStream = new SslStream( 339 this.client.GetStream(), 340 false, 341 validator, 342 this.localCertificateSelectionCallback 343 ); 344 } 345 else 346 { 347 this.secureStream = new SslStream( 348 this.client.GetStream(), 349 false, 350 validator 351 ); 352 } 353 354 try 355 { 356 if (isServer) 357 { 358 // Server authentication 359 this.secureStream.AuthenticateAsServer(this.certificate, this.certValidator != null, sslProtocols, true); 360 } 361 else 362 { 363 // Client authentication 364 X509CertificateCollection certs = certificate != null ? new X509CertificateCollection { certificate } : new X509CertificateCollection(); 365 this.secureStream.AuthenticateAsClient(host, certs, sslProtocols, true); 366 } 367 } 368 catch (Exception) 369 { 370 this.Close(); 371 throw; 372 } 373 374 inputStream = this.secureStream; 375 outputStream = this.secureStream; 376 } 377 ConnectCallback(IAsyncResult asyncres)378 static void ConnectCallback(IAsyncResult asyncres) 379 { 380 ConnectHelper hlp = asyncres.AsyncState as ConnectHelper; 381 lock (hlp.Mutex) 382 { 383 hlp.CallbackDone = true; 384 385 try 386 { 387 if (hlp.Client.Client != null) 388 hlp.Client.EndConnect(asyncres); 389 } 390 catch (Exception) 391 { 392 // catch that away 393 } 394 395 if (hlp.DoCleanup) 396 { 397 try 398 { 399 asyncres.AsyncWaitHandle.Close(); 400 } 401 catch (Exception) { } 402 403 try 404 { 405 if (hlp.Client is IDisposable) 406 ((IDisposable)hlp.Client).Dispose(); 407 } 408 catch (Exception) { } 409 hlp.Client = null; 410 } 411 } 412 } 413 414 private class ConnectHelper 415 { 416 public object Mutex = new object(); 417 public bool DoCleanup = false; 418 public bool CallbackDone = false; 419 public TcpClient Client; ConnectHelper(TcpClient client)420 public ConnectHelper(TcpClient client) 421 { 422 Client = client; 423 } 424 } 425 426 /// <summary> 427 /// Closes the SSL Socket 428 /// </summary> Close()429 public override void Close() 430 { 431 base.Close(); 432 if (this.client != null) 433 { 434 this.client.Close(); 435 this.client = null; 436 } 437 438 if (this.secureStream != null) 439 { 440 this.secureStream.Close(); 441 this.secureStream = null; 442 } 443 } 444 } 445 } 446