1package tls 2 3import ( 4 "crypto/tls" 5 "encoding/pem" 6 "fmt" 7 "io/ioutil" 8 "net/http" 9 "net/http/httptest" 10 "net/url" 11 "os" 12 "strings" 13 "testing" 14 "time" 15) 16 17const ( 18 httpContent = "Hello, TLS!" 19 20 certHash = "SHA256:448f628a8a65aa18560e53a80c53acb38c51b427df0334082349141147dc9bf6" 21) 22 23var ( 24 certNotBefore = time.Unix(0, 0) 25 certNotAfter = certNotBefore.Add(1000000 * time.Hour) 26) 27 28type handshakeError string 29 30func (he handshakeError) Error() string { 31 return string(he) 32} 33 34// createCAFile writes a PEM encoded version of the certificate out to a 35// temporary file, for use by libtls. 36func createCAFile(cert []byte) (string, error) { 37 f, err := ioutil.TempFile("", "tls") 38 if err != nil { 39 return "", fmt.Errorf("failed to create file: %v", err) 40 } 41 defer f.Close() 42 block := &pem.Block{ 43 Type: "CERTIFICATE", 44 Bytes: cert, 45 } 46 if err := pem.Encode(f, block); err != nil { 47 return "", fmt.Errorf("failed to encode certificate: %v", err) 48 } 49 return f.Name(), nil 50} 51 52func newTestServer(tlsCfg *tls.Config) (*httptest.Server, *url.URL, string, error) { 53 ts := httptest.NewUnstartedServer( 54 http.HandlerFunc( 55 func(w http.ResponseWriter, r *http.Request) { 56 fmt.Fprintln(w, httpContent) 57 }, 58 ), 59 ) 60 ts.TLS = tlsCfg 61 ts.StartTLS() 62 63 u, err := url.Parse(ts.URL) 64 if err != nil { 65 return nil, nil, "", fmt.Errorf("failed to parse URL %q: %v", ts.URL, err) 66 } 67 68 caFile, err := createCAFile(ts.TLS.Certificates[0].Certificate[0]) 69 if err != nil { 70 return nil, nil, "", fmt.Errorf("failed to create CA file: %v", err) 71 } 72 73 return ts, u, caFile, nil 74} 75 76func handshakeVersionTest(tlsCfg *tls.Config) (ProtocolVersion, error) { 77 ts, u, caFile, err := newTestServer(tlsCfg) 78 if err != nil { 79 return 0, fmt.Errorf("failed to start test server: %v", err) 80 } 81 defer os.Remove(caFile) 82 defer ts.Close() 83 84 if err := Init(); err != nil { 85 return 0, err 86 } 87 88 cfg, err := NewConfig() 89 if err != nil { 90 return 0, err 91 } 92 defer cfg.Free() 93 if err := cfg.SetCAFile(caFile); err != nil { 94 return 0, err 95 } 96 if err := cfg.SetCiphers("compat"); err != nil { 97 return 0, err 98 } 99 if err := cfg.SetProtocols(ProtocolsAll); err != nil { 100 return 0, err 101 } 102 103 tls, err := NewClient(cfg) 104 if err != nil { 105 return 0, err 106 } 107 defer tls.Free() 108 109 if err := tls.Connect(u.Host, ""); err != nil { 110 return 0, err 111 } 112 if err := tls.Handshake(); err != nil { 113 return 0, handshakeError(err.Error()) 114 } 115 version, err := tls.ConnVersion() 116 if err != nil { 117 return 0, err 118 } 119 if err := tls.Close(); err != nil { 120 return 0, err 121 } 122 return version, nil 123} 124 125func TestTLSBasic(t *testing.T) { 126 ts, u, caFile, err := newTestServer(nil) 127 if err != nil { 128 t.Fatalf("Failed to start test server: %v", err) 129 } 130 defer os.Remove(caFile) 131 defer ts.Close() 132 133 if err := Init(); err != nil { 134 t.Fatal(err) 135 } 136 137 cfg, err := NewConfig() 138 if err != nil { 139 t.Fatal(err) 140 } 141 defer cfg.Free() 142 if err := cfg.SetCAFile(caFile); err != nil { 143 t.Fatal(err) 144 } 145 146 tls, err := NewClient(cfg) 147 if err != nil { 148 t.Fatal(err) 149 } 150 defer tls.Free() 151 152 t.Logf("Connecting to %s", u.Host) 153 154 if err := tls.Connect(u.Host, ""); err != nil { 155 t.Fatal(err) 156 } 157 defer func() { 158 if err := tls.Close(); err != nil { 159 t.Fatalf("Close failed: %v", err) 160 } 161 }() 162 163 n, err := tls.Write([]byte("GET / HTTP/1.0\n\n")) 164 if err != nil { 165 t.Fatal(err) 166 } 167 t.Logf("Wrote %d bytes...", n) 168 169 buf := make([]byte, 1024) 170 n, err = tls.Read(buf) 171 if err != nil { 172 t.Fatal(err) 173 } 174 t.Logf("Read %d bytes...", n) 175 176 if !strings.Contains(string(buf), httpContent) { 177 t.Errorf("Response does not contain %q", httpContent) 178 } 179} 180 181func TestTLSVersions(t *testing.T) { 182 tests := []struct { 183 minVersion uint16 184 maxVersion uint16 185 wantVersion ProtocolVersion 186 wantHandshakeErr bool 187 }{ 188 {tls.VersionSSL30, tls.VersionTLS12, ProtocolTLSv12, false}, 189 {tls.VersionTLS10, tls.VersionTLS12, ProtocolTLSv12, false}, 190 {tls.VersionTLS11, tls.VersionTLS12, ProtocolTLSv12, false}, 191 {tls.VersionSSL30, tls.VersionTLS11, ProtocolTLSv11, false}, 192 {tls.VersionSSL30, tls.VersionTLS10, ProtocolTLSv10, false}, 193 {tls.VersionSSL30, tls.VersionSSL30, 0, true}, 194 {tls.VersionTLS10, tls.VersionTLS10, ProtocolTLSv10, false}, 195 {tls.VersionTLS11, tls.VersionTLS11, ProtocolTLSv11, false}, 196 {tls.VersionTLS12, tls.VersionTLS12, ProtocolTLSv12, false}, 197 } 198 for i, test := range tests { 199 t.Logf("Testing handshake with protocols %x:%x", test.minVersion, test.maxVersion) 200 tlsCfg := &tls.Config{ 201 MinVersion: test.minVersion, 202 MaxVersion: test.maxVersion, 203 } 204 version, err := handshakeVersionTest(tlsCfg) 205 switch { 206 case test.wantHandshakeErr && err == nil: 207 t.Errorf("Test %d - handshake %x:%x succeeded, want handshake error", 208 i, test.minVersion, test.maxVersion) 209 case test.wantHandshakeErr && err != nil: 210 if _, ok := err.(handshakeError); !ok { 211 t.Errorf("Test %d - handshake %x:%x; got unknown error, want handshake error: %v", 212 i, test.minVersion, test.maxVersion, err) 213 } 214 case !test.wantHandshakeErr && err != nil: 215 t.Errorf("Test %d - handshake %x:%x failed: %v", i, test.minVersion, test.maxVersion, err) 216 case !test.wantHandshakeErr && err == nil: 217 if got, want := version, test.wantVersion; got != want { 218 t.Errorf("Test %d - handshake %x:%x; got protocol version %v, want %v", 219 i, test.minVersion, test.maxVersion, got, want) 220 } 221 } 222 } 223} 224 225func TestTLSSingleByteReadWrite(t *testing.T) { 226 ts, u, caFile, err := newTestServer(nil) 227 if err != nil { 228 t.Fatalf("Failed to start test server: %v", err) 229 } 230 defer os.Remove(caFile) 231 defer ts.Close() 232 233 if err := Init(); err != nil { 234 t.Fatal(err) 235 } 236 237 cfg, err := NewConfig() 238 if err != nil { 239 t.Fatal(err) 240 } 241 defer cfg.Free() 242 if err := cfg.SetCAFile(caFile); err != nil { 243 t.Fatal(err) 244 } 245 246 tls, err := NewClient(cfg) 247 if err != nil { 248 t.Fatal(err) 249 } 250 defer tls.Free() 251 252 t.Logf("Connecting to %s", u.Host) 253 254 if err := tls.Connect(u.Host, ""); err != nil { 255 t.Fatal(err) 256 } 257 defer func() { 258 if err := tls.Close(); err != nil { 259 t.Fatalf("Close failed: %v", err) 260 } 261 }() 262 263 for _, b := range []byte("GET / HTTP/1.0\n\n") { 264 n, err := tls.Write([]byte{b}) 265 if err != nil { 266 t.Fatal(err) 267 } 268 if n != 1 { 269 t.Fatalf("Wrote byte %v, got length %d, want 1", b, n) 270 } 271 } 272 273 var body []byte 274 for { 275 buf := make([]byte, 1) 276 n, err := tls.Read(buf) 277 if err != nil { 278 t.Fatal(err) 279 } 280 if n == 0 { 281 break 282 } 283 if n != 1 { 284 t.Fatalf("Read single byte, got length %d, want 1", n) 285 } 286 body = append(body, buf...) 287 } 288 289 if !strings.Contains(string(body), httpContent) { 290 t.Errorf("Response does not contain %q", httpContent) 291 } 292} 293 294func TestTLSInfo(t *testing.T) { 295 ts, u, caFile, err := newTestServer(nil) 296 if err != nil { 297 t.Fatalf("Failed to start test server: %v", err) 298 } 299 defer os.Remove(caFile) 300 defer ts.Close() 301 302 if err := Init(); err != nil { 303 t.Fatal(err) 304 } 305 306 cfg, err := NewConfig() 307 if err != nil { 308 t.Fatal(err) 309 } 310 defer cfg.Free() 311 if err := cfg.SetCAFile(caFile); err != nil { 312 t.Fatal(err) 313 } 314 315 tls, err := NewClient(cfg) 316 if err != nil { 317 t.Fatal(err) 318 } 319 defer tls.Free() 320 321 t.Logf("Connecting to %s", u.Host) 322 323 if err := tls.Connect(u.Host, ""); err != nil { 324 t.Fatal(err) 325 } 326 defer func() { 327 if err := tls.Close(); err != nil { 328 t.Fatalf("Close failed: %v", err) 329 } 330 }() 331 332 // All of these should fail since the handshake has not completed. 333 if _, err := tls.ConnVersion(); err == nil { 334 t.Error("ConnVersion() return nil error, want error") 335 } 336 if _, err := tls.ConnCipher(); err == nil { 337 t.Error("ConnCipher() return nil error, want error") 338 } 339 if _, err := tls.ConnCipherStrength(); err == nil { 340 t.Error("ConnCipherStrength() return nil error, want error") 341 } 342 343 if got, want := tls.PeerCertProvided(), false; got != want { 344 t.Errorf("PeerCertProvided() = %v, want %v", got, want) 345 } 346 for _, name := range []string{"127.0.0.1", "::1", "example.com"} { 347 if got, want := tls.PeerCertContainsName(name), false; got != want { 348 t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want) 349 } 350 } 351 352 if _, err := tls.PeerCertIssuer(); err == nil { 353 t.Error("PeerCertIssuer() returned nil error, want error") 354 } 355 if _, err := tls.PeerCertSubject(); err == nil { 356 t.Error("PeerCertSubject() returned nil error, want error") 357 } 358 if _, err := tls.PeerCertHash(); err == nil { 359 t.Error("PeerCertHash() returned nil error, want error") 360 } 361 if _, err := tls.PeerCertNotBefore(); err == nil { 362 t.Error("PeerCertNotBefore() returned nil error, want error") 363 } 364 if _, err := tls.PeerCertNotAfter(); err == nil { 365 t.Error("PeerCertNotAfter() returned nil error, want error") 366 } 367 368 // Complete the handshake... 369 if err := tls.Handshake(); err != nil { 370 t.Fatalf("Handshake failed: %v", err) 371 } 372 373 if version, err := tls.ConnVersion(); err != nil { 374 t.Errorf("ConnVersion() returned error: %v", err) 375 } else { 376 t.Logf("Protocol version: %v", version) 377 } 378 if cipher, err := tls.ConnCipher(); err != nil { 379 t.Errorf("ConnCipher() returned error: %v", err) 380 } else { 381 t.Logf("Cipher: %v", cipher) 382 } 383 if strength, err := tls.ConnCipherStrength(); err != nil { 384 t.Errorf("ConnCipherStrength() return ederror: %v", err) 385 } else { 386 t.Logf("Cipher Strength: %v bits", strength) 387 } 388 389 if got, want := tls.PeerCertProvided(), true; got != want { 390 t.Errorf("PeerCertProvided() = %v, want %v", got, want) 391 } 392 for _, name := range []string{"127.0.0.1", "::1", "example.com"} { 393 if got, want := tls.PeerCertContainsName(name), true; got != want { 394 t.Errorf("PeerCertContainsName(%q) = %v, want %v", name, got, want) 395 } 396 } 397 398 if issuer, err := tls.PeerCertIssuer(); err != nil { 399 t.Errorf("PeerCertIssuer() returned error: %v", err) 400 } else { 401 t.Logf("Issuer: %v", issuer) 402 } 403 if subject, err := tls.PeerCertSubject(); err != nil { 404 t.Errorf("PeerCertSubject() returned error: %v", err) 405 } else { 406 t.Logf("Subject: %v", subject) 407 } 408 if hash, err := tls.PeerCertHash(); err != nil { 409 t.Errorf("PeerCertHash() returned error: %v", err) 410 } else if hash != certHash { 411 t.Errorf("Got cert hash %q, want %q", hash, certHash) 412 } else { 413 t.Logf("Hash: %v", hash) 414 } 415 if notBefore, err := tls.PeerCertNotBefore(); err != nil { 416 t.Errorf("PeerCertNotBefore() returned error: %v", err) 417 } else if !certNotBefore.Equal(notBefore) { 418 t.Errorf("Got cert notBefore %v, want %v", notBefore.UTC(), certNotBefore.UTC()) 419 } else { 420 t.Logf("NotBefore: %v", notBefore.UTC()) 421 } 422 if notAfter, err := tls.PeerCertNotAfter(); err != nil { 423 t.Errorf("PeerCertNotAfter() returned error: %v", err) 424 } else if !certNotAfter.Equal(notAfter) { 425 t.Errorf("Got cert notAfter %v, want %v", notAfter.UTC(), certNotAfter.UTC()) 426 } else { 427 t.Logf("NotAfter: %v", notAfter.UTC()) 428 } 429} 430