1package restorer
2
3import (
4	"bufio"
5	"context"
6	"io"
7	"math"
8	"path/filepath"
9	"sort"
10	"sync"
11
12	"golang.org/x/sync/errgroup"
13
14	"github.com/restic/restic/internal/crypto"
15	"github.com/restic/restic/internal/debug"
16	"github.com/restic/restic/internal/errors"
17	"github.com/restic/restic/internal/restic"
18)
19
20// TODO if a blob is corrupt, there may be good blob copies in other packs
21// TODO evaluate if it makes sense to split download and processing workers
22//      pro: can (slowly) read network and decrypt/write files concurrently
23//      con: each worker needs to keep one pack in memory
24
25const (
26	workerCount = 8
27
28	largeFileBlobCount = 25
29)
30
31// information about regular file being restored
32type fileInfo struct {
33	lock       sync.Mutex
34	inProgress bool
35	size       int64
36	location   string      // file on local filesystem relative to restorer basedir
37	blobs      interface{} // blobs of the file
38}
39
40type fileBlobInfo struct {
41	id     restic.ID // the blob id
42	offset int64     // blob offset in the file
43}
44
45// information about a data pack required to restore one or more files
46type packInfo struct {
47	id    restic.ID              // the pack id
48	files map[*fileInfo]struct{} // set of files that use blobs from this pack
49}
50
51// fileRestorer restores set of files
52type fileRestorer struct {
53	key        *crypto.Key
54	idx        func(restic.BlobHandle) []restic.PackedBlob
55	packLoader func(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error
56
57	filesWriter *filesWriter
58
59	dst   string
60	files []*fileInfo
61	Error func(string, error) error
62}
63
64func newFileRestorer(dst string,
65	packLoader func(ctx context.Context, h restic.Handle, length int, offset int64, fn func(rd io.Reader) error) error,
66	key *crypto.Key,
67	idx func(restic.BlobHandle) []restic.PackedBlob) *fileRestorer {
68
69	return &fileRestorer{
70		key:         key,
71		idx:         idx,
72		packLoader:  packLoader,
73		filesWriter: newFilesWriter(workerCount),
74		dst:         dst,
75		Error:       restorerAbortOnAllErrors,
76	}
77}
78
79func (r *fileRestorer) addFile(location string, content restic.IDs, size int64) {
80	r.files = append(r.files, &fileInfo{location: location, blobs: content, size: size})
81}
82
83func (r *fileRestorer) targetPath(location string) string {
84	return filepath.Join(r.dst, location)
85}
86
87func (r *fileRestorer) forEachBlob(blobIDs []restic.ID, fn func(packID restic.ID, packBlob restic.Blob)) error {
88	if len(blobIDs) == 0 {
89		return nil
90	}
91
92	for _, blobID := range blobIDs {
93		packs := r.idx(restic.BlobHandle{ID: blobID, Type: restic.DataBlob})
94		if len(packs) == 0 {
95			return errors.Errorf("Unknown blob %s", blobID.String())
96		}
97		fn(packs[0].PackID, packs[0].Blob)
98	}
99
100	return nil
101}
102
103func (r *fileRestorer) restoreFiles(ctx context.Context) error {
104
105	packs := make(map[restic.ID]*packInfo) // all packs
106	// Process packs in order of first access. While this cannot guarantee
107	// that file chunks are restored sequentially, it offers a good enough
108	// approximation to shorten restore times by up to 19% in some test.
109	var packOrder restic.IDs
110
111	// create packInfo from fileInfo
112	for _, file := range r.files {
113		fileBlobs := file.blobs.(restic.IDs)
114		largeFile := len(fileBlobs) > largeFileBlobCount
115		var packsMap map[restic.ID][]fileBlobInfo
116		if largeFile {
117			packsMap = make(map[restic.ID][]fileBlobInfo)
118		}
119		fileOffset := int64(0)
120		err := r.forEachBlob(fileBlobs, func(packID restic.ID, blob restic.Blob) {
121			if largeFile {
122				packsMap[packID] = append(packsMap[packID], fileBlobInfo{id: blob.ID, offset: fileOffset})
123				fileOffset += int64(blob.Length) - crypto.Extension
124			}
125			pack, ok := packs[packID]
126			if !ok {
127				pack = &packInfo{
128					id:    packID,
129					files: make(map[*fileInfo]struct{}),
130				}
131				packs[packID] = pack
132				packOrder = append(packOrder, packID)
133			}
134			pack.files[file] = struct{}{}
135		})
136		if err != nil {
137			// repository index is messed up, can't do anything
138			return err
139		}
140		if largeFile {
141			file.blobs = packsMap
142		}
143	}
144
145	wg, ctx := errgroup.WithContext(ctx)
146	downloadCh := make(chan *packInfo)
147
148	worker := func() error {
149		for pack := range downloadCh {
150			if err := r.downloadPack(ctx, pack); err != nil {
151				return err
152			}
153		}
154		return nil
155	}
156	for i := 0; i < workerCount; i++ {
157		wg.Go(worker)
158	}
159
160	// the main restore loop
161	wg.Go(func() error {
162		for _, id := range packOrder {
163			pack := packs[id]
164			select {
165			case <-ctx.Done():
166				return ctx.Err()
167			case downloadCh <- pack:
168				debug.Log("Scheduled download pack %s", pack.id.Str())
169			}
170		}
171		close(downloadCh)
172		return nil
173	})
174
175	return wg.Wait()
176}
177
178const maxBufferSize = 4 * 1024 * 1024
179
180func (r *fileRestorer) downloadPack(ctx context.Context, pack *packInfo) error {
181
182	// calculate pack byte range and blob->[]files->[]offsets mappings
183	start, end := int64(math.MaxInt64), int64(0)
184	blobs := make(map[restic.ID]struct {
185		offset int64                 // offset of the blob in the pack
186		length int                   // length of the blob
187		files  map[*fileInfo][]int64 // file -> offsets (plural!) of the blob in the file
188	})
189	for file := range pack.files {
190		addBlob := func(blob restic.Blob, fileOffset int64) {
191			if start > int64(blob.Offset) {
192				start = int64(blob.Offset)
193			}
194			if end < int64(blob.Offset+blob.Length) {
195				end = int64(blob.Offset + blob.Length)
196			}
197			blobInfo, ok := blobs[blob.ID]
198			if !ok {
199				blobInfo.offset = int64(blob.Offset)
200				blobInfo.length = int(blob.Length)
201				blobInfo.files = make(map[*fileInfo][]int64)
202				blobs[blob.ID] = blobInfo
203			}
204			blobInfo.files[file] = append(blobInfo.files[file], fileOffset)
205		}
206		if fileBlobs, ok := file.blobs.(restic.IDs); ok {
207			fileOffset := int64(0)
208			err := r.forEachBlob(fileBlobs, func(packID restic.ID, blob restic.Blob) {
209				if packID.Equal(pack.id) {
210					addBlob(blob, fileOffset)
211				}
212				fileOffset += int64(blob.Length) - crypto.Extension
213			})
214			if err != nil {
215				// restoreFiles should have caught this error before
216				panic(err)
217			}
218		} else if packsMap, ok := file.blobs.(map[restic.ID][]fileBlobInfo); ok {
219			for _, blob := range packsMap[pack.id] {
220				idxPacks := r.idx(restic.BlobHandle{ID: blob.id, Type: restic.DataBlob})
221				for _, idxPack := range idxPacks {
222					if idxPack.PackID.Equal(pack.id) {
223						addBlob(idxPack.Blob, blob.offset)
224						break
225					}
226				}
227			}
228		}
229	}
230
231	sortedBlobs := make([]restic.ID, 0, len(blobs))
232	for blobID := range blobs {
233		sortedBlobs = append(sortedBlobs, blobID)
234	}
235	sort.Slice(sortedBlobs, func(i, j int) bool {
236		return blobs[sortedBlobs[i]].offset < blobs[sortedBlobs[j]].offset
237	})
238
239	sanitizeError := func(file *fileInfo, err error) error {
240		if err != nil {
241			err = r.Error(file.location, err)
242		}
243		return err
244	}
245
246	h := restic.Handle{Type: restic.PackFile, Name: pack.id.String()}
247	err := r.packLoader(ctx, h, int(end-start), start, func(rd io.Reader) error {
248		bufferSize := int(end - start)
249		if bufferSize > maxBufferSize {
250			bufferSize = maxBufferSize
251		}
252		bufRd := bufio.NewReaderSize(rd, bufferSize)
253		currentBlobEnd := start
254		var blobData, buf []byte
255		for _, blobID := range sortedBlobs {
256			blob := blobs[blobID]
257			_, err := bufRd.Discard(int(blob.offset - currentBlobEnd))
258			if err != nil {
259				return err
260			}
261			buf, err = r.downloadBlob(bufRd, blobID, blob.length, buf)
262			if err != nil {
263				return err
264			}
265			blobData, err = r.decryptBlob(blobID, buf)
266			if err != nil {
267				for file := range blob.files {
268					if errFile := sanitizeError(file, err); errFile != nil {
269						return errFile
270					}
271				}
272				continue
273			}
274			currentBlobEnd = blob.offset + int64(blob.length)
275			for file, offsets := range blob.files {
276				for _, offset := range offsets {
277					writeToFile := func() error {
278						// this looks overly complicated and needs explanation
279						// two competing requirements:
280						// - must create the file once and only once
281						// - should allow concurrent writes to the file
282						// so write the first blob while holding file lock
283						// write other blobs after releasing the lock
284						createSize := int64(-1)
285						file.lock.Lock()
286						if file.inProgress {
287							file.lock.Unlock()
288						} else {
289							defer file.lock.Unlock()
290							file.inProgress = true
291							createSize = file.size
292						}
293						return r.filesWriter.writeToFile(r.targetPath(file.location), blobData, offset, createSize)
294					}
295					err := sanitizeError(file, writeToFile())
296					if err != nil {
297						return err
298					}
299				}
300			}
301		}
302		return nil
303	})
304
305	if err != nil {
306		for file := range pack.files {
307			if errFile := sanitizeError(file, err); errFile != nil {
308				return errFile
309			}
310		}
311	}
312
313	return nil
314}
315
316func (r *fileRestorer) downloadBlob(rd io.Reader, blobID restic.ID, length int, buf []byte) ([]byte, error) {
317	// TODO reconcile with Repository#loadBlob implementation
318
319	if cap(buf) < length {
320		buf = make([]byte, length)
321	} else {
322		buf = buf[:length]
323	}
324
325	n, err := io.ReadFull(rd, buf)
326	if err != nil {
327		return nil, err
328	}
329
330	if n != length {
331		return nil, errors.Errorf("error loading blob %v: wrong length returned, want %d, got %d", blobID.Str(), length, n)
332	}
333	return buf, nil
334}
335
336func (r *fileRestorer) decryptBlob(blobID restic.ID, buf []byte) ([]byte, error) {
337	// TODO reconcile with Repository#loadBlob implementation
338
339	// decrypt
340	nonce, ciphertext := buf[:r.key.NonceSize()], buf[r.key.NonceSize():]
341	plaintext, err := r.key.Open(ciphertext[:0], nonce, ciphertext, nil)
342	if err != nil {
343		return nil, errors.Errorf("decrypting blob %v failed: %v", blobID, err)
344	}
345
346	// check hash
347	if !restic.Hash(plaintext).Equal(blobID) {
348		return nil, errors.Errorf("blob %v returned invalid hash", blobID)
349	}
350
351	return plaintext, nil
352}
353