1/**
2 * Reed-Solomon Coding over 8-bit values.
3 *
4 * Copyright 2015, Klaus Post
5 * Copyright 2015, Backblaze, Inc.
6 */
7
8package reedsolomon
9
10import (
11	"bytes"
12	"errors"
13	"fmt"
14	"io"
15	"sync"
16)
17
18// StreamEncoder is an interface to encode Reed-Salomon parity sets for your data.
19// It provides a fully streaming interface, and processes data in blocks of up to 4MB.
20//
21// For small shard sizes, 10MB and below, it is recommended to use the in-memory interface,
22// since the streaming interface has a start up overhead.
23//
24// For all operations, no readers and writers should not assume any order/size of
25// individual reads/writes.
26//
27// For usage examples, see "stream-encoder.go" and "streamdecoder.go" in the examples
28// folder.
29type StreamEncoder interface {
30	// Encode parity shards for a set of data shards.
31	//
32	// Input is 'shards' containing readers for data shards followed by parity shards
33	// io.Writer.
34	//
35	// The number of shards must match the number given to NewStream().
36	//
37	// Each reader must supply the same number of bytes.
38	//
39	// The parity shards will be written to the writer.
40	// The number of bytes written will match the input size.
41	//
42	// If a data stream returns an error, a StreamReadError type error
43	// will be returned. If a parity writer returns an error, a
44	// StreamWriteError will be returned.
45	Encode(data []io.Reader, parity []io.Writer) error
46
47	// Verify returns true if the parity shards contain correct data.
48	//
49	// The number of shards must match the number total data+parity shards
50	// given to NewStream().
51	//
52	// Each reader must supply the same number of bytes.
53	// If a shard stream returns an error, a StreamReadError type error
54	// will be returned.
55	Verify(shards []io.Reader) (bool, error)
56
57	// Reconstruct will recreate the missing shards if possible.
58	//
59	// Given a list of valid shards (to read) and invalid shards (to write)
60	//
61	// You indicate that a shard is missing by setting it to nil in the 'valid'
62	// slice and at the same time setting a non-nil writer in "fill".
63	// An index cannot contain both non-nil 'valid' and 'fill' entry.
64	// If both are provided 'ErrReconstructMismatch' is returned.
65	//
66	// If there are too few shards to reconstruct the missing
67	// ones, ErrTooFewShards will be returned.
68	//
69	// The reconstructed shard set is complete, but integrity is not verified.
70	// Use the Verify function to check if data set is ok.
71	Reconstruct(valid []io.Reader, fill []io.Writer) error
72
73	// Split a an input stream into the number of shards given to the encoder.
74	//
75	// The data will be split into equally sized shards.
76	// If the data size isn't dividable by the number of shards,
77	// the last shard will contain extra zeros.
78	//
79	// You must supply the total size of your input.
80	// 'ErrShortData' will be returned if it is unable to retrieve the
81	// number of bytes indicated.
82	Split(data io.Reader, dst []io.Writer, size int64) (err error)
83
84	// Join the shards and write the data segment to dst.
85	//
86	// Only the data shards are considered.
87	//
88	// You must supply the exact output size you want.
89	// If there are to few shards given, ErrTooFewShards will be returned.
90	// If the total data size is less than outSize, ErrShortData will be returned.
91	Join(dst io.Writer, shards []io.Reader, outSize int64) error
92}
93
94// StreamReadError is returned when a read error is encountered
95// that relates to a supplied stream.
96// This will allow you to find out which reader has failed.
97type StreamReadError struct {
98	Err    error // The error
99	Stream int   // The stream number on which the error occurred
100}
101
102// Error returns the error as a string
103func (s StreamReadError) Error() string {
104	return fmt.Sprintf("error reading stream %d: %s", s.Stream, s.Err)
105}
106
107// String returns the error as a string
108func (s StreamReadError) String() string {
109	return s.Error()
110}
111
112// StreamWriteError is returned when a write error is encountered
113// that relates to a supplied stream. This will allow you to
114// find out which reader has failed.
115type StreamWriteError struct {
116	Err    error // The error
117	Stream int   // The stream number on which the error occurred
118}
119
120// Error returns the error as a string
121func (s StreamWriteError) Error() string {
122	return fmt.Sprintf("error writing stream %d: %s", s.Stream, s.Err)
123}
124
125// String returns the error as a string
126func (s StreamWriteError) String() string {
127	return s.Error()
128}
129
130// rsStream contains a matrix for a specific
131// distribution of datashards and parity shards.
132// Construct if using NewStream()
133type rsStream struct {
134	r *reedSolomon
135	o options
136
137	// Shard reader
138	readShards func(dst [][]byte, in []io.Reader) error
139	// Shard writer
140	writeShards func(out []io.Writer, in [][]byte) error
141
142	blockPool sync.Pool
143}
144
145// NewStream creates a new encoder and initializes it to
146// the number of data shards and parity shards that
147// you want to use. You can reuse this encoder.
148// Note that the maximum number of data shards is 256.
149func NewStream(dataShards, parityShards int, o ...Option) (StreamEncoder, error) {
150	r := rsStream{o: defaultOptions}
151	for _, opt := range o {
152		opt(&r.o)
153	}
154	// Override block size if shard size is set.
155	if r.o.streamBS == 0 && r.o.shardSize > 0 {
156		r.o.streamBS = r.o.shardSize
157	}
158	if r.o.streamBS <= 0 {
159		r.o.streamBS = 4 << 20
160	}
161	if r.o.shardSize == 0 && r.o.maxGoroutines == defaultOptions.maxGoroutines {
162		o = append(o, WithAutoGoroutines(r.o.streamBS))
163	}
164
165	enc, err := New(dataShards, parityShards, o...)
166	if err != nil {
167		return nil, err
168	}
169	r.r = enc.(*reedSolomon)
170
171	r.blockPool.New = func() interface{} {
172		out := make([][]byte, dataShards+parityShards)
173		for i := range out {
174			out[i] = make([]byte, r.o.streamBS)
175		}
176		return out
177	}
178	r.readShards = readShards
179	r.writeShards = writeShards
180	if r.o.concReads {
181		r.readShards = cReadShards
182	}
183	if r.o.concWrites {
184		r.writeShards = cWriteShards
185	}
186
187	return &r, err
188}
189
190// NewStreamC creates a new encoder and initializes it to
191// the number of data shards and parity shards given.
192//
193// This functions as 'NewStream', but allows you to enable CONCURRENT reads and writes.
194func NewStreamC(dataShards, parityShards int, conReads, conWrites bool, o ...Option) (StreamEncoder, error) {
195	return NewStream(dataShards, parityShards, append(o, WithConcurrentStreamReads(conReads), WithConcurrentStreamWrites(conWrites))...)
196}
197
198func (r *rsStream) createSlice() [][]byte {
199	out := r.blockPool.Get().([][]byte)
200	for i := range out {
201		out[i] = out[i][:r.o.streamBS]
202	}
203	return out
204}
205
206// Encodes parity shards for a set of data shards.
207//
208// Input is 'shards' containing readers for data shards followed by parity shards
209// io.Writer.
210//
211// The number of shards must match the number given to NewStream().
212//
213// Each reader must supply the same number of bytes.
214//
215// The parity shards will be written to the writer.
216// The number of bytes written will match the input size.
217//
218// If a data stream returns an error, a StreamReadError type error
219// will be returned. If a parity writer returns an error, a
220// StreamWriteError will be returned.
221func (r *rsStream) Encode(data []io.Reader, parity []io.Writer) error {
222	if len(data) != r.r.DataShards {
223		return ErrTooFewShards
224	}
225
226	if len(parity) != r.r.ParityShards {
227		return ErrTooFewShards
228	}
229
230	all := r.createSlice()
231	defer r.blockPool.Put(all)
232	in := all[:r.r.DataShards]
233	out := all[r.r.DataShards:]
234	read := 0
235
236	for {
237		err := r.readShards(in, data)
238		switch err {
239		case nil:
240		case io.EOF:
241			if read == 0 {
242				return ErrShardNoData
243			}
244			return nil
245		default:
246			return err
247		}
248		out = trimShards(out, shardSize(in))
249		read += shardSize(in)
250		err = r.r.Encode(all)
251		if err != nil {
252			return err
253		}
254		err = r.writeShards(parity, out)
255		if err != nil {
256			return err
257		}
258	}
259}
260
261// Trim the shards so they are all the same size
262func trimShards(in [][]byte, size int) [][]byte {
263	for i := range in {
264		if len(in[i]) != 0 {
265			in[i] = in[i][0:size]
266		}
267		if len(in[i]) < size {
268			in[i] = in[i][:0]
269		}
270	}
271	return in
272}
273
274func readShards(dst [][]byte, in []io.Reader) error {
275	if len(in) != len(dst) {
276		panic("internal error: in and dst size do not match")
277	}
278	size := -1
279	for i := range in {
280		if in[i] == nil {
281			dst[i] = dst[i][:0]
282			continue
283		}
284		n, err := io.ReadFull(in[i], dst[i])
285		// The error is EOF only if no bytes were read.
286		// If an EOF happens after reading some but not all the bytes,
287		// ReadFull returns ErrUnexpectedEOF.
288		switch err {
289		case io.ErrUnexpectedEOF, io.EOF:
290			if size < 0 {
291				size = n
292			} else if n != size {
293				// Shard sizes must match.
294				return ErrShardSize
295			}
296			dst[i] = dst[i][0:n]
297		case nil:
298			continue
299		default:
300			return StreamReadError{Err: err, Stream: i}
301		}
302	}
303	if size == 0 {
304		return io.EOF
305	}
306	return nil
307}
308
309func writeShards(out []io.Writer, in [][]byte) error {
310	if len(out) != len(in) {
311		panic("internal error: in and out size do not match")
312	}
313	for i := range in {
314		if out[i] == nil {
315			continue
316		}
317		n, err := out[i].Write(in[i])
318		if err != nil {
319			return StreamWriteError{Err: err, Stream: i}
320		}
321		//
322		if n != len(in[i]) {
323			return StreamWriteError{Err: io.ErrShortWrite, Stream: i}
324		}
325	}
326	return nil
327}
328
329type readResult struct {
330	n    int
331	size int
332	err  error
333}
334
335// cReadShards reads shards concurrently
336func cReadShards(dst [][]byte, in []io.Reader) error {
337	if len(in) != len(dst) {
338		panic("internal error: in and dst size do not match")
339	}
340	var wg sync.WaitGroup
341	wg.Add(len(in))
342	res := make(chan readResult, len(in))
343	for i := range in {
344		if in[i] == nil {
345			dst[i] = dst[i][:0]
346			wg.Done()
347			continue
348		}
349		go func(i int) {
350			defer wg.Done()
351			n, err := io.ReadFull(in[i], dst[i])
352			// The error is EOF only if no bytes were read.
353			// If an EOF happens after reading some but not all the bytes,
354			// ReadFull returns ErrUnexpectedEOF.
355			res <- readResult{size: n, err: err, n: i}
356
357		}(i)
358	}
359	wg.Wait()
360	close(res)
361	size := -1
362	for r := range res {
363		switch r.err {
364		case io.ErrUnexpectedEOF, io.EOF:
365			if size < 0 {
366				size = r.size
367			} else if r.size != size {
368				// Shard sizes must match.
369				return ErrShardSize
370			}
371			dst[r.n] = dst[r.n][0:r.size]
372		case nil:
373		default:
374			return StreamReadError{Err: r.err, Stream: r.n}
375		}
376	}
377	if size == 0 {
378		return io.EOF
379	}
380	return nil
381}
382
383// cWriteShards writes shards concurrently
384func cWriteShards(out []io.Writer, in [][]byte) error {
385	if len(out) != len(in) {
386		panic("internal error: in and out size do not match")
387	}
388	var errs = make(chan error, len(out))
389	var wg sync.WaitGroup
390	wg.Add(len(out))
391	for i := range in {
392		go func(i int) {
393			defer wg.Done()
394			if out[i] == nil {
395				errs <- nil
396				return
397			}
398			n, err := out[i].Write(in[i])
399			if err != nil {
400				errs <- StreamWriteError{Err: err, Stream: i}
401				return
402			}
403			if n != len(in[i]) {
404				errs <- StreamWriteError{Err: io.ErrShortWrite, Stream: i}
405			}
406		}(i)
407	}
408	wg.Wait()
409	close(errs)
410	for err := range errs {
411		if err != nil {
412			return err
413		}
414	}
415
416	return nil
417}
418
419// Verify returns true if the parity shards contain correct data.
420//
421// The number of shards must match the number total data+parity shards
422// given to NewStream().
423//
424// Each reader must supply the same number of bytes.
425// If a shard stream returns an error, a StreamReadError type error
426// will be returned.
427func (r *rsStream) Verify(shards []io.Reader) (bool, error) {
428	if len(shards) != r.r.Shards {
429		return false, ErrTooFewShards
430	}
431
432	read := 0
433	all := r.createSlice()
434	defer r.blockPool.Put(all)
435	for {
436		err := r.readShards(all, shards)
437		if err == io.EOF {
438			if read == 0 {
439				return false, ErrShardNoData
440			}
441			return true, nil
442		}
443		if err != nil {
444			return false, err
445		}
446		read += shardSize(all)
447		ok, err := r.r.Verify(all)
448		if !ok || err != nil {
449			return ok, err
450		}
451	}
452}
453
454// ErrReconstructMismatch is returned by the StreamEncoder, if you supply
455// "valid" and "fill" streams on the same index.
456// Therefore it is impossible to see if you consider the shard valid
457// or would like to have it reconstructed.
458var ErrReconstructMismatch = errors.New("valid shards and fill shards are mutually exclusive")
459
460// Reconstruct will recreate the missing shards if possible.
461//
462// Given a list of valid shards (to read) and invalid shards (to write)
463//
464// You indicate that a shard is missing by setting it to nil in the 'valid'
465// slice and at the same time setting a non-nil writer in "fill".
466// An index cannot contain both non-nil 'valid' and 'fill' entry.
467//
468// If there are too few shards to reconstruct the missing
469// ones, ErrTooFewShards will be returned.
470//
471// The reconstructed shard set is complete when explicitly asked for all missing shards.
472// However its integrity is not automatically verified.
473// Use the Verify function to check in case the data set is complete.
474func (r *rsStream) Reconstruct(valid []io.Reader, fill []io.Writer) error {
475	if len(valid) != r.r.Shards {
476		return ErrTooFewShards
477	}
478	if len(fill) != r.r.Shards {
479		return ErrTooFewShards
480	}
481
482	all := r.createSlice()
483	defer r.blockPool.Put(all)
484	reconDataOnly := true
485	for i := range valid {
486		if valid[i] != nil && fill[i] != nil {
487			return ErrReconstructMismatch
488		}
489		if i >= r.r.DataShards && fill[i] != nil {
490			reconDataOnly = false
491		}
492	}
493
494	read := 0
495	for {
496		err := r.readShards(all, valid)
497		if err == io.EOF {
498			if read == 0 {
499				return ErrShardNoData
500			}
501			return nil
502		}
503		if err != nil {
504			return err
505		}
506		read += shardSize(all)
507		all = trimShards(all, shardSize(all))
508
509		if reconDataOnly {
510			err = r.r.ReconstructData(all) // just reconstruct missing data shards
511		} else {
512			err = r.r.Reconstruct(all) //  reconstruct all missing shards
513		}
514		if err != nil {
515			return err
516		}
517		err = r.writeShards(fill, all)
518		if err != nil {
519			return err
520		}
521	}
522}
523
524// Join the shards and write the data segment to dst.
525//
526// Only the data shards are considered.
527//
528// You must supply the exact output size you want.
529// If there are to few shards given, ErrTooFewShards will be returned.
530// If the total data size is less than outSize, ErrShortData will be returned.
531func (r *rsStream) Join(dst io.Writer, shards []io.Reader, outSize int64) error {
532	// Do we have enough shards?
533	if len(shards) < r.r.DataShards {
534		return ErrTooFewShards
535	}
536
537	// Trim off parity shards if any
538	shards = shards[:r.r.DataShards]
539	for i := range shards {
540		if shards[i] == nil {
541			return StreamReadError{Err: ErrShardNoData, Stream: i}
542		}
543	}
544	// Join all shards
545	src := io.MultiReader(shards...)
546
547	// Copy data to dst
548	n, err := io.CopyN(dst, src, outSize)
549	if err == io.EOF {
550		return ErrShortData
551	}
552	if err != nil {
553		return err
554	}
555	if n != outSize {
556		return ErrShortData
557	}
558	return nil
559}
560
561// Split a an input stream into the number of shards given to the encoder.
562//
563// The data will be split into equally sized shards.
564// If the data size isn't dividable by the number of shards,
565// the last shard will contain extra zeros.
566//
567// You must supply the total size of your input.
568// 'ErrShortData' will be returned if it is unable to retrieve the
569// number of bytes indicated.
570func (r *rsStream) Split(data io.Reader, dst []io.Writer, size int64) error {
571	if size == 0 {
572		return ErrShortData
573	}
574	if len(dst) != r.r.DataShards {
575		return ErrInvShardNum
576	}
577
578	for i := range dst {
579		if dst[i] == nil {
580			return StreamWriteError{Err: ErrShardNoData, Stream: i}
581		}
582	}
583
584	// Calculate number of bytes per shard.
585	perShard := (size + int64(r.r.DataShards) - 1) / int64(r.r.DataShards)
586
587	// Pad data to r.Shards*perShard.
588	padding := make([]byte, (int64(r.r.Shards)*perShard)-size)
589	data = io.MultiReader(data, bytes.NewBuffer(padding))
590
591	// Split into equal-length shards and copy.
592	for i := range dst {
593		n, err := io.CopyN(dst[i], data, perShard)
594		if err != io.EOF && err != nil {
595			return err
596		}
597		if n != perShard {
598			return ErrShortData
599		}
600	}
601
602	return nil
603}
604