1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package tlog implements a tamper-evident log
6// used in the Go module go.sum database server.
7//
8// This package follows the design of Certificate Transparency (RFC 6962)
9// and its proofs are compatible with that system.
10// See TestCertificateTransparency.
11//
12package tlog
13
14import (
15	"crypto/sha256"
16	"encoding/base64"
17	"errors"
18	"fmt"
19	"math/bits"
20)
21
22// A Hash is a hash identifying a log record or tree root.
23type Hash [HashSize]byte
24
25// HashSize is the size of a Hash in bytes.
26const HashSize = 32
27
28// String returns a base64 representation of the hash for printing.
29func (h Hash) String() string {
30	return base64.StdEncoding.EncodeToString(h[:])
31}
32
33// MarshalJSON marshals the hash as a JSON string containing the base64-encoded hash.
34func (h Hash) MarshalJSON() ([]byte, error) {
35	return []byte(`"` + h.String() + `"`), nil
36}
37
38// UnmarshalJSON unmarshals a hash from JSON string containing the a base64-encoded hash.
39func (h *Hash) UnmarshalJSON(data []byte) error {
40	if len(data) != 1+44+1 || data[0] != '"' || data[len(data)-2] != '=' || data[len(data)-1] != '"' {
41		return errors.New("cannot decode hash")
42	}
43
44	// As of Go 1.12, base64.StdEncoding.Decode insists on
45	// slicing into target[33:] even when it only writes 32 bytes.
46	// Since we already checked that the hash ends in = above,
47	// we can use base64.RawStdEncoding with the = removed;
48	// RawStdEncoding does not exhibit the same bug.
49	// We decode into a temporary to avoid writing anything to *h
50	// unless the entire input is well-formed.
51	var tmp Hash
52	n, err := base64.RawStdEncoding.Decode(tmp[:], data[1:len(data)-2])
53	if err != nil || n != HashSize {
54		return errors.New("cannot decode hash")
55	}
56	*h = tmp
57	return nil
58}
59
60// ParseHash parses the base64-encoded string form of a hash.
61func ParseHash(s string) (Hash, error) {
62	data, err := base64.StdEncoding.DecodeString(s)
63	if err != nil || len(data) != HashSize {
64		return Hash{}, fmt.Errorf("malformed hash")
65	}
66	var h Hash
67	copy(h[:], data)
68	return h, nil
69}
70
71// maxpow2 returns k, the maximum power of 2 smaller than n,
72// as well as l = log₂ k (so k = 1<<l).
73func maxpow2(n int64) (k int64, l int) {
74	l = 0
75	for 1<<uint(l+1) < n {
76		l++
77	}
78	return 1 << uint(l), l
79}
80
81var zeroPrefix = []byte{0x00}
82
83// RecordHash returns the content hash for the given record data.
84func RecordHash(data []byte) Hash {
85	// SHA256(0x00 || data)
86	// https://tools.ietf.org/html/rfc6962#section-2.1
87	h := sha256.New()
88	h.Write(zeroPrefix)
89	h.Write(data)
90	var h1 Hash
91	h.Sum(h1[:0])
92	return h1
93}
94
95// NodeHash returns the hash for an interior tree node with the given left and right hashes.
96func NodeHash(left, right Hash) Hash {
97	// SHA256(0x01 || left || right)
98	// https://tools.ietf.org/html/rfc6962#section-2.1
99	// We use a stack buffer to assemble the hash input
100	// to avoid allocating a hash struct with sha256.New.
101	var buf [1 + HashSize + HashSize]byte
102	buf[0] = 0x01
103	copy(buf[1:], left[:])
104	copy(buf[1+HashSize:], right[:])
105	return sha256.Sum256(buf[:])
106}
107
108// For information about the stored hash index ordering,
109// see section 3.3 of Crosby and Wallach's paper
110// "Efficient Data Structures for Tamper-Evident Logging".
111// https://www.usenix.org/legacy/event/sec09/tech/full_papers/crosby.pdf
112
113// StoredHashIndex maps the tree coordinates (level, n)
114// to a dense linear ordering that can be used for hash storage.
115// Hash storage implementations that store hashes in sequential
116// storage can use this function to compute where to read or write
117// a given hash.
118func StoredHashIndex(level int, n int64) int64 {
119	// Level L's n'th hash is written right after level L+1's 2n+1'th hash.
120	// Work our way down to the level 0 ordering.
121	// We'll add back the original level count at the end.
122	for l := level; l > 0; l-- {
123		n = 2*n + 1
124	}
125
126	// Level 0's n'th hash is written at n+n/2+n/4+... (eventually n/2ⁱ hits zero).
127	i := int64(0)
128	for ; n > 0; n >>= 1 {
129		i += n
130	}
131
132	return i + int64(level)
133}
134
135// SplitStoredHashIndex is the inverse of StoredHashIndex.
136// That is, SplitStoredHashIndex(StoredHashIndex(level, n)) == level, n.
137func SplitStoredHashIndex(index int64) (level int, n int64) {
138	// Determine level 0 record before index.
139	// StoredHashIndex(0, n) < 2*n,
140	// so the n we want is in [index/2, index/2+log₂(index)].
141	n = index / 2
142	indexN := StoredHashIndex(0, n)
143	if indexN > index {
144		panic("bad math")
145	}
146	for {
147		// Each new record n adds 1 + trailingZeros(n) hashes.
148		x := indexN + 1 + int64(bits.TrailingZeros64(uint64(n+1)))
149		if x > index {
150			break
151		}
152		n++
153		indexN = x
154	}
155	// The hash we want was committed with record n,
156	// meaning it is one of (0, n), (1, n/2), (2, n/4), ...
157	level = int(index - indexN)
158	return level, n >> uint(level)
159}
160
161// StoredHashCount returns the number of stored hashes
162// that are expected for a tree with n records.
163func StoredHashCount(n int64) int64 {
164	if n == 0 {
165		return 0
166	}
167	// The tree will have the hashes up to the last leaf hash.
168	numHash := StoredHashIndex(0, n-1) + 1
169	// And it will have any hashes for subtrees completed by that leaf.
170	for i := uint64(n - 1); i&1 != 0; i >>= 1 {
171		numHash++
172	}
173	return numHash
174}
175
176// StoredHashes returns the hashes that must be stored when writing
177// record n with the given data. The hashes should be stored starting
178// at StoredHashIndex(0, n). The result will have at most 1 + log₂ n hashes,
179// but it will average just under two per call for a sequence of calls for n=1..k.
180//
181// StoredHashes may read up to log n earlier hashes from r
182// in order to compute hashes for completed subtrees.
183func StoredHashes(n int64, data []byte, r HashReader) ([]Hash, error) {
184	return StoredHashesForRecordHash(n, RecordHash(data), r)
185}
186
187// StoredHashesForRecordHash is like StoredHashes but takes
188// as its second argument RecordHash(data) instead of data itself.
189func StoredHashesForRecordHash(n int64, h Hash, r HashReader) ([]Hash, error) {
190	// Start with the record hash.
191	hashes := []Hash{h}
192
193	// Build list of indexes needed for hashes for completed subtrees.
194	// Each trailing 1 bit in the binary representation of n completes a subtree
195	// and consumes a hash from an adjacent subtree.
196	m := int(bits.TrailingZeros64(uint64(n + 1)))
197	indexes := make([]int64, m)
198	for i := 0; i < m; i++ {
199		// We arrange indexes in sorted order.
200		// Note that n>>i is always odd.
201		indexes[m-1-i] = StoredHashIndex(i, n>>uint(i)-1)
202	}
203
204	// Fetch hashes.
205	old, err := r.ReadHashes(indexes)
206	if err != nil {
207		return nil, err
208	}
209	if len(old) != len(indexes) {
210		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(old))
211	}
212
213	// Build new hashes.
214	for i := 0; i < m; i++ {
215		h = NodeHash(old[m-1-i], h)
216		hashes = append(hashes, h)
217	}
218	return hashes, nil
219}
220
221// A HashReader can read hashes for nodes in the log's tree structure.
222type HashReader interface {
223	// ReadHashes returns the hashes with the given stored hash indexes
224	// (see StoredHashIndex and SplitStoredHashIndex).
225	// ReadHashes must return a slice of hashes the same length as indexes,
226	// or else it must return a non-nil error.
227	// ReadHashes may run faster if indexes is sorted in increasing order.
228	ReadHashes(indexes []int64) ([]Hash, error)
229}
230
231// A HashReaderFunc is a function implementing HashReader.
232type HashReaderFunc func([]int64) ([]Hash, error)
233
234func (f HashReaderFunc) ReadHashes(indexes []int64) ([]Hash, error) {
235	return f(indexes)
236}
237
238// TreeHash computes the hash for the root of the tree with n records,
239// using the HashReader to obtain previously stored hashes
240// (those returned by StoredHashes during the writes of those n records).
241// TreeHash makes a single call to ReadHash requesting at most 1 + log₂ n hashes.
242// The tree of size zero is defined to have an all-zero Hash.
243func TreeHash(n int64, r HashReader) (Hash, error) {
244	if n == 0 {
245		return Hash{}, nil
246	}
247	indexes := subTreeIndex(0, n, nil)
248	hashes, err := r.ReadHashes(indexes)
249	if err != nil {
250		return Hash{}, err
251	}
252	if len(hashes) != len(indexes) {
253		return Hash{}, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
254	}
255	hash, hashes := subTreeHash(0, n, hashes)
256	if len(hashes) != 0 {
257		panic("tlog: bad index math in TreeHash")
258	}
259	return hash, nil
260}
261
262// subTreeIndex returns the storage indexes needed to compute
263// the hash for the subtree containing records [lo, hi),
264// appending them to need and returning the result.
265// See https://tools.ietf.org/html/rfc6962#section-2.1
266func subTreeIndex(lo, hi int64, need []int64) []int64 {
267	// See subTreeHash below for commentary.
268	for lo < hi {
269		k, level := maxpow2(hi - lo + 1)
270		if lo&(k-1) != 0 {
271			panic("tlog: bad math in subTreeIndex")
272		}
273		need = append(need, StoredHashIndex(level, lo>>uint(level)))
274		lo += k
275	}
276	return need
277}
278
279// subTreeHash computes the hash for the subtree containing records [lo, hi),
280// assuming that hashes are the hashes corresponding to the indexes
281// returned by subTreeIndex(lo, hi).
282// It returns any leftover hashes.
283func subTreeHash(lo, hi int64, hashes []Hash) (Hash, []Hash) {
284	// Repeatedly partition the tree into a left side with 2^level nodes,
285	// for as large a level as possible, and a right side with the fringe.
286	// The left hash is stored directly and can be read from storage.
287	// The right side needs further computation.
288	numTree := 0
289	for lo < hi {
290		k, _ := maxpow2(hi - lo + 1)
291		if lo&(k-1) != 0 || lo >= hi {
292			panic("tlog: bad math in subTreeHash")
293		}
294		numTree++
295		lo += k
296	}
297
298	if len(hashes) < numTree {
299		panic("tlog: bad index math in subTreeHash")
300	}
301
302	// Reconstruct hash.
303	h := hashes[numTree-1]
304	for i := numTree - 2; i >= 0; i-- {
305		h = NodeHash(hashes[i], h)
306	}
307	return h, hashes[numTree:]
308}
309
310// A RecordProof is a verifiable proof that a particular log root contains a particular record.
311// RFC 6962 calls this a “Merkle audit path.”
312type RecordProof []Hash
313
314// ProveRecord returns the proof that the tree of size t contains the record with index n.
315func ProveRecord(t, n int64, r HashReader) (RecordProof, error) {
316	if t < 0 || n < 0 || n >= t {
317		return nil, fmt.Errorf("tlog: invalid inputs in ProveRecord")
318	}
319	indexes := leafProofIndex(0, t, n, nil)
320	if len(indexes) == 0 {
321		return RecordProof{}, nil
322	}
323	hashes, err := r.ReadHashes(indexes)
324	if err != nil {
325		return nil, err
326	}
327	if len(hashes) != len(indexes) {
328		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
329	}
330
331	p, hashes := leafProof(0, t, n, hashes)
332	if len(hashes) != 0 {
333		panic("tlog: bad index math in ProveRecord")
334	}
335	return p, nil
336}
337
338// leafProofIndex builds the list of indexes needed to construct the proof
339// that leaf n is contained in the subtree with leaves [lo, hi).
340// It appends those indexes to need and returns the result.
341// See https://tools.ietf.org/html/rfc6962#section-2.1.1
342func leafProofIndex(lo, hi, n int64, need []int64) []int64 {
343	// See leafProof below for commentary.
344	if !(lo <= n && n < hi) {
345		panic("tlog: bad math in leafProofIndex")
346	}
347	if lo+1 == hi {
348		return need
349	}
350	if k, _ := maxpow2(hi - lo); n < lo+k {
351		need = leafProofIndex(lo, lo+k, n, need)
352		need = subTreeIndex(lo+k, hi, need)
353	} else {
354		need = subTreeIndex(lo, lo+k, need)
355		need = leafProofIndex(lo+k, hi, n, need)
356	}
357	return need
358}
359
360// leafProof constructs the proof that leaf n is contained in the subtree with leaves [lo, hi).
361// It returns any leftover hashes as well.
362// See https://tools.ietf.org/html/rfc6962#section-2.1.1
363func leafProof(lo, hi, n int64, hashes []Hash) (RecordProof, []Hash) {
364	// We must have lo <= n < hi or else the code here has a bug.
365	if !(lo <= n && n < hi) {
366		panic("tlog: bad math in leafProof")
367	}
368
369	if lo+1 == hi { // n == lo
370		// Reached the leaf node.
371		// The verifier knows what the leaf hash is, so we don't need to send it.
372		return RecordProof{}, hashes
373	}
374
375	// Walk down the tree toward n.
376	// Record the hash of the path not taken (needed for verifying the proof).
377	var p RecordProof
378	var th Hash
379	if k, _ := maxpow2(hi - lo); n < lo+k {
380		// n is on left side
381		p, hashes = leafProof(lo, lo+k, n, hashes)
382		th, hashes = subTreeHash(lo+k, hi, hashes)
383	} else {
384		// n is on right side
385		th, hashes = subTreeHash(lo, lo+k, hashes)
386		p, hashes = leafProof(lo+k, hi, n, hashes)
387	}
388	return append(p, th), hashes
389}
390
391var errProofFailed = errors.New("invalid transparency proof")
392
393// CheckRecord verifies that p is a valid proof that the tree of size t
394// with hash th has an n'th record with hash h.
395func CheckRecord(p RecordProof, t int64, th Hash, n int64, h Hash) error {
396	if t < 0 || n < 0 || n >= t {
397		return fmt.Errorf("tlog: invalid inputs in CheckRecord")
398	}
399	th2, err := runRecordProof(p, 0, t, n, h)
400	if err != nil {
401		return err
402	}
403	if th2 == th {
404		return nil
405	}
406	return errProofFailed
407}
408
409// runRecordProof runs the proof p that leaf n is contained in the subtree with leaves [lo, hi).
410// Running the proof means constructing and returning the implied hash of that
411// subtree.
412func runRecordProof(p RecordProof, lo, hi, n int64, leafHash Hash) (Hash, error) {
413	// We must have lo <= n < hi or else the code here has a bug.
414	if !(lo <= n && n < hi) {
415		panic("tlog: bad math in runRecordProof")
416	}
417
418	if lo+1 == hi { // m == lo
419		// Reached the leaf node.
420		// The proof must not have any unnecessary hashes.
421		if len(p) != 0 {
422			return Hash{}, errProofFailed
423		}
424		return leafHash, nil
425	}
426
427	if len(p) == 0 {
428		return Hash{}, errProofFailed
429	}
430
431	k, _ := maxpow2(hi - lo)
432	if n < lo+k {
433		th, err := runRecordProof(p[:len(p)-1], lo, lo+k, n, leafHash)
434		if err != nil {
435			return Hash{}, err
436		}
437		return NodeHash(th, p[len(p)-1]), nil
438	} else {
439		th, err := runRecordProof(p[:len(p)-1], lo+k, hi, n, leafHash)
440		if err != nil {
441			return Hash{}, err
442		}
443		return NodeHash(p[len(p)-1], th), nil
444	}
445}
446
447// A TreeProof is a verifiable proof that a particular log tree contains
448// as a prefix all records present in an earlier tree.
449// RFC 6962 calls this a “Merkle consistency proof.”
450type TreeProof []Hash
451
452// ProveTree returns the proof that the tree of size t contains
453// as a prefix all the records from the tree of smaller size n.
454func ProveTree(t, n int64, h HashReader) (TreeProof, error) {
455	if t < 1 || n < 1 || n > t {
456		return nil, fmt.Errorf("tlog: invalid inputs in ProveTree")
457	}
458	indexes := treeProofIndex(0, t, n, nil)
459	if len(indexes) == 0 {
460		return TreeProof{}, nil
461	}
462	hashes, err := h.ReadHashes(indexes)
463	if err != nil {
464		return nil, err
465	}
466	if len(hashes) != len(indexes) {
467		return nil, fmt.Errorf("tlog: ReadHashes(%d indexes) = %d hashes", len(indexes), len(hashes))
468	}
469
470	p, hashes := treeProof(0, t, n, hashes)
471	if len(hashes) != 0 {
472		panic("tlog: bad index math in ProveTree")
473	}
474	return p, nil
475}
476
477// treeProofIndex builds the list of indexes needed to construct
478// the sub-proof related to the subtree containing records [lo, hi).
479// See https://tools.ietf.org/html/rfc6962#section-2.1.2.
480func treeProofIndex(lo, hi, n int64, need []int64) []int64 {
481	// See treeProof below for commentary.
482	if !(lo < n && n <= hi) {
483		panic("tlog: bad math in treeProofIndex")
484	}
485
486	if n == hi {
487		if lo == 0 {
488			return need
489		}
490		return subTreeIndex(lo, hi, need)
491	}
492
493	if k, _ := maxpow2(hi - lo); n <= lo+k {
494		need = treeProofIndex(lo, lo+k, n, need)
495		need = subTreeIndex(lo+k, hi, need)
496	} else {
497		need = subTreeIndex(lo, lo+k, need)
498		need = treeProofIndex(lo+k, hi, n, need)
499	}
500	return need
501}
502
503// treeProof constructs the sub-proof related to the subtree containing records [lo, hi).
504// It returns any leftover hashes as well.
505// See https://tools.ietf.org/html/rfc6962#section-2.1.2.
506func treeProof(lo, hi, n int64, hashes []Hash) (TreeProof, []Hash) {
507	// We must have lo < n <= hi or else the code here has a bug.
508	if !(lo < n && n <= hi) {
509		panic("tlog: bad math in treeProof")
510	}
511
512	// Reached common ground.
513	if n == hi {
514		if lo == 0 {
515			// This subtree corresponds exactly to the old tree.
516			// The verifier knows that hash, so we don't need to send it.
517			return TreeProof{}, hashes
518		}
519		th, hashes := subTreeHash(lo, hi, hashes)
520		return TreeProof{th}, hashes
521	}
522
523	// Interior node for the proof.
524	// Decide whether to walk down the left or right side.
525	var p TreeProof
526	var th Hash
527	if k, _ := maxpow2(hi - lo); n <= lo+k {
528		// m is on left side
529		p, hashes = treeProof(lo, lo+k, n, hashes)
530		th, hashes = subTreeHash(lo+k, hi, hashes)
531	} else {
532		// m is on right side
533		th, hashes = subTreeHash(lo, lo+k, hashes)
534		p, hashes = treeProof(lo+k, hi, n, hashes)
535	}
536	return append(p, th), hashes
537}
538
539// CheckTree verifies that p is a valid proof that the tree of size t with hash th
540// contains as a prefix the tree of size n with hash h.
541func CheckTree(p TreeProof, t int64, th Hash, n int64, h Hash) error {
542	if t < 1 || n < 1 || n > t {
543		return fmt.Errorf("tlog: invalid inputs in CheckTree")
544	}
545	h2, th2, err := runTreeProof(p, 0, t, n, h)
546	if err != nil {
547		return err
548	}
549	if th2 == th && h2 == h {
550		return nil
551	}
552	return errProofFailed
553}
554
555// runTreeProof runs the sub-proof p related to the subtree containing records [lo, hi),
556// where old is the hash of the old tree with n records.
557// Running the proof means constructing and returning the implied hashes of that
558// subtree in both the old and new tree.
559func runTreeProof(p TreeProof, lo, hi, n int64, old Hash) (Hash, Hash, error) {
560	// We must have lo < n <= hi or else the code here has a bug.
561	if !(lo < n && n <= hi) {
562		panic("tlog: bad math in runTreeProof")
563	}
564
565	// Reached common ground.
566	if n == hi {
567		if lo == 0 {
568			if len(p) != 0 {
569				return Hash{}, Hash{}, errProofFailed
570			}
571			return old, old, nil
572		}
573		if len(p) != 1 {
574			return Hash{}, Hash{}, errProofFailed
575		}
576		return p[0], p[0], nil
577	}
578
579	if len(p) == 0 {
580		return Hash{}, Hash{}, errProofFailed
581	}
582
583	// Interior node for the proof.
584	k, _ := maxpow2(hi - lo)
585	if n <= lo+k {
586		oh, th, err := runTreeProof(p[:len(p)-1], lo, lo+k, n, old)
587		if err != nil {
588			return Hash{}, Hash{}, err
589		}
590		return oh, NodeHash(th, p[len(p)-1]), nil
591	} else {
592		oh, th, err := runTreeProof(p[:len(p)-1], lo+k, hi, n, old)
593		if err != nil {
594			return Hash{}, Hash{}, err
595		}
596		return NodeHash(p[len(p)-1], oh), NodeHash(p[len(p)-1], th), nil
597	}
598}
599