1package fsutil
2
3import (
4	"context"
5	"io"
6	"os"
7	"sync"
8
9	"github.com/pkg/errors"
10	"github.com/tonistiigi/fsutil/types"
11	"golang.org/x/sync/errgroup"
12)
13
14type ReceiveOpt struct {
15	NotifyHashed  ChangeFunc
16	ContentHasher ContentHasher
17	ProgressCb    func(int, bool)
18	Merge         bool
19	Filter        FilterFunc
20}
21
22func Receive(ctx context.Context, conn Stream, dest string, opt ReceiveOpt) error {
23	ctx, cancel := context.WithCancel(context.Background())
24	defer cancel()
25
26	r := &receiver{
27		conn:          &syncStream{Stream: conn},
28		dest:          dest,
29		files:         make(map[string]uint32),
30		pipes:         make(map[uint32]io.WriteCloser),
31		notifyHashed:  opt.NotifyHashed,
32		contentHasher: opt.ContentHasher,
33		progressCb:    opt.ProgressCb,
34		merge:         opt.Merge,
35		filter:        opt.Filter,
36	}
37	return r.run(ctx)
38}
39
40type receiver struct {
41	dest       string
42	conn       Stream
43	files      map[string]uint32
44	pipes      map[uint32]io.WriteCloser
45	mu         sync.RWMutex
46	muPipes    sync.RWMutex
47	progressCb func(int, bool)
48	merge      bool
49	filter     FilterFunc
50
51	notifyHashed   ChangeFunc
52	contentHasher  ContentHasher
53	orderValidator Validator
54	hlValidator    Hardlinks
55}
56
57type dynamicWalker struct {
58	walkChan chan *currentPath
59	err      error
60	closeCh  chan struct{}
61}
62
63func newDynamicWalker() *dynamicWalker {
64	return &dynamicWalker{
65		walkChan: make(chan *currentPath, 128),
66		closeCh:  make(chan struct{}),
67	}
68}
69
70func (w *dynamicWalker) update(p *currentPath) error {
71	select {
72	case <-w.closeCh:
73		return errors.Wrap(w.err, "walker is closed")
74	default:
75	}
76	if p == nil {
77		close(w.walkChan)
78		return nil
79	}
80	select {
81	case w.walkChan <- p:
82		return nil
83	case <-w.closeCh:
84		return errors.Wrap(w.err, "walker is closed")
85	}
86}
87
88func (w *dynamicWalker) fill(ctx context.Context, pathC chan<- *currentPath) error {
89	for {
90		select {
91		case p, ok := <-w.walkChan:
92			if !ok {
93				return nil
94			}
95			pathC <- p
96		case <-ctx.Done():
97			w.err = ctx.Err()
98			close(w.closeCh)
99			return ctx.Err()
100		}
101	}
102	return nil
103}
104
105func (r *receiver) run(ctx context.Context) error {
106	g, ctx := errgroup.WithContext(ctx)
107
108	dw, err := NewDiskWriter(ctx, r.dest, DiskWriterOpt{
109		AsyncDataCb:   r.asyncDataFunc,
110		NotifyCb:      r.notifyHashed,
111		ContentHasher: r.contentHasher,
112		Filter:        r.filter,
113	})
114	if err != nil {
115		return err
116	}
117
118	w := newDynamicWalker()
119
120	g.Go(func() (retErr error) {
121		defer func() {
122			if retErr != nil {
123				r.conn.SendMsg(&types.Packet{Type: types.PACKET_ERR, Data: []byte(retErr.Error())})
124			}
125		}()
126		destWalker := emptyWalker
127		if !r.merge {
128			destWalker = GetWalkerFn(r.dest)
129		}
130		err := doubleWalkDiff(ctx, dw.HandleChange, destWalker, w.fill)
131		if err != nil {
132			return err
133		}
134		if err := dw.Wait(ctx); err != nil {
135			return err
136		}
137		r.conn.SendMsg(&types.Packet{Type: types.PACKET_FIN})
138		return nil
139	})
140
141	g.Go(func() error {
142		var i uint32 = 0
143
144		size := 0
145		if r.progressCb != nil {
146			defer func() {
147				r.progressCb(size, true)
148			}()
149		}
150		var p types.Packet
151		for {
152			p = types.Packet{Data: p.Data[:0]}
153			if err := r.conn.RecvMsg(&p); err != nil {
154				return err
155			}
156			if r.progressCb != nil {
157				size += p.Size()
158				r.progressCb(size, false)
159			}
160
161			switch p.Type {
162			case types.PACKET_ERR:
163				return errors.Errorf("error from sender: %s", p.Data)
164			case types.PACKET_STAT:
165				if p.Stat == nil {
166					if err := w.update(nil); err != nil {
167						return err
168					}
169					break
170				}
171				if fileCanRequestData(os.FileMode(p.Stat.Mode)) {
172					r.mu.Lock()
173					r.files[p.Stat.Path] = i
174					r.mu.Unlock()
175				}
176				i++
177				cp := &currentPath{path: p.Stat.Path, f: &StatInfo{p.Stat}}
178				if err := r.orderValidator.HandleChange(ChangeKindAdd, cp.path, cp.f, nil); err != nil {
179					return err
180				}
181				if err := r.hlValidator.HandleChange(ChangeKindAdd, cp.path, cp.f, nil); err != nil {
182					return err
183				}
184				if err := w.update(cp); err != nil {
185					return err
186				}
187			case types.PACKET_DATA:
188				r.muPipes.Lock()
189				pw, ok := r.pipes[p.ID]
190				r.muPipes.Unlock()
191				if !ok {
192					return errors.Errorf("invalid file request %d", p.ID)
193				}
194				if len(p.Data) == 0 {
195					if err := pw.Close(); err != nil {
196						return err
197					}
198				} else {
199					if _, err := pw.Write(p.Data); err != nil {
200						return err
201					}
202				}
203			case types.PACKET_FIN:
204				for {
205					var p types.Packet
206					if err := r.conn.RecvMsg(&p); err != nil {
207						if err == io.EOF {
208							return nil
209						}
210						return err
211					}
212				}
213			}
214		}
215	})
216	return g.Wait()
217}
218
219func (r *receiver) asyncDataFunc(ctx context.Context, p string, wc io.WriteCloser) error {
220	r.mu.Lock()
221	id, ok := r.files[p]
222	if !ok {
223		r.mu.Unlock()
224		return errors.Errorf("invalid file request %s", p)
225	}
226	delete(r.files, p)
227	r.mu.Unlock()
228
229	wwc := newWrappedWriteCloser(wc)
230	r.muPipes.Lock()
231	r.pipes[id] = wwc
232	r.muPipes.Unlock()
233	if err := r.conn.SendMsg(&types.Packet{Type: types.PACKET_REQ, ID: id}); err != nil {
234		return err
235	}
236	err := wwc.Wait(ctx)
237	if err != nil {
238		return err
239	}
240	r.muPipes.Lock()
241	delete(r.pipes, id)
242	r.muPipes.Unlock()
243	return nil
244}
245
246type wrappedWriteCloser struct {
247	io.WriteCloser
248	err  error
249	once sync.Once
250	done chan struct{}
251}
252
253func newWrappedWriteCloser(wc io.WriteCloser) *wrappedWriteCloser {
254	return &wrappedWriteCloser{WriteCloser: wc, done: make(chan struct{})}
255}
256
257func (w *wrappedWriteCloser) Close() error {
258	w.err = w.WriteCloser.Close()
259	w.once.Do(func() { close(w.done) })
260	return w.err
261}
262
263func (w *wrappedWriteCloser) Wait(ctx context.Context) error {
264	select {
265	case <-ctx.Done():
266		return ctx.Err()
267	case <-w.done:
268		return w.err
269	}
270}
271