1 /* Copyright (C) 2002-2005 RealVNC Ltd.  All Rights Reserved.
2  * Copyright (C) 2005 Martin Koegler
3  * Copyright (C) 2010 TigerVNC Team
4  * Copyright (C) 2012-2021 Pierre Ossman for Cendio AB
5  *
6  * This is free software; you can redistribute it and/or modify
7  * it under the terms of the GNU General Public License as published by
8  * the Free Software Foundation; either version 2 of the License, or
9  * (at your option) any later version.
10  *
11  * This software is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14  * GNU General Public License for more details.
15  *
16  * You should have received a copy of the GNU General Public License
17  * along with this software; if not, write to the Free Software
18  * Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307,
19  * USA.
20  */
21 
22 #ifdef HAVE_CONFIG_H
23 #include <config.h>
24 #endif
25 
26 #include <rdr/Exception.h>
27 #include <rdr/TLSException.h>
28 #include <rdr/TLSOutStream.h>
29 #include <rfb/LogWriter.h>
30 #include <errno.h>
31 
32 #ifdef HAVE_GNUTLS
33 using namespace rdr;
34 
35 static rfb::LogWriter vlog("TLSOutStream");
36 
37 enum { DEFAULT_BUF_SIZE = 16384 };
38 
push(gnutls_transport_ptr_t str,const void * data,size_t size)39 ssize_t TLSOutStream::push(gnutls_transport_ptr_t str, const void* data,
40 				   size_t size)
41 {
42   TLSOutStream* self= (TLSOutStream*) str;
43   OutStream *out = self->out;
44 
45   delete self->saved_exception;
46   self->saved_exception = NULL;
47 
48   try {
49     out->writeBytes(data, size);
50     out->flush();
51   } catch (SystemException &e) {
52     vlog.error("Failure sending TLS data: %s", e.str());
53     gnutls_transport_set_errno(self->session, e.err);
54     self->saved_exception = new SystemException(e);
55     return -1;
56   } catch (Exception& e) {
57     vlog.error("Failure sending TLS data: %s", e.str());
58     gnutls_transport_set_errno(self->session, EINVAL);
59     self->saved_exception = new Exception(e);
60     return -1;
61   }
62 
63   return size;
64 }
65 
TLSOutStream(OutStream * _out,gnutls_session_t _session)66 TLSOutStream::TLSOutStream(OutStream* _out, gnutls_session_t _session)
67   : session(_session), out(_out), bufSize(DEFAULT_BUF_SIZE), offset(0),
68     saved_exception(NULL)
69 {
70   gnutls_transport_ptr_t recv, send;
71 
72   ptr = start = new U8[bufSize];
73   end = start + bufSize;
74 
75   gnutls_transport_set_push_function(session, push);
76   gnutls_transport_get_ptr2(session, &recv, &send);
77   gnutls_transport_set_ptr2(session, recv, this);
78 }
79 
~TLSOutStream()80 TLSOutStream::~TLSOutStream()
81 {
82 #if 0
83   try {
84 //    flush();
85   } catch (Exception&) {
86   }
87 #endif
88   gnutls_transport_set_push_function(session, NULL);
89 
90   delete [] start;
91   delete saved_exception;
92 }
93 
length()94 size_t TLSOutStream::length()
95 {
96   return offset + ptr - start;
97 }
98 
flush()99 void TLSOutStream::flush()
100 {
101   U8* sentUpTo;
102 
103   // Only give GnuTLS larger chunks if corked to minimize overhead
104   if (corked && ((ptr - start) < 1024))
105     return;
106 
107   sentUpTo = start;
108   while (sentUpTo < ptr) {
109     size_t n = writeTLS(sentUpTo, ptr - sentUpTo);
110     sentUpTo += n;
111     offset += n;
112   }
113 
114   ptr = start;
115   out->flush();
116 }
117 
cork(bool enable)118 void TLSOutStream::cork(bool enable)
119 {
120   OutStream::cork(enable);
121 
122   out->cork(enable);
123 }
124 
overrun(size_t needed)125 void TLSOutStream::overrun(size_t needed)
126 {
127   if (needed > bufSize)
128     throw Exception("TLSOutStream overrun: buffer size exceeded");
129 
130   // A cork might prevent the flush, so disable it temporarily
131   corked = false;
132   flush();
133   corked = true;
134 }
135 
writeTLS(const U8 * data,size_t length)136 size_t TLSOutStream::writeTLS(const U8* data, size_t length)
137 {
138   int n;
139 
140   n = gnutls_record_send(session, data, length);
141   if (n == GNUTLS_E_INTERRUPTED || n == GNUTLS_E_AGAIN)
142     return 0;
143 
144   if (n == GNUTLS_E_PUSH_ERROR)
145     throw *saved_exception;
146 
147   if (n < 0)
148     throw TLSException("writeTLS", n);
149 
150   return n;
151 }
152 
153 #endif
154