1 // Licensed to the Apache Software Foundation(ASF) under one
2 // or more contributor license agreements.See the NOTICE file
3 // distributed with this work for additional information
4 // regarding copyright ownership.The ASF licenses this file
5 // to you under the Apache License, Version 2.0 (the
6 // "License"); you may not use this file except in compliance
7 // with the License. You may obtain a copy of the License at
8 //
9 //     http://www.apache.org/licenses/LICENSE-2.0
10 //
11 // Unless required by applicable law or agreed to in writing,
12 // software distributed under the License is distributed on an
13 // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14 // KIND, either express or implied. See the License for the
15 // specific language governing permissions and limitations
16 // under the License.
17 
18 using System;
19 using System.Net;
20 using System.Net.Security;
21 using System.Net.Sockets;
22 using System.Security.Authentication;
23 using System.Security.Cryptography.X509Certificates;
24 using System.Threading;
25 using System.Threading.Tasks;
26 using Thrift.Transport.Client;
27 
28 namespace Thrift.Transport.Server
29 {
30     // ReSharper disable once InconsistentNaming
31     public class TTlsServerSocketTransport : TServerTransport
32     {
33         private readonly RemoteCertificateValidationCallback _clientCertValidator;
34         private readonly int _clientTimeout = 0;
35         private readonly LocalCertificateSelectionCallback _localCertificateSelectionCallback;
36         private readonly X509Certificate2 _serverCertificate;
37         private readonly SslProtocols _sslProtocols;
38         private TcpListener _server;
39 
TTlsServerSocketTransport( TcpListener listener, TConfiguration config, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12)40         public TTlsServerSocketTransport(
41             TcpListener listener,
42             TConfiguration config,
43             X509Certificate2 certificate,
44             RemoteCertificateValidationCallback clientCertValidator = null,
45             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
46             SslProtocols sslProtocols = SslProtocols.Tls12)
47             : base(config)
48         {
49             if (!certificate.HasPrivateKey)
50             {
51                 throw new TTransportException(TTransportException.ExceptionType.Unknown,
52                     "Your server-certificate needs to have a private key");
53             }
54 
55             _serverCertificate = certificate;
56             _clientCertValidator = clientCertValidator;
57             _localCertificateSelectionCallback = localCertificateSelectionCallback;
58             _sslProtocols = sslProtocols;
59             _server = listener;
60         }
61 
TTlsServerSocketTransport( int port, TConfiguration config, X509Certificate2 certificate, RemoteCertificateValidationCallback clientCertValidator = null, LocalCertificateSelectionCallback localCertificateSelectionCallback = null, SslProtocols sslProtocols = SslProtocols.Tls12)62         public TTlsServerSocketTransport(
63             int port,
64             TConfiguration config,
65             X509Certificate2 certificate,
66             RemoteCertificateValidationCallback clientCertValidator = null,
67             LocalCertificateSelectionCallback localCertificateSelectionCallback = null,
68             SslProtocols sslProtocols = SslProtocols.Tls12)
69             : this(null, config, certificate, clientCertValidator, localCertificateSelectionCallback, sslProtocols)
70         {
71             try
72             {
73                 // Create server socket
74                 _server = new TcpListener(IPAddress.Any, port);
75                 _server.Server.NoDelay = true;
76             }
77             catch (Exception)
78             {
79                 _server = null;
80                 throw new TTransportException($"Could not create ServerSocket on port {port}.");
81             }
82         }
83 
IsOpen()84         public override bool IsOpen()
85         {
86             return (_server != null)
87 				&& (_server.Server != null)
88 				&& _server.Server.IsBound;
89         }
90 
GetPort()91         public int GetPort()
92         {
93             if ((_server != null) && (_server.Server != null) && (_server.Server.LocalEndPoint != null))
94             {
95                 if (_server.Server.LocalEndPoint is IPEndPoint server)
96                 {
97                     return server.Port;
98                 }
99                 else
100                 {
101                     throw new TTransportException("ServerSocket is not a network socket");
102                 }
103             }
104             else
105             {
106                 throw new TTransportException("ServerSocket is not open");
107             }
108         }
109 
Listen()110         public override void Listen()
111         {
112             // Make sure accept is not blocking
113             if (_server != null)
114             {
115                 try
116                 {
117                     _server.Start();
118                 }
119                 catch (SocketException sx)
120                 {
121                     throw new TTransportException($"Could not accept on listening socket: {sx.Message}");
122                 }
123             }
124         }
125 
IsClientPending()126         public override bool IsClientPending()
127         {
128             return _server.Pending();
129         }
130 
AcceptImplementationAsync(CancellationToken cancellationToken)131         protected override async ValueTask<TTransport> AcceptImplementationAsync(CancellationToken cancellationToken)
132         {
133             cancellationToken.ThrowIfCancellationRequested();
134 
135             if (_server == null)
136             {
137                 throw new TTransportException(TTransportException.ExceptionType.NotOpen, "No underlying server socket.");
138             }
139 
140             try
141             {
142                 var client = await _server.AcceptTcpClientAsync();
143                 client.SendTimeout = client.ReceiveTimeout = _clientTimeout;
144 
145                 //wrap the client in an SSL Socket passing in the SSL cert
146                 var tTlsSocket = new TTlsSocketTransport(
147                     client, Configuration,
148                     _serverCertificate, true, _clientCertValidator,
149                     _localCertificateSelectionCallback, _sslProtocols);
150 
151                 await tTlsSocket.SetupTlsAsync();
152 
153                 return tTlsSocket;
154             }
155             catch (Exception ex)
156             {
157                 throw new TTransportException(ex.ToString());
158             }
159         }
160 
Close()161         public override void Close()
162         {
163             if (_server != null)
164             {
165                 try
166                 {
167                     _server.Stop();
168                 }
169                 catch (Exception ex)
170                 {
171                     throw new TTransportException($"WARNING: Could not close server socket: {ex}");
172                 }
173 
174                 _server = null;
175             }
176         }
177     }
178 }
179