1package fsutil 2 3import ( 4 "context" 5 "io" 6 "os" 7 "sync" 8 9 "github.com/pkg/errors" 10 "github.com/tonistiigi/fsutil/types" 11 "golang.org/x/sync/errgroup" 12) 13 14type ReceiveOpt struct { 15 NotifyHashed ChangeFunc 16 ContentHasher ContentHasher 17 ProgressCb func(int, bool) 18 Merge bool 19 Filter FilterFunc 20} 21 22func Receive(ctx context.Context, conn Stream, dest string, opt ReceiveOpt) error { 23 ctx, cancel := context.WithCancel(context.Background()) 24 defer cancel() 25 26 r := &receiver{ 27 conn: &syncStream{Stream: conn}, 28 dest: dest, 29 files: make(map[string]uint32), 30 pipes: make(map[uint32]io.WriteCloser), 31 notifyHashed: opt.NotifyHashed, 32 contentHasher: opt.ContentHasher, 33 progressCb: opt.ProgressCb, 34 merge: opt.Merge, 35 filter: opt.Filter, 36 } 37 return r.run(ctx) 38} 39 40type receiver struct { 41 dest string 42 conn Stream 43 files map[string]uint32 44 pipes map[uint32]io.WriteCloser 45 mu sync.RWMutex 46 muPipes sync.RWMutex 47 progressCb func(int, bool) 48 merge bool 49 filter FilterFunc 50 51 notifyHashed ChangeFunc 52 contentHasher ContentHasher 53 orderValidator Validator 54 hlValidator Hardlinks 55} 56 57type dynamicWalker struct { 58 walkChan chan *currentPath 59 err error 60 closeCh chan struct{} 61} 62 63func newDynamicWalker() *dynamicWalker { 64 return &dynamicWalker{ 65 walkChan: make(chan *currentPath, 128), 66 closeCh: make(chan struct{}), 67 } 68} 69 70func (w *dynamicWalker) update(p *currentPath) error { 71 select { 72 case <-w.closeCh: 73 return errors.Wrap(w.err, "walker is closed") 74 default: 75 } 76 if p == nil { 77 close(w.walkChan) 78 return nil 79 } 80 select { 81 case w.walkChan <- p: 82 return nil 83 case <-w.closeCh: 84 return errors.Wrap(w.err, "walker is closed") 85 } 86} 87 88func (w *dynamicWalker) fill(ctx context.Context, pathC chan<- *currentPath) error { 89 for { 90 select { 91 case p, ok := <-w.walkChan: 92 if !ok { 93 return nil 94 } 95 pathC <- p 96 case <-ctx.Done(): 97 w.err = ctx.Err() 98 close(w.closeCh) 99 return ctx.Err() 100 } 101 } 102 return nil 103} 104 105func (r *receiver) run(ctx context.Context) error { 106 g, ctx := errgroup.WithContext(ctx) 107 108 dw, err := NewDiskWriter(ctx, r.dest, DiskWriterOpt{ 109 AsyncDataCb: r.asyncDataFunc, 110 NotifyCb: r.notifyHashed, 111 ContentHasher: r.contentHasher, 112 Filter: r.filter, 113 }) 114 if err != nil { 115 return err 116 } 117 118 w := newDynamicWalker() 119 120 g.Go(func() (retErr error) { 121 defer func() { 122 if retErr != nil { 123 r.conn.SendMsg(&types.Packet{Type: types.PACKET_ERR, Data: []byte(retErr.Error())}) 124 } 125 }() 126 destWalker := emptyWalker 127 if !r.merge { 128 destWalker = GetWalkerFn(r.dest) 129 } 130 err := doubleWalkDiff(ctx, dw.HandleChange, destWalker, w.fill) 131 if err != nil { 132 return err 133 } 134 if err := dw.Wait(ctx); err != nil { 135 return err 136 } 137 r.conn.SendMsg(&types.Packet{Type: types.PACKET_FIN}) 138 return nil 139 }) 140 141 g.Go(func() error { 142 var i uint32 = 0 143 144 size := 0 145 if r.progressCb != nil { 146 defer func() { 147 r.progressCb(size, true) 148 }() 149 } 150 var p types.Packet 151 for { 152 p = types.Packet{Data: p.Data[:0]} 153 if err := r.conn.RecvMsg(&p); err != nil { 154 return err 155 } 156 if r.progressCb != nil { 157 size += p.Size() 158 r.progressCb(size, false) 159 } 160 161 switch p.Type { 162 case types.PACKET_ERR: 163 return errors.Errorf("error from sender: %s", p.Data) 164 case types.PACKET_STAT: 165 if p.Stat == nil { 166 if err := w.update(nil); err != nil { 167 return err 168 } 169 break 170 } 171 if fileCanRequestData(os.FileMode(p.Stat.Mode)) { 172 r.mu.Lock() 173 r.files[p.Stat.Path] = i 174 r.mu.Unlock() 175 } 176 i++ 177 cp := ¤tPath{path: p.Stat.Path, f: &StatInfo{p.Stat}} 178 if err := r.orderValidator.HandleChange(ChangeKindAdd, cp.path, cp.f, nil); err != nil { 179 return err 180 } 181 if err := r.hlValidator.HandleChange(ChangeKindAdd, cp.path, cp.f, nil); err != nil { 182 return err 183 } 184 if err := w.update(cp); err != nil { 185 return err 186 } 187 case types.PACKET_DATA: 188 r.muPipes.Lock() 189 pw, ok := r.pipes[p.ID] 190 r.muPipes.Unlock() 191 if !ok { 192 return errors.Errorf("invalid file request %d", p.ID) 193 } 194 if len(p.Data) == 0 { 195 if err := pw.Close(); err != nil { 196 return err 197 } 198 } else { 199 if _, err := pw.Write(p.Data); err != nil { 200 return err 201 } 202 } 203 case types.PACKET_FIN: 204 for { 205 var p types.Packet 206 if err := r.conn.RecvMsg(&p); err != nil { 207 if err == io.EOF { 208 return nil 209 } 210 return err 211 } 212 } 213 } 214 } 215 }) 216 return g.Wait() 217} 218 219func (r *receiver) asyncDataFunc(ctx context.Context, p string, wc io.WriteCloser) error { 220 r.mu.Lock() 221 id, ok := r.files[p] 222 if !ok { 223 r.mu.Unlock() 224 return errors.Errorf("invalid file request %s", p) 225 } 226 delete(r.files, p) 227 r.mu.Unlock() 228 229 wwc := newWrappedWriteCloser(wc) 230 r.muPipes.Lock() 231 r.pipes[id] = wwc 232 r.muPipes.Unlock() 233 if err := r.conn.SendMsg(&types.Packet{Type: types.PACKET_REQ, ID: id}); err != nil { 234 return err 235 } 236 err := wwc.Wait(ctx) 237 if err != nil { 238 return err 239 } 240 r.muPipes.Lock() 241 delete(r.pipes, id) 242 r.muPipes.Unlock() 243 return nil 244} 245 246type wrappedWriteCloser struct { 247 io.WriteCloser 248 err error 249 once sync.Once 250 done chan struct{} 251} 252 253func newWrappedWriteCloser(wc io.WriteCloser) *wrappedWriteCloser { 254 return &wrappedWriteCloser{WriteCloser: wc, done: make(chan struct{})} 255} 256 257func (w *wrappedWriteCloser) Close() error { 258 w.err = w.WriteCloser.Close() 259 w.once.Do(func() { close(w.done) }) 260 return w.err 261} 262 263func (w *wrappedWriteCloser) Wait(ctx context.Context) error { 264 select { 265 case <-ctx.Done(): 266 return ctx.Err() 267 case <-w.done: 268 return w.err 269 } 270} 271