1unit lNetSSL;
2
3{$mode objfpc}{$H+}
4
5interface
6
7uses
8  SysUtils, Classes, cTypes, OpenSSL,
9  lNet, lEvents;
10
11type
12  TLSSLMethod = (msSSLv2or3, msSSLv2, msSSLv3, msTLSv1);
13  TLSSLStatus = (slNone, slConnect, slActivateTLS, slShutdown);
14
15  TLPasswordCB = function(buf: pChar; num, rwflag: cInt; userdata: Pointer): cInt; cdecl;
16
17  { TLSSLSocket }
18
19  TLSSLSocket = class(TLSocket)
20   protected
21    FSSL: PSSL;
22    FSSLContext: PSSL_CTX;
23    FSSLStatus: TLSSLStatus;
24    FIsAcceptor: Boolean;
25    function GetConnected: Boolean; override; deprecated;
26    function GetConnectionStatus: TLSocketConnectionStatus; override;
27
28    function DoSend(const aData; const aSize: Integer): Integer; override;
29    function DoGet(out aData; const aSize: Integer): Integer; override;
30
31    function HandleResult(const aResult: Integer; aOp: TLSocketOperation): Integer; override;
32
33    function SetActiveSSL(const AValue: Boolean): Boolean;
34
35    procedure SetupSSLSocket;
36    procedure ActivateTLSEvent;
37    procedure ConnectEvent;
38    procedure AcceptEvent;
39    procedure ConnectSSL;
40    procedure AcceptSSL;
41    procedure ShutdownSSL;
42
43    function LogError(const msg: string; const ernum: Integer): Boolean; override;
44   public
45    destructor Destroy; override;
46
47    function SetState(const aState: TLSocketState; const TurnOn: Boolean = True): Boolean; override;
48
49    procedure Disconnect(const Forced: Boolean = True); override;
50   public
51    property SSLStatus: TLSSLStatus read FSSLStatus;
52  end;
53
54  { TLSSLSession }
55
56  TLSSLSession = class(TLSession)
57   protected
58    FOnSSLConnect: TLSocketEvent;
59    FOnSSLAccept: TLSocketEvent;
60    FSSLActive: Boolean;
61    FSSLContext: PSSL_CTX;
62    FPassword: string;
63    FCAFile: string;
64    FKeyFile: string;
65    FMethod: TLSSLMethod;
66    FPasswordCallback: TLPasswordCB;
67
68    procedure CallOnSSLConnect(aSocket: TLSocket);
69    procedure CallOnSSLAccept(aSocket: TLSocket);
70
71    procedure SetSSLActive(const AValue: Boolean);
72    procedure SetCAFile(AValue: string);
73    procedure SetKeyFile(AValue: string);
74    procedure SetPassword(const AValue: string);
75    procedure SetMethod(const AValue: TLSSLMethod);
76    procedure SetPasswordCallback(const AValue: TLPasswordCB);
77
78    procedure CreateSSLContext; virtual;
79   public
80    constructor Create(aOwner: TComponent); override;
81
82    procedure RegisterWithComponent(aConnection: TLConnection); override;
83
84    procedure InitHandle(aHandle: TLHandle); override;
85
86    procedure ConnectEvent(aHandle: TLHandle); override;
87    procedure ReceiveEvent(aHandle: TLHandle); override;
88    procedure AcceptEvent(aHandle: TLHandle); override;
89    function HandleSSLConnection(aSocket: TLSSLSocket; const DoAccept: Boolean = False): Boolean;
90   public
91    property Password: string read FPassword write SetPassword;
92    property CAFile: string read FCAFile write SetCAFile;
93    property KeyFile: string read FKeyFile write SetKeyFile;
94    property Method: TLSSLMethod read FMethod write SetMethod;
95    property PasswordCallback: TLPasswordCB read FPasswordCallback write SetPasswordCallback;
96    property SSLContext: PSSL_CTX read FSSLContext;
97    property SSLActive: Boolean read FSSLActive write SetSSLActive;
98    property OnSSLConnect: TLSocketEvent read FOnSSLConnect write FOnSSLConnect;
99    property OnSSLAccept: TLSocketEvent read FOnSSLAccept write FOnSSLAccept;
100  end;
101
102  function IsSSLBlockError(const anError: Longint): Boolean; inline;
103  function IsSSLNonFatalError(const anError, aRet: Longint): Boolean; inline;
104
105implementation
106
107uses
108  {Math,} lCommon;
109
110function PasswordCB(buf: pChar; num, rwflag: cInt; userdata: Pointer): cInt; cdecl;
111var
112  S: TLSSLSession;
113begin
114  S := TLSSLSession(userdata);
115
116  if num < Length(S.Password) + 1 then
117    Exit(0);
118
119  Move(S.Password[1], buf[0], Length(S.Password));
120  Result := Length(S.Password);
121end;
122
123function IsSSLBlockError(const anError: Longint): Boolean; inline;
124begin
125  Result := (anError = SSL_ERROR_WANT_READ) or (anError = SSL_ERROR_WANT_WRITE);
126end;
127
128function IsSSLNonFatalError(const anError, aRet: Longint): Boolean; inline;
129var
130  tmp: Longint;
131begin
132  Result := False;
133  if anError = SSL_ERROR_SYSCALL then repeat
134    tmp := ErrGetError();
135    if tmp = 0 then begin // we neet to check the ret
136      if aRet <= 0 then Exit; // EOF or BIO crap, we skip those
137      Result := IsNonFatalError(aRet);
138    end else // check what exactly
139      Result := IsNonFatalError(tmp);
140  until tmp <= 0; // we need to empty the queue
141end;
142
143{ TLSSLSocket }
144
145function TLSSLSocket.SetActiveSSL(const AValue: Boolean): Boolean;
146begin
147  Result := False;
148
149  if (ssSSLActive in FSocketState) = AValue then Exit(True);
150  case aValue of
151    True  : FSocketState := FSocketState + [ssSSLActive];
152    False : FSocketState := FSocketState - [ssSSLActive];
153  end;
154
155  if aValue and (FConnectionStatus = scConnected) then
156    ActivateTLSEvent;
157
158  if not aValue then begin
159    if ConnectionStatus = scConnected then
160      ShutdownSSL
161    else if FSSLStatus in [slConnect, slActivateTLS] then
162      raise Exception.Create('Switching SSL mode on socket during SSL handshake is not supported');
163  end;
164
165  Result := True;
166end;
167
168procedure TLSSLSocket.SetupSSLSocket;
169begin
170  if Assigned(FSSL) then
171    SslFree(FSSL);
172
173  FSSL := SSLNew(FSSLContext);
174  if not Assigned(FSSL) then begin
175    Bail('SSLNew error', -1);
176    Exit;
177  end;
178
179  if SslSetFd(FSSL, FHandle) = 0 then begin
180    FSSL := nil;
181    Bail('SSL setFD error', -1);
182    Exit;
183  end;
184end;
185
186procedure TLSSLSocket.ActivateTLSEvent;
187begin
188  SetupSSLSocket;
189  FSSLStatus := slActivateTLS;
190  ConnectSSL;
191end;
192
193function TLSSLSocket.GetConnected: Boolean;
194begin
195  if ssSSLActive in FSocketState then
196    Result := Assigned(FSSL) and (FSSLStatus = slNone)
197  else
198    Result := inherited;
199end;
200
201function TLSSLSocket.GetConnectionStatus: TLSocketConnectionStatus;
202begin
203  if ssSSLActive in FSocketState then case FSSLStatus of
204    slNone        : if Assigned(FSSL) then
205                      Result := scConnected
206                    else
207                      Result := scNone;
208    slConnect,
209    slActivateTLS : Result := scConnecting;
210    slShutdown    : Result := scDisconnecting;
211  end else
212    Result := inherited;
213end;
214
215function TLSSLSocket.DoSend(const aData; const aSize: Integer): Integer;
216begin
217  if ssSSLActive in FSocketState then begin
218{    if FSSLSendSize = 0 then begin
219      FSSLSendSize := Min(aSize, Length(FSSLSendBuffer));
220      Move(aData, FSSLSendBuffer[0], FSSLSendSize);
221    end;
222
223    Result := SSLWrite(FSSL, @FSSLSendBuffer[0], FSSLSendSize);
224    if Result > 0 then
225      FSSLSendSize := 0;}
226
227    Result := SSLWrite(FSSL, @aData, aSize);
228  end else
229    Result := inherited DoSend(aData, aSize);
230end;
231
232function TLSSLSocket.DoGet(out aData; const aSize: Integer): Integer;
233begin
234  if ssSSLActive in FSocketState then
235    Result := SSLRead(FSSL, @aData, aSize)
236  else
237    Result := inherited DoGet(aData, aSize);
238end;
239
240function TLSSLSocket.HandleResult(const aResult: Integer; aOp: TLSocketOperation): Integer;
241const
242  GSStr: array[TLSocketOperation] of string = ('SSLWrite', 'SSLRead');
243var
244  LastError: cInt;
245begin
246  if not (ssSSLActive in FSocketState) then
247    Exit(inherited HandleResult(aResult, aOp));
248
249  Result := aResult;
250  if Result <= 0 then begin
251    LastError := SslGetError(FSSL, Result);
252    if IsSSLBlockError(LastError) then case aOp of
253      soSend:
254         begin
255           FSocketState := FSocketState - [ssCanSend];
256           IgnoreWrite := False;
257         end;
258      soReceive:
259         begin
260           FSocketState := FSocketState - [ssCanReceive];
261           IgnoreRead := False;
262         end;
263    end else if IsSSLNonFatalError(LastError, Result) then
264      LogError(GSStr[aOp] + ' error', LastError)
265    else if (aOp = soSend) and (IsPipeError(LastError)) then
266      HardDisconnect(True)
267    else
268      Bail(GSStr[aOp] + ' error', LastError);
269    Result := 0;
270  end;
271end;
272
273procedure TLSSLSocket.ConnectEvent;
274begin
275  SetupSSLSocket;
276  FSSLStatus := slConnect;
277  ConnectSSL;
278end;
279
280procedure TLSSLSocket.AcceptEvent;
281begin
282  SetupSSLSocket;
283  FSSLStatus := slConnect;
284  AcceptSSL;
285end;
286
287function TLSSLSocket.LogError(const msg: string; const ernum: Integer): Boolean;
288var
289  s: string;
290begin
291  Result := False;
292  if not (ssSSLActive in FSocketState) then
293    Result := inherited LogError(msg, ernum)
294  else if Assigned(FOnError) then begin
295    if ernum > 0 then begin
296      SetLength(s, 1024);
297      ErrErrorString(ernum, s, Length(s));
298      FOnError(Self, msg + ': ' + s);
299    end else
300      FOnError(Self, msg);
301  end;
302end;
303
304destructor TLSSLSocket.Destroy;
305begin
306  inherited Destroy;
307  SslFree(FSSL);
308end;
309
310function TLSSLSocket.SetState(const aState: TLSocketState; const TurnOn: Boolean
311  ): Boolean;
312begin
313  case aState of
314    ssSSLActive: Result := SetActiveSSL(TurnOn);
315  else
316    Result := inherited SetState(aState, TurnOn);
317  end;
318end;
319
320procedure TLSSLSocket.ConnectSSL;
321var
322  c, e: cInt;
323begin
324  c := SSLConnect(FSSL);
325  if c <= 0 then begin
326    e := SslGetError(FSSL, c);
327    case e of
328      SSL_ERROR_WANT_READ  : begin // make sure we're watching for reads and flag status
329                               FSocketState := FSocketState - [ssCanReceive];
330                               IgnoreRead := False;
331                             end;
332      SSL_ERROR_WANT_WRITE : begin // make sure we're watching for writes and flag status
333                               FSocketState := FSocketState - [ssCanSend];
334                               IgnoreWrite := False;
335                             end;
336    else
337      begin
338        Bail('SSL connect error', e);
339        Exit;
340      end;
341    end;
342  end else begin
343    FSSLStatus := slNone;
344    TLSSLSession(FSession).CallOnSSLConnect(Self);
345  end;
346end;
347
348procedure TLSSLSocket.AcceptSSL;
349var
350  c, e: cInt;
351begin
352  c := SSLAccept(FSSL);
353  if c <= 0 then begin
354    e := SslGetError(FSSL, c);
355    case e of
356      SSL_ERROR_WANT_READ  : begin // make sure we're watching for reads and flag status
357                               FSocketState := FSocketState - [ssCanReceive];
358                               IgnoreRead := False;
359                             end;
360      SSL_ERROR_WANT_WRITE : begin // make sure we're watching for writes and flag status
361                               FSocketState := FSocketState - [ssCanSend];
362                               IgnoreWrite := False;
363                             end;
364    else
365      begin
366        Bail('SSL accept error', e);
367        Exit;
368      end;
369    end;
370  end else begin
371    FSSLStatus := slNone;
372    TLSSLSession(FSession).CallOnSSLAccept(Self);
373  end;
374end;
375
376procedure TLSSLSocket.ShutdownSSL;
377var
378  n: Integer;
379begin
380  if Assigned(FSSL) then begin
381    FSSLStatus := slNone; // for now
382    n := SSLShutdown(FSSL); // don't care for now, unless it fails badly
383    if n <= 0 then begin
384      n := SslGetError(FSSL, n);
385      case n of
386        SSL_ERROR_WANT_READ,
387        SSL_ERROR_WANT_WRITE,
388        SSL_ERROR_SYSCALL     : begin end; // ignore
389      else
390        Bail('SSL shutdown error', n);
391      end;
392    end;
393  end;
394end;
395
396procedure TLSSLSocket.Disconnect(const Forced: Boolean = True);
397begin
398  if ssSSLActive in FSocketState then begin
399    FSSLStatus := slShutdown;
400    SetActiveSSL(False);
401  end;
402
403  inherited Disconnect(Forced);
404end;
405
406{ TLSSLSession }
407
408procedure TLSSLSession.SetSSLActive(const AValue: Boolean);
409begin
410  if aValue = FSSLActive then Exit;
411  FSSLActive := aValue;
412  if aValue then
413    CreateSSLContext;
414end;
415
416procedure TLSSLSession.CallOnSSLConnect(aSocket: TLSocket);
417begin
418  if Assigned(FOnSSLConnect) then
419    FOnSSLConnect(aSocket);
420end;
421
422procedure TLSSLSession.CallOnSSLAccept(aSocket: TLSocket);
423begin
424  if Assigned(FOnSSLAccept) then
425    FOnSSLAccept(aSocket);
426end;
427
428procedure TLSSLSession.SetCAFile(AValue: string);
429begin
430  DoDirSeparators(aValue);
431  if aValue = FCAFile then Exit;
432  FCAFile := aValue;
433  CreateSSLContext;
434end;
435
436procedure TLSSLSession.SetKeyFile(AValue: string);
437begin
438  DoDirSeparators(aValue);
439  if aValue = FKeyFile then Exit;
440  FKeyFile := aValue;
441  CreateSSLContext;
442end;
443
444procedure TLSSLSession.SetPassword(const AValue: string);
445begin
446  if aValue = FPassword then Exit;
447  FPassword := aValue;
448  CreateSSLContext;
449end;
450
451procedure TLSSLSession.SetMethod(const AValue: TLSSLMethod);
452begin
453  if aValue = FMethod then Exit;
454  FMethod := aValue;
455  CreateSSLContext;
456end;
457
458procedure TLSSLSession.SetPasswordCallback(const AValue: TLPasswordCB);
459begin
460  if aValue = FPasswordCallback then Exit;
461  FPasswordCallback := aValue;
462  CreateSSLContext;
463end;
464
465procedure TLSSLSession.CreateSSLContext;
466var
467  aMethod: PSSL_METHOD;
468begin
469  if Assigned(FSSLContext) then
470    SSLCTXFree(FSSLContext);
471
472  if not FSSLActive then
473    Exit;
474
475  case FMethod of
476    msSSLv2or3 : aMethod := SslMethodV23;
477    msSSLv2    : aMethod := SslMethodV2;
478    msSSLv3    : aMethod := SslMethodV3;
479    msTLSv1    : aMethod := SslMethodTLSV1;
480  end;
481
482  FSSLContext := SSLCTXNew(aMethod);
483  if not Assigned(FSSLContext) then
484    raise Exception.Create('Error creating SSL CTX: SSLCTXNew');
485
486  if SSLCTXSetMode(FSSLContext, SSL_MODE_ENABLE_PARTIAL_WRITE) and SSL_MODE_ENABLE_PARTIAL_WRITE <> SSL_MODE_ENABLE_PARTIAL_WRITE then
487    raise Exception.Create('Error setting partial write mode on CTX');
488  if SSLCTXSetMode(FSSLContext, SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER) and SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER <> SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER then
489    raise Exception.Create('Error setting accept moving buffer mode on CTX');
490
491  if Length(FKeyFile) > 0 then begin
492    if SslCtxUseCertificateChainFile(FSSLContext, FKeyFile) = 0 then
493      raise Exception.Create('Error creating SSL CTX: SslCtxUseCertificateChainFile');
494
495    SslCtxSetDefaultPasswdCb(FSSLContext, FPasswordCallback);
496    SslCtxSetDefaultPasswdCbUserdata(FSSLContext, Self);
497
498    if SSLCTXUsePrivateKeyFile(FSSLContext, FKeyfile, SSL_FILETYPE_PEM) = 0 then
499      raise Exception.Create('Error creating SSL CTX: SSLCTXUsePrivateKeyFile');
500  end;
501
502  if Length(FCAFile) > 0 then
503    if SSLCTXLoadVerifyLocations(FSSLContext, FCAFile, pChar(nil)) = 0 then
504      raise Exception.Create('Error creating SSL CTX: SSLCTXLoadVerifyLocations');
505
506  OPENSSLaddallalgorithms;
507end;
508
509constructor TLSSLSession.Create(aOwner: TComponent);
510begin
511  inherited Create(aOwner);
512  FPasswordCallback := @PasswordCB;
513  FSSLActive := True;
514  CreateSSLContext;
515end;
516
517procedure TLSSLSession.RegisterWithComponent(aConnection: TLConnection);
518begin
519  inherited RegisterWithComponent(aConnection);
520
521  if not aConnection.SocketClass.InheritsFrom(TLSSLSocket) then
522    aConnection.SocketClass := TLSSLSocket;
523end;
524
525procedure TLSSLSession.InitHandle(aHandle: TLHandle);
526begin
527  inherited;
528
529  TLSSLSocket(aHandle).FSSLContext := FSSLContext;
530  TLSSLSocket(aHandle).SetState(ssSSLActive, FSSLActive);
531end;
532
533procedure TLSSLSession.ConnectEvent(aHandle: TLHandle);
534begin
535  if not (ssSSLActive in TLSSLSocket(aHandle).SocketState) then
536    inherited ConnectEvent(aHandle)
537  else if HandleSSLConnection(TLSSLSocket(aHandle)) then
538    CallConnectEvent(aHandle);
539end;
540
541procedure TLSSLSession.ReceiveEvent(aHandle: TLHandle);
542begin
543  if not (ssSSLActive in TLSSLSocket(aHandle).SocketState) then
544    inherited ReceiveEvent(aHandle)
545  else case TLSSLSocket(aHandle).SSLStatus of
546    slConnect:
547      if HandleSSLConnection(TLSSLSocket(aHandle)) then
548      case ssServerSocket in TLSSLSocket(aHandle).SocketState of
549        True  : CallAcceptEvent(aHandle);
550        False : CallConnectEvent(aHandle);
551      end;
552    slActivateTLS:
553      HandleSSLConnection(TLSSLSocket(aHandle));
554  else
555    CallReceiveEvent(aHandle);
556  end;
557end;
558
559procedure TLSSLSession.AcceptEvent(aHandle: TLHandle);
560begin
561  if not (ssSSLActive in TLSSLSocket(aHandle).SocketState) then
562    inherited AcceptEvent(aHandle)
563  else if HandleSSLConnection(TLSSLSocket(aHandle), True) then
564    CallAcceptEvent(aHandle);
565end;
566
567function TLSSLSession.HandleSSLConnection(aSocket: TLSSLSocket; const DoAccept: Boolean = False): Boolean;
568
569  procedure HandleNone;
570  begin
571    aSocket.FIsAcceptor := DoAccept;
572
573    if aSocket.FIsAcceptor then
574      aSocket.AcceptEvent
575    else
576      aSocket.ConnectEvent;
577  end;
578
579  procedure HandleConnect;
580  begin
581    if aSocket.FIsAcceptor then
582      aSocket.AcceptSSL
583    else
584      aSocket.ConnectSSL;
585  end;
586
587begin
588  Result := False;
589
590  if not Assigned(FSSLContext) then
591    raise Exception.Create('Context not created during SSL connect/accept');
592
593  case aSocket.FSSLStatus of
594    slNone        : HandleNone;
595    slActivateTLS,
596    slConnect     : HandleConnect;
597    slShutdown    : raise Exception.Create('Got ConnectEvent or AcceptEvent on socket with ssShutdown status');
598  end;
599
600  Result := aSocket.SSLStatus = slNone;
601end;
602
603initialization
604  SslLibraryInit;
605  SslLoadErrorStrings;
606
607finalization
608  DestroySSLInterface;
609
610end.
611
612