1// Copyright 2012 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package xts implements the XTS cipher mode as specified in IEEE P1619/D16.
6//
7// XTS mode is typically used for disk encryption, which presents a number of
8// novel problems that make more common modes inapplicable. The disk is
9// conceptually an array of sectors and we must be able to encrypt and decrypt
10// a sector in isolation. However, an attacker must not be able to transpose
11// two sectors of plaintext by transposing their ciphertext.
12//
13// XTS wraps a block cipher with Rogaway's XEX mode in order to build a
14// tweakable block cipher. This allows each sector to have a unique tweak and
15// effectively create a unique key for each sector.
16//
17// XTS does not provide any authentication. An attacker can manipulate the
18// ciphertext and randomise a block (16 bytes) of the plaintext.
19//
20// (Note: this package does not implement ciphertext-stealing so sectors must
21// be a multiple of 16 bytes.)
22package xts // import "golang.org/x/crypto/xts"
23
24import (
25	"crypto/cipher"
26	"encoding/binary"
27	"errors"
28
29	"golang.org/x/crypto/internal/subtle"
30)
31
32// Cipher contains an expanded key structure. It doesn't contain mutable state
33// and therefore can be used concurrently.
34type Cipher struct {
35	k1, k2 cipher.Block
36}
37
38// blockSize is the block size that the underlying cipher must have. XTS is
39// only defined for 16-byte ciphers.
40const blockSize = 16
41
42// NewCipher creates a Cipher given a function for creating the underlying
43// block cipher (which must have a block size of 16 bytes). The key must be
44// twice the length of the underlying cipher's key.
45func NewCipher(cipherFunc func([]byte) (cipher.Block, error), key []byte) (c *Cipher, err error) {
46	c = new(Cipher)
47	if c.k1, err = cipherFunc(key[:len(key)/2]); err != nil {
48		return
49	}
50	c.k2, err = cipherFunc(key[len(key)/2:])
51
52	if c.k1.BlockSize() != blockSize {
53		err = errors.New("xts: cipher does not have a block size of 16")
54	}
55
56	return
57}
58
59// Encrypt encrypts a sector of plaintext and puts the result into ciphertext.
60// Plaintext and ciphertext must overlap entirely or not at all.
61// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
62func (c *Cipher) Encrypt(ciphertext, plaintext []byte, sectorNum uint64) {
63	if len(ciphertext) < len(plaintext) {
64		panic("xts: ciphertext is smaller than plaintext")
65	}
66	if len(plaintext)%blockSize != 0 {
67		panic("xts: plaintext is not a multiple of the block size")
68	}
69	if subtle.InexactOverlap(ciphertext[:len(plaintext)], plaintext) {
70		panic("xts: invalid buffer overlap")
71	}
72
73	var tweak [blockSize]byte
74	binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
75
76	c.k2.Encrypt(tweak[:], tweak[:])
77
78	for len(plaintext) > 0 {
79		for j := range tweak {
80			ciphertext[j] = plaintext[j] ^ tweak[j]
81		}
82		c.k1.Encrypt(ciphertext, ciphertext)
83		for j := range tweak {
84			ciphertext[j] ^= tweak[j]
85		}
86		plaintext = plaintext[blockSize:]
87		ciphertext = ciphertext[blockSize:]
88
89		mul2(&tweak)
90	}
91}
92
93// Decrypt decrypts a sector of ciphertext and puts the result into plaintext.
94// Plaintext and ciphertext must overlap entirely or not at all.
95// Sectors must be a multiple of 16 bytes and less than 2²⁴ bytes.
96func (c *Cipher) Decrypt(plaintext, ciphertext []byte, sectorNum uint64) {
97	if len(plaintext) < len(ciphertext) {
98		panic("xts: plaintext is smaller than ciphertext")
99	}
100	if len(ciphertext)%blockSize != 0 {
101		panic("xts: ciphertext is not a multiple of the block size")
102	}
103	if subtle.InexactOverlap(plaintext[:len(ciphertext)], ciphertext) {
104		panic("xts: invalid buffer overlap")
105	}
106
107	var tweak [blockSize]byte
108	binary.LittleEndian.PutUint64(tweak[:8], sectorNum)
109
110	c.k2.Encrypt(tweak[:], tweak[:])
111
112	for len(ciphertext) > 0 {
113		for j := range tweak {
114			plaintext[j] = ciphertext[j] ^ tweak[j]
115		}
116		c.k1.Decrypt(plaintext, plaintext)
117		for j := range tweak {
118			plaintext[j] ^= tweak[j]
119		}
120		plaintext = plaintext[blockSize:]
121		ciphertext = ciphertext[blockSize:]
122
123		mul2(&tweak)
124	}
125}
126
127// mul2 multiplies tweak by 2 in GF(2¹²⁸) with an irreducible polynomial of
128// x¹²⁸ + x⁷ + x² + x + 1.
129func mul2(tweak *[blockSize]byte) {
130	var carryIn byte
131	for j := range tweak {
132		carryOut := tweak[j] >> 7
133		tweak[j] = (tweak[j] << 1) + carryIn
134		carryIn = carryOut
135	}
136	if carryIn != 0 {
137		// If we have a carry bit then we need to subtract a multiple
138		// of the irreducible polynomial (x¹²⁸ + x⁷ + x² + x + 1).
139		// By dropping the carry bit, we're subtracting the x^128 term
140		// so all that remains is to subtract x⁷ + x² + x + 1.
141		// Subtraction (and addition) in this representation is just
142		// XOR.
143		tweak[0] ^= 1<<7 | 1<<2 | 1<<1 | 1
144	}
145}
146