1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package tlog implements a tamper-evident log
6// used in the Go module go.sum database server.
7//
8// This package is part of a DRAFT of what the go.sum database server will look like.
9// Do not assume the details here are final!
10//
11// This package follows the design of Certificate Transparency (RFC 6962)
12// and its proofs are compatible with that system.
13// See TestCertificateTransparency.
14//
15package tlog
16
17import (
18	"crypto/sha256"
19	"encoding/base64"
20	"errors"
21	"fmt"
22	"math/bits"
23)
24
25// A Hash is a hash identifying a log record or tree root.
26type Hash [HashSize]byte
27
28// HashSize is the size of a Hash in bytes.
29const HashSize = 32
30
31// String returns a base64 representation of the hash for printing.
32func (h Hash) String() string {
33	return base64.StdEncoding.EncodeToString(h[:])
34}
35
36// MarshalJSON marshals the hash as a JSON string containing the base64-encoded hash.
37func (h Hash) MarshalJSON() ([]byte, error) {
38	return []byte(`"` + h.String() + `"`), nil
39}
40
41// UnmarshalJSON unmarshals a hash from JSON string containing the a base64-encoded hash.
42func (h *Hash) UnmarshalJSON(data []byte) error {
43	if len(data) != 1+44+1 || data[0] != '"' || data[len(data)-2] != '=' || data[len(data)-1] != '"' {
44		return errors.New("cannot decode hash")
45	}
46
47	// As of Go 1.12, base64.StdEncoding.Decode insists on
48	// slicing into target[33:] even when it only writes 32 bytes.
49	// Since we already checked that the hash ends in = above,
50	// we can use base64.RawStdEncoding with the = removed;
51	// RawStdEncoding does not exhibit the same bug.
52	// We decode into a temporary to avoid writing anything to *h
53	// unless the entire input is well-formed.
54	var tmp Hash
55	n, err := base64.RawStdEncoding.Decode(tmp[:], data[1:len(data)-2])
56	if err != nil || n != HashSize {
57		return errors.New("cannot decode hash")
58	}
59	*h = tmp
60	return nil
61}
62
63// ParseHash parses the base64-encoded string form of a hash.
64func ParseHash(s string) (Hash, error) {
65	data, err := base64.StdEncoding.DecodeString(s)
66	if err != nil || len(data) != HashSize {
67		return Hash{}, fmt.Errorf("malformed hash")
68	}
69	var h Hash
70	copy(h[:], data)
71	return h, nil
72}
73
74// maxpow2 returns k, the maximum power of 2 smaller than n,
75// as well as l = log₂ k (so k = 1<<l).
76func maxpow2(n int64) (k int64, l int) {
77	l = 0
78	for 1<<uint(l+1) < n {
79		l++
80	}
81	return 1 << uint(l), l
82}
83
84var zeroPrefix = []byte{0x00}
85
86// RecordHash returns the content hash for the given record data.
87func RecordHash(data []byte) Hash {
88	// SHA256(0x00 || data)
89	// https://tools.ietf.org/html/rfc6962#section-2.1
90	h := sha256.New()
91	h.Write(zeroPrefix)
92	h.Write(data)
93	var h1 Hash
94	h.Sum(h1[:0])
95	return h1
96}
97
98// NodeHash returns the hash for an interior tree node with the given left and right hashes.
99func NodeHash(left, right Hash) Hash {
100	// SHA256(0x01 || left || right)
101	// https://tools.ietf.org/html/rfc6962#section-2.1
102	// We use a stack buffer to assemble the hash input
103	// to avoid allocating a hash struct with sha256.New.
104	var buf [1 + HashSize + HashSize]byte
105	buf[0] = 0x01
106	copy(buf[1:], left[:])
107	copy(buf[1+HashSize:], right[:])
108	return sha256.Sum256(buf[:])
109}
110
111// For information about the stored hash index ordering,
112// see section 3.3 of Crosby and Wallach's paper
113// "Efficient Data Structures for Tamper-Evident Logging".
114// https://www.usenix.org/legacy/event/sec09/tech/full_papers/crosby.pdf
115
116// StoredHashIndex maps the tree coordinates (level, n)
117// to a dense linear ordering that can be used for hash storage.
118// Hash storage implementations that store hashes in sequential
119// storage can use this function to compute where to read or write
120// a given hash.
121func StoredHashIndex(level int, n int64) int64 {
122	// Level L's n'th hash is written right after level L+1's 2n+1'th hash.
123	// Work our way down to the level 0 ordering.
124	// We'll add back the orignal level count at the end.
125	for l := level; l > 0; l-- {
126		n = 2*n + 1
127	}
128
129	// Level 0's n'th hash is written at n+n/2+n/4+... (eventually n/2ⁱ hits zero).
130	i := int64(0)
131	for ; n > 0; n >>= 1 {
132		i += n
133	}
134
135	return i + int64(level)
136}
137
138// SplitStoredHashIndex is the inverse of StoredHashIndex.
139// That is, SplitStoredHashIndex(StoredHashIndex(level, n)) == level, n.
140func SplitStoredHashIndex(index int64) (level int, n int64) {
141	// Determine level 0 record before index.
142	// StoredHashIndex(0, n) < 2*n,
143	// so the n we want is in [index/2, index/2+log₂(index)].
144	n = index / 2
145	indexN := StoredHashIndex(0, n)
146	if indexN > index {
147		panic("bad math")
148	}
149	for {
150		// Each new record n adds 1 + trailingZeros(n) hashes.
151		x := indexN + 1 + int64(bits.TrailingZeros64(uint64(n+1)))
152		if x > index {
153			break
154		}
155		n++
156		indexN = x
157	}
158	// The hash we want was commited with record n,
159	// meaning it is one of (0, n), (1, n/2), (2, n/4), ...
160	level = int(index - indexN)
161	return level, n >> uint(level)
162}
163
164// StoredHashCount returns the number of stored hashes
165// that are expected for a tree with n records.
166func StoredHashCount(n int64) int64 {
167	if n == 0 {
168		return 0
169	}
170	// The tree will have the hashes up to the last leaf hash.
171	numHash := StoredHashIndex(0, n-1) + 1
172	// And it will have any hashes for subtrees completed by that leaf.
173	for i := uint64(n - 1); i&1 != 0; i >>= 1 {
174		numHash++
175	}
176	return numHash
177}
178
179// StoredHashes returns the hashes that must be stored when writing
180// record n with the given data. The hashes should be stored starting
181// at StoredHashIndex(0, n). The result will have at most 1 + log₂ n hashes,
182// but it will average just under two per call for a sequence of calls for n=1..k.
183//
184// StoredHashes may read up to log n earlier hashes from r
185// in order to compute hashes for completed subtrees.
186func StoredHashes(n int64, data []byte, r HashReader) ([]Hash, error) {
187	return StoredHashesForRecordHash(n, RecordHash(data), r)
188}
189
190// StoredHashesForRecordHash is like StoredHashes but takes
191// as its second argument RecordHash(data) instead of data itself.
192func StoredHashesForRecordHash(n int64, h Hash, r HashReader) ([]Hash, error) {
193	// Start with the record hash.
194	hashes := []Hash{h}
195
196	// Build list of indexes needed for hashes for completed subtrees.
197	// Each trailing 1 bit in the binary representation of n completes a subtree
198	// and consumes a hash from an adjacent subtree.
199	m := int(bits.TrailingZeros64(uint64(n + 1)))
200	indexes := make([]int64, m)
201	for i := 0; i < m; i++ {
202		// We arrange indexes in sorted order.
203		// Note that n>>i is always odd.
204		indexes[m-1-i] = StoredHashIndex(i, n>>uint(i)-1)
205	}
206
207	// Fetch hashes.
208	old, err := r.ReadHashes(indexes)
209	if err != nil {
210		return nil, err
211	}
212	if len(old) != len(indexes) {
213		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(old))
214	}
215
216	// Build new hashes.
217	for i := 0; i < m; i++ {
218		h = NodeHash(old[m-1-i], h)
219		hashes = append(hashes, h)
220	}
221	return hashes, nil
222}
223
224// A HashReader can read hashes for nodes in the log's tree structure.
225type HashReader interface {
226	// ReadHashes returns the hashes with the given stored hash indexes
227	// (see StoredHashIndex and SplitStoredHashIndex).
228	// ReadHashes must return a slice of hashes the same length as indexes,
229	// or else it must return a non-nil error.
230	// ReadHashes may run faster if indexes is sorted in increasing order.
231	ReadHashes(indexes []int64) ([]Hash, error)
232}
233
234// A HashReaderFunc is a function implementing HashReader.
235type HashReaderFunc func([]int64) ([]Hash, error)
236
237func (f HashReaderFunc) ReadHashes(indexes []int64) ([]Hash, error) {
238	return f(indexes)
239}
240
241// TreeHash computes the hash for the root of the tree with n records,
242// using the HashReader to obtain previously stored hashes
243// (those returned by StoredHashes during the writes of those n records).
244// TreeHash makes a single call to ReadHash requesting at most 1 + log₂ n hashes.
245// The tree of size zero is defined to have an all-zero Hash.
246func TreeHash(n int64, r HashReader) (Hash, error) {
247	if n == 0 {
248		return Hash{}, nil
249	}
250	indexes := subTreeIndex(0, n, nil)
251	hashes, err := r.ReadHashes(indexes)
252	if err != nil {
253		return Hash{}, err
254	}
255	if len(hashes) != len(indexes) {
256		return Hash{}, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
257	}
258	hash, hashes := subTreeHash(0, n, hashes)
259	if len(hashes) != 0 {
260		panic("tlog: bad index math in TreeHash")
261	}
262	return hash, nil
263}
264
265// subTreeIndex returns the storage indexes needed to compute
266// the hash for the subtree containing records [lo, hi),
267// appending them to need and returning the result.
268// See https://tools.ietf.org/html/rfc6962#section-2.1
269func subTreeIndex(lo, hi int64, need []int64) []int64 {
270	// See subTreeHash below for commentary.
271	for lo < hi {
272		k, level := maxpow2(hi - lo + 1)
273		if lo&(k-1) != 0 {
274			panic("tlog: bad math in subTreeIndex")
275		}
276		need = append(need, StoredHashIndex(level, lo>>uint(level)))
277		lo += k
278	}
279	return need
280}
281
282// subTreeHash computes the hash for the subtree containing records [lo, hi),
283// assuming that hashes are the hashes corresponding to the indexes
284// returned by subTreeIndex(lo, hi).
285// It returns any leftover hashes.
286func subTreeHash(lo, hi int64, hashes []Hash) (Hash, []Hash) {
287	// Repeatedly partition the tree into a left side with 2^level nodes,
288	// for as large a level as possible, and a right side with the fringe.
289	// The left hash is stored directly and can be read from storage.
290	// The right side needs further computation.
291	numTree := 0
292	for lo < hi {
293		k, _ := maxpow2(hi - lo + 1)
294		if lo&(k-1) != 0 || lo >= hi {
295			panic("tlog: bad math in subTreeHash")
296		}
297		numTree++
298		lo += k
299	}
300
301	if len(hashes) < numTree {
302		panic("tlog: bad index math in subTreeHash")
303	}
304
305	// Reconstruct hash.
306	h := hashes[numTree-1]
307	for i := numTree - 2; i >= 0; i-- {
308		h = NodeHash(hashes[i], h)
309	}
310	return h, hashes[numTree:]
311}
312
313// A RecordProof is a verifiable proof that a particular log root contains a particular record.
314// RFC 6962 calls this a “Merkle audit path.”
315type RecordProof []Hash
316
317// ProveRecord returns the proof that the tree of size t contains the record with index n.
318func ProveRecord(t, n int64, r HashReader) (RecordProof, error) {
319	if t < 0 || n < 0 || n >= t {
320		return nil, fmt.Errorf("tlog: invalid inputs in ProveRecord")
321	}
322	indexes := leafProofIndex(0, t, n, nil)
323	if len(indexes) == 0 {
324		return RecordProof{}, nil
325	}
326	hashes, err := r.ReadHashes(indexes)
327	if err != nil {
328		return nil, err
329	}
330	if len(hashes) != len(indexes) {
331		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
332	}
333
334	p, hashes := leafProof(0, t, n, hashes)
335	if len(hashes) != 0 {
336		panic("tlog: bad index math in ProveRecord")
337	}
338	return p, nil
339}
340
341// leafProofIndex builds the list of indexes needed to construct the proof
342// that leaf n is contained in the subtree with leaves [lo, hi).
343// It appends those indexes to need and returns the result.
344// See https://tools.ietf.org/html/rfc6962#section-2.1.1
345func leafProofIndex(lo, hi, n int64, need []int64) []int64 {
346	// See leafProof below for commentary.
347	if !(lo <= n && n < hi) {
348		panic("tlog: bad math in leafProofIndex")
349	}
350	if lo+1 == hi {
351		return need
352	}
353	if k, _ := maxpow2(hi - lo); n < lo+k {
354		need = leafProofIndex(lo, lo+k, n, need)
355		need = subTreeIndex(lo+k, hi, need)
356	} else {
357		need = subTreeIndex(lo, lo+k, need)
358		need = leafProofIndex(lo+k, hi, n, need)
359	}
360	return need
361}
362
363// leafProof constructs the proof that leaf n is contained in the subtree with leaves [lo, hi).
364// It returns any leftover hashes as well.
365// See https://tools.ietf.org/html/rfc6962#section-2.1.1
366func leafProof(lo, hi, n int64, hashes []Hash) (RecordProof, []Hash) {
367	// We must have lo <= n < hi or else the code here has a bug.
368	if !(lo <= n && n < hi) {
369		panic("tlog: bad math in leafProof")
370	}
371
372	if lo+1 == hi { // n == lo
373		// Reached the leaf node.
374		// The verifier knows what the leaf hash is, so we don't need to send it.
375		return RecordProof{}, hashes
376	}
377
378	// Walk down the tree toward n.
379	// Record the hash of the path not taken (needed for verifying the proof).
380	var p RecordProof
381	var th Hash
382	if k, _ := maxpow2(hi - lo); n < lo+k {
383		// n is on left side
384		p, hashes = leafProof(lo, lo+k, n, hashes)
385		th, hashes = subTreeHash(lo+k, hi, hashes)
386	} else {
387		// n is on right side
388		th, hashes = subTreeHash(lo, lo+k, hashes)
389		p, hashes = leafProof(lo+k, hi, n, hashes)
390	}
391	return append(p, th), hashes
392}
393
394var errProofFailed = errors.New("invalid transparency proof")
395
396// CheckRecord verifies that p is a valid proof that the tree of size t
397// with hash th has an n'th record with hash h.
398func CheckRecord(p RecordProof, t int64, th Hash, n int64, h Hash) error {
399	if t < 0 || n < 0 || n >= t {
400		return fmt.Errorf("tlog: invalid inputs in CheckRecord")
401	}
402	th2, err := runRecordProof(p, 0, t, n, h)
403	if err != nil {
404		return err
405	}
406	if th2 == th {
407		return nil
408	}
409	return errProofFailed
410}
411
412// runRecordProof runs the proof p that leaf n is contained in the subtree with leaves [lo, hi).
413// Running the proof means constructing and returning the implied hash of that
414// subtree.
415func runRecordProof(p RecordProof, lo, hi, n int64, leafHash Hash) (Hash, error) {
416	// We must have lo <= n < hi or else the code here has a bug.
417	if !(lo <= n && n < hi) {
418		panic("tlog: bad math in runRecordProof")
419	}
420
421	if lo+1 == hi { // m == lo
422		// Reached the leaf node.
423		// The proof must not have any unnecessary hashes.
424		if len(p) != 0 {
425			return Hash{}, errProofFailed
426		}
427		return leafHash, nil
428	}
429
430	if len(p) == 0 {
431		return Hash{}, errProofFailed
432	}
433
434	k, _ := maxpow2(hi - lo)
435	if n < lo+k {
436		th, err := runRecordProof(p[:len(p)-1], lo, lo+k, n, leafHash)
437		if err != nil {
438			return Hash{}, err
439		}
440		return NodeHash(th, p[len(p)-1]), nil
441	} else {
442		th, err := runRecordProof(p[:len(p)-1], lo+k, hi, n, leafHash)
443		if err != nil {
444			return Hash{}, err
445		}
446		return NodeHash(p[len(p)-1], th), nil
447	}
448}
449
450// A TreeProof is a verifiable proof that a particular log tree contains
451// as a prefix all records present in an earlier tree.
452// RFC 6962 calls this a “Merkle consistency proof.”
453type TreeProof []Hash
454
455// ProveTree returns the proof that the tree of size t contains
456// as a prefix all the records from the tree of smaller size n.
457func ProveTree(t, n int64, h HashReader) (TreeProof, error) {
458	if t < 1 || n < 1 || n > t {
459		return nil, fmt.Errorf("tlog: invalid inputs in ProveTree")
460	}
461	indexes := treeProofIndex(0, t, n, nil)
462	if len(indexes) == 0 {
463		return TreeProof{}, nil
464	}
465	hashes, err := h.ReadHashes(indexes)
466	if err != nil {
467		return nil, err
468	}
469	if len(hashes) != len(indexes) {
470		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
471	}
472
473	p, hashes := treeProof(0, t, n, hashes)
474	if len(hashes) != 0 {
475		panic("tlog: bad index math in ProveTree")
476	}
477	return p, nil
478}
479
480// treeProofIndex builds the list of indexes needed to construct
481// the sub-proof related to the subtree containing records [lo, hi).
482// See https://tools.ietf.org/html/rfc6962#section-2.1.2.
483func treeProofIndex(lo, hi, n int64, need []int64) []int64 {
484	// See treeProof below for commentary.
485	if !(lo < n && n <= hi) {
486		panic("tlog: bad math in treeProofIndex")
487	}
488
489	if n == hi {
490		if lo == 0 {
491			return need
492		}
493		return subTreeIndex(lo, hi, need)
494	}
495
496	if k, _ := maxpow2(hi - lo); n <= lo+k {
497		need = treeProofIndex(lo, lo+k, n, need)
498		need = subTreeIndex(lo+k, hi, need)
499	} else {
500		need = subTreeIndex(lo, lo+k, need)
501		need = treeProofIndex(lo+k, hi, n, need)
502	}
503	return need
504}
505
506// treeProof constructs the sub-proof related to the subtree containing records [lo, hi).
507// It returns any leftover hashes as well.
508// See https://tools.ietf.org/html/rfc6962#section-2.1.2.
509func treeProof(lo, hi, n int64, hashes []Hash) (TreeProof, []Hash) {
510	// We must have lo < n <= hi or else the code here has a bug.
511	if !(lo < n && n <= hi) {
512		panic("tlog: bad math in treeProof")
513	}
514
515	// Reached common ground.
516	if n == hi {
517		if lo == 0 {
518			// This subtree corresponds exactly to the old tree.
519			// The verifier knows that hash, so we don't need to send it.
520			return TreeProof{}, hashes
521		}
522		th, hashes := subTreeHash(lo, hi, hashes)
523		return TreeProof{th}, hashes
524	}
525
526	// Interior node for the proof.
527	// Decide whether to walk down the left or right side.
528	var p TreeProof
529	var th Hash
530	if k, _ := maxpow2(hi - lo); n <= lo+k {
531		// m is on left side
532		p, hashes = treeProof(lo, lo+k, n, hashes)
533		th, hashes = subTreeHash(lo+k, hi, hashes)
534	} else {
535		// m is on right side
536		th, hashes = subTreeHash(lo, lo+k, hashes)
537		p, hashes = treeProof(lo+k, hi, n, hashes)
538	}
539	return append(p, th), hashes
540}
541
542// CheckTree verifies that p is a valid proof that the tree of size t with hash th
543// contains as a prefix the tree of size n with hash h.
544func CheckTree(p TreeProof, t int64, th Hash, n int64, h Hash) error {
545	if t < 1 || n < 1 || n > t {
546		return fmt.Errorf("tlog: invalid inputs in CheckTree")
547	}
548	h2, th2, err := runTreeProof(p, 0, t, n, h)
549	if err != nil {
550		return err
551	}
552	if th2 == th && h2 == h {
553		return nil
554	}
555	return errProofFailed
556}
557
558// runTreeProof runs the sub-proof p related to the subtree containing records [lo, hi),
559// where old is the hash of the old tree with n records.
560// Running the proof means constructing and returning the implied hashes of that
561// subtree in both the old and new tree.
562func runTreeProof(p TreeProof, lo, hi, n int64, old Hash) (Hash, Hash, error) {
563	// We must have lo < n <= hi or else the code here has a bug.
564	if !(lo < n && n <= hi) {
565		panic("tlog: bad math in runTreeProof")
566	}
567
568	// Reached common ground.
569	if n == hi {
570		if lo == 0 {
571			if len(p) != 0 {
572				return Hash{}, Hash{}, errProofFailed
573			}
574			return old, old, nil
575		}
576		if len(p) != 1 {
577			return Hash{}, Hash{}, errProofFailed
578		}
579		return p[0], p[0], nil
580	}
581
582	if len(p) == 0 {
583		return Hash{}, Hash{}, errProofFailed
584	}
585
586	// Interior node for the proof.
587	k, _ := maxpow2(hi - lo)
588	if n <= lo+k {
589		oh, th, err := runTreeProof(p[:len(p)-1], lo, lo+k, n, old)
590		if err != nil {
591			return Hash{}, Hash{}, err
592		}
593		return oh, NodeHash(th, p[len(p)-1]), nil
594	} else {
595		oh, th, err := runTreeProof(p[:len(p)-1], lo+k, hi, n, old)
596		if err != nil {
597			return Hash{}, Hash{}, err
598		}
599		return NodeHash(p[len(p)-1], oh), NodeHash(p[len(p)-1], th), nil
600	}
601}
602