1package mint 2 3import ( 4 "crypto/cipher" 5 "fmt" 6 "io" 7 "sync" 8) 9 10const ( 11 sequenceNumberLen = 8 // sequence number length 12 recordHeaderLenTLS = 5 // record header length (TLS) 13 recordHeaderLenDTLS = 13 // record header length (DTLS) 14 maxFragmentLen = 1 << 14 // max number of bytes in a record 15) 16 17type DecryptError string 18 19func (err DecryptError) Error() string { 20 return string(err) 21} 22 23type direction uint8 24 25const ( 26 directionWrite = direction(1) 27 directionRead = direction(2) 28) 29 30// struct { 31// ContentType type; 32// ProtocolVersion record_version [0301 for CH, 0303 for others] 33// uint16 length; 34// opaque fragment[TLSPlaintext.length]; 35// } TLSPlaintext; 36type TLSPlaintext struct { 37 // Omitted: record_version (static) 38 // Omitted: length (computed from fragment) 39 contentType RecordType 40 epoch Epoch 41 seq uint64 42 fragment []byte 43} 44 45type cipherState struct { 46 epoch Epoch // DTLS epoch 47 ivLength int // Length of the seq and nonce fields 48 seq uint64 // Zero-padded sequence number 49 iv []byte // Buffer for the IV 50 cipher cipher.AEAD // AEAD cipher 51} 52 53type RecordLayer struct { 54 sync.Mutex 55 label string 56 direction direction 57 version uint16 // The current version number 58 conn io.ReadWriter // The underlying connection 59 frame *frameReader // The buffered frame reader 60 nextData []byte // The next record to send 61 cachedRecord *TLSPlaintext // Last record read, cached to enable "peek" 62 cachedError error // Error on the last record read 63 64 cipher *cipherState 65 readCiphers map[Epoch]*cipherState 66 67 datagram bool 68} 69 70type recordLayerFrameDetails struct { 71 datagram bool 72} 73 74func (d recordLayerFrameDetails) headerLen() int { 75 if d.datagram { 76 return recordHeaderLenDTLS 77 } 78 return recordHeaderLenTLS 79} 80 81func (d recordLayerFrameDetails) defaultReadLen() int { 82 return d.headerLen() + maxFragmentLen 83} 84 85func (d recordLayerFrameDetails) frameLen(hdr []byte) (int, error) { 86 return (int(hdr[d.headerLen()-2]) << 8) | int(hdr[d.headerLen()-1]), nil 87} 88 89func newCipherStateNull() *cipherState { 90 return &cipherState{EpochClear, 0, 0, nil, nil} 91} 92 93func newCipherStateAead(epoch Epoch, factory aeadFactory, key []byte, iv []byte) (*cipherState, error) { 94 cipher, err := factory(key) 95 if err != nil { 96 return nil, err 97 } 98 99 return &cipherState{epoch, len(iv), 0, iv, cipher}, nil 100} 101 102func NewRecordLayerTLS(conn io.ReadWriter, dir direction) *RecordLayer { 103 r := RecordLayer{} 104 r.label = "" 105 r.direction = dir 106 r.conn = conn 107 r.frame = newFrameReader(recordLayerFrameDetails{false}) 108 r.cipher = newCipherStateNull() 109 r.version = tls10Version 110 return &r 111} 112 113func NewRecordLayerDTLS(conn io.ReadWriter, dir direction) *RecordLayer { 114 r := RecordLayer{} 115 r.label = "" 116 r.direction = dir 117 r.conn = conn 118 r.frame = newFrameReader(recordLayerFrameDetails{true}) 119 r.cipher = newCipherStateNull() 120 r.readCiphers = make(map[Epoch]*cipherState, 0) 121 r.readCiphers[0] = r.cipher 122 r.datagram = true 123 return &r 124} 125 126func (r *RecordLayer) SetVersion(v uint16) { 127 r.version = v 128} 129 130func (r *RecordLayer) ResetClear(seq uint64) { 131 r.cipher = newCipherStateNull() 132 r.cipher.seq = seq 133} 134 135func (r *RecordLayer) Rekey(epoch Epoch, factory aeadFactory, key []byte, iv []byte) error { 136 cipher, err := newCipherStateAead(epoch, factory, key, iv) 137 if err != nil { 138 return err 139 } 140 r.cipher = cipher 141 if r.datagram && r.direction == directionRead { 142 r.readCiphers[epoch] = cipher 143 } 144 return nil 145} 146 147// TODO(ekr@rtfm.com): This is never used, which is a bug. 148func (r *RecordLayer) DiscardReadKey(epoch Epoch) { 149 if !r.datagram { 150 return 151 } 152 153 _, ok := r.readCiphers[epoch] 154 assert(ok) 155 delete(r.readCiphers, epoch) 156} 157 158func (c *cipherState) combineSeq(datagram bool) uint64 { 159 seq := c.seq 160 if datagram { 161 seq |= uint64(c.epoch) << 48 162 } 163 return seq 164} 165 166func (c *cipherState) computeNonce(seq uint64) []byte { 167 nonce := make([]byte, len(c.iv)) 168 copy(nonce, c.iv) 169 170 s := seq 171 172 offset := len(c.iv) 173 for i := 0; i < 8; i++ { 174 nonce[(offset-i)-1] ^= byte(s & 0xff) 175 s >>= 8 176 } 177 logf(logTypeCrypto, "Computing nonce for sequence # %x -> %x", seq, nonce) 178 179 return nonce 180} 181 182func (c *cipherState) incrementSequenceNumber() { 183 if c.seq >= (1<<48 - 1) { 184 // Not allowed to let sequence number wrap. 185 // Instead, must renegotiate before it does. 186 // Not likely enough to bother. This is the 187 // DTLS limit. 188 panic("TLS: sequence number wraparound") 189 } 190 c.seq++ 191} 192 193func (c *cipherState) overhead() int { 194 if c.cipher == nil { 195 return 0 196 } 197 return c.cipher.Overhead() 198} 199 200func (r *RecordLayer) encrypt(cipher *cipherState, seq uint64, pt *TLSPlaintext, padLen int) *TLSPlaintext { 201 assert(r.direction == directionWrite) 202 logf(logTypeIO, "%s Encrypt seq=[%x]", r.label, seq) 203 // Expand the fragment to hold contentType, padding, and overhead 204 originalLen := len(pt.fragment) 205 plaintextLen := originalLen + 1 + padLen 206 ciphertextLen := plaintextLen + cipher.overhead() 207 208 // Assemble the revised plaintext 209 out := &TLSPlaintext{ 210 211 contentType: RecordTypeApplicationData, 212 fragment: make([]byte, ciphertextLen), 213 } 214 copy(out.fragment, pt.fragment) 215 out.fragment[originalLen] = byte(pt.contentType) 216 for i := 1; i <= padLen; i++ { 217 out.fragment[originalLen+i] = 0 218 } 219 220 // Encrypt the fragment 221 payload := out.fragment[:plaintextLen] 222 cipher.cipher.Seal(payload[:0], cipher.computeNonce(seq), payload, nil) 223 return out 224} 225 226func (r *RecordLayer) decrypt(pt *TLSPlaintext, seq uint64) (*TLSPlaintext, int, error) { 227 assert(r.direction == directionRead) 228 logf(logTypeIO, "%s Decrypt seq=[%x]", r.label, seq) 229 if len(pt.fragment) < r.cipher.overhead() { 230 msg := fmt.Sprintf("tls.record.decrypt: Record too short [%d] < [%d]", len(pt.fragment), r.cipher.overhead()) 231 return nil, 0, DecryptError(msg) 232 } 233 234 decryptLen := len(pt.fragment) - r.cipher.overhead() 235 out := &TLSPlaintext{ 236 contentType: pt.contentType, 237 fragment: make([]byte, decryptLen), 238 } 239 240 // Decrypt 241 _, err := r.cipher.cipher.Open(out.fragment[:0], r.cipher.computeNonce(seq), pt.fragment, nil) 242 if err != nil { 243 logf(logTypeIO, "%s AEAD decryption failure [%x]", r.label, pt) 244 return nil, 0, DecryptError("tls.record.decrypt: AEAD decrypt failed") 245 } 246 247 // Find the padding boundary 248 padLen := 0 249 for ; padLen < decryptLen+1 && out.fragment[decryptLen-padLen-1] == 0; padLen++ { 250 } 251 252 // Transfer the content type 253 newLen := decryptLen - padLen - 1 254 out.contentType = RecordType(out.fragment[newLen]) 255 256 // Truncate the message to remove contentType, padding, overhead 257 out.fragment = out.fragment[:newLen] 258 out.seq = seq 259 return out, padLen, nil 260} 261 262func (r *RecordLayer) PeekRecordType(block bool) (RecordType, error) { 263 var pt *TLSPlaintext 264 var err error 265 266 for { 267 pt, err = r.nextRecord(false) 268 if err == nil { 269 break 270 } 271 if !block || err != AlertWouldBlock { 272 return 0, err 273 } 274 } 275 return pt.contentType, nil 276} 277 278func (r *RecordLayer) ReadRecord() (*TLSPlaintext, error) { 279 pt, err := r.nextRecord(false) 280 281 // Consume the cached record if there was one 282 r.cachedRecord = nil 283 r.cachedError = nil 284 285 return pt, err 286} 287 288func (r *RecordLayer) readRecordAnyEpoch() (*TLSPlaintext, error) { 289 pt, err := r.nextRecord(true) 290 291 // Consume the cached record if there was one 292 r.cachedRecord = nil 293 r.cachedError = nil 294 295 return pt, err 296} 297 298func (r *RecordLayer) nextRecord(allowOldEpoch bool) (*TLSPlaintext, error) { 299 cipher := r.cipher 300 if r.cachedRecord != nil { 301 logf(logTypeIO, "%s Returning cached record", r.label) 302 return r.cachedRecord, r.cachedError 303 } 304 305 // Loop until one of three things happens: 306 // 307 // 1. We get a frame 308 // 2. We try to read off the socket and get nothing, in which case 309 // returnAlertWouldBlock 310 // 3. We get an error. 311 var err error 312 err = AlertWouldBlock 313 var header, body []byte 314 315 for err != nil { 316 if r.frame.needed() > 0 { 317 buf := make([]byte, r.frame.details.headerLen()+maxFragmentLen) 318 n, err := r.conn.Read(buf) 319 if err != nil { 320 logf(logTypeIO, "%s Error reading, %v", r.label, err) 321 return nil, err 322 } 323 324 if n == 0 { 325 return nil, AlertWouldBlock 326 } 327 328 logf(logTypeIO, "%s Read %v bytes", r.label, n) 329 330 buf = buf[:n] 331 r.frame.addChunk(buf) 332 } 333 334 header, body, err = r.frame.process() 335 // Loop around onAlertWouldBlock to see if some 336 // data is now available. 337 if err != nil && err != AlertWouldBlock { 338 return nil, err 339 } 340 } 341 342 pt := &TLSPlaintext{} 343 // Validate content type 344 switch RecordType(header[0]) { 345 default: 346 return nil, fmt.Errorf("tls.record: Unknown content type %02x", header[0]) 347 case RecordTypeAlert, RecordTypeHandshake, RecordTypeApplicationData, RecordTypeAck: 348 pt.contentType = RecordType(header[0]) 349 } 350 351 // Validate version 352 if !allowWrongVersionNumber && (header[1] != 0x03 || header[2] != 0x01) { 353 return nil, fmt.Errorf("tls.record: Invalid version %02x%02x", header[1], header[2]) 354 } 355 356 // Validate size < max 357 size := (int(header[len(header)-2]) << 8) + int(header[len(header)-1]) 358 359 if size > maxFragmentLen+256 { 360 return nil, fmt.Errorf("tls.record: Ciphertext size too big") 361 } 362 363 pt.fragment = make([]byte, size) 364 copy(pt.fragment, body) 365 366 // TODO(ekr@rtfm.com): Enforce that for epoch > 0, the content type is app data. 367 368 // Attempt to decrypt fragment 369 seq := cipher.seq 370 if r.datagram { 371 // TODO(ekr@rtfm.com): Handle duplicates. 372 seq, _ = decodeUint(header[3:11], 8) 373 epoch := Epoch(seq >> 48) 374 375 // Look up the cipher suite from the epoch 376 c, ok := r.readCiphers[epoch] 377 if !ok { 378 logf(logTypeIO, "%s Message from unknown epoch: [%v]", r.label, epoch) 379 return nil, AlertWouldBlock 380 } 381 382 if epoch != cipher.epoch { 383 logf(logTypeIO, "%s Message from non-current epoch: [%v != %v] out-of-epoch reads=%v", r.label, epoch, 384 cipher.epoch, allowOldEpoch) 385 if !allowOldEpoch { 386 return nil, AlertWouldBlock 387 } 388 cipher = c 389 } 390 } 391 392 if cipher.cipher != nil { 393 logf(logTypeIO, "%s RecordLayer.ReadRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), seq, pt.contentType, pt.fragment) 394 pt, _, err = r.decrypt(pt, seq) 395 if err != nil { 396 logf(logTypeIO, "%s Decryption failed", r.label) 397 return nil, err 398 } 399 } 400 pt.epoch = cipher.epoch 401 402 // Check that plaintext length is not too long 403 if len(pt.fragment) > maxFragmentLen { 404 return nil, fmt.Errorf("tls.record: Plaintext size too big") 405 } 406 407 logf(logTypeIO, "%s RecordLayer.ReadRecord [%d] [%x]", r.label, pt.contentType, pt.fragment) 408 409 r.cachedRecord = pt 410 cipher.incrementSequenceNumber() 411 return pt, nil 412} 413 414func (r *RecordLayer) WriteRecord(pt *TLSPlaintext) error { 415 return r.writeRecordWithPadding(pt, r.cipher, 0) 416} 417 418func (r *RecordLayer) WriteRecordWithPadding(pt *TLSPlaintext, padLen int) error { 419 return r.writeRecordWithPadding(pt, r.cipher, padLen) 420} 421 422func (r *RecordLayer) writeRecordWithPadding(pt *TLSPlaintext, cipher *cipherState, padLen int) error { 423 seq := cipher.combineSeq(r.datagram) 424 if cipher.cipher != nil { 425 logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] plaintext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) 426 pt = r.encrypt(cipher, seq, pt, padLen) 427 } else if padLen > 0 { 428 return fmt.Errorf("tls.record: Padding can only be done on encrypted records") 429 } 430 431 if len(pt.fragment) > maxFragmentLen { 432 return fmt.Errorf("tls.record: Record size too big") 433 } 434 435 length := len(pt.fragment) 436 var header []byte 437 438 if !r.datagram { 439 header = []byte{byte(pt.contentType), 440 byte(r.version >> 8), byte(r.version & 0xff), 441 byte(length >> 8), byte(length)} 442 } else { 443 header = make([]byte, 13) 444 version := dtlsConvertVersion(r.version) 445 copy(header, []byte{byte(pt.contentType), 446 byte(version >> 8), byte(version & 0xff), 447 }) 448 encodeUint(seq, 8, header[3:]) 449 encodeUint(uint64(length), 2, header[11:]) 450 } 451 record := append(header, pt.fragment...) 452 453 logf(logTypeIO, "%s RecordLayer.WriteRecord epoch=[%s] seq=[%x] [%d] ciphertext=[%x]", r.label, cipher.epoch.label(), cipher.seq, pt.contentType, pt.fragment) 454 455 cipher.incrementSequenceNumber() 456 _, err := r.conn.Write(record) 457 return err 458} 459