1/*
2Copyright (c) 2017 VMware, Inc. All Rights Reserved.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package hgfs
18
19import (
20	"archive/tar"
21	"bufio"
22	"bytes"
23	"compress/gzip"
24	"io"
25	"io/ioutil"
26	"log"
27	"math"
28	"net/url"
29	"os"
30	"path/filepath"
31	"strings"
32	"sync"
33	"time"
34
35	"github.com/vmware/govmomi/toolbox/vix"
36)
37
38// ArchiveScheme is the default scheme used to register the archive FileHandler
39var ArchiveScheme = "archive"
40
41// ArchiveHandler implements a FileHandler for transferring directories.
42type ArchiveHandler struct {
43	Read  func(*url.URL, *tar.Reader) error
44	Write func(*url.URL, *tar.Writer) error
45}
46
47// NewArchiveHandler returns a FileHandler implementation for transferring directories using gzip'd tar files.
48func NewArchiveHandler() FileHandler {
49	return &ArchiveHandler{
50		Read:  archiveRead,
51		Write: archiveWrite,
52	}
53}
54
55// Stat implements FileHandler.Stat
56func (*ArchiveHandler) Stat(u *url.URL) (os.FileInfo, error) {
57	switch u.Query().Get("format") {
58	case "", "tar", "tgz":
59		// ok
60	default:
61		log.Printf("unknown archive format: %q", u)
62		return nil, vix.Error(vix.InvalidArg)
63	}
64
65	return &archive{
66		name: u.Path,
67		size: math.MaxInt64,
68	}, nil
69}
70
71// Open implements FileHandler.Open
72func (h *ArchiveHandler) Open(u *url.URL, mode int32) (File, error) {
73	switch mode {
74	case OpenModeReadOnly:
75		return h.newArchiveFromGuest(u)
76	case OpenModeWriteOnly:
77		return h.newArchiveToGuest(u)
78	default:
79		return nil, os.ErrNotExist
80	}
81}
82
83// archive implements the hgfs.File and os.FileInfo interfaces.
84type archive struct {
85	name string
86	size int64
87	done func() error
88
89	io.Reader
90	io.Writer
91}
92
93// Name implementation of the os.FileInfo interface method.
94func (a *archive) Name() string {
95	return a.name
96}
97
98// Size implementation of the os.FileInfo interface method.
99func (a *archive) Size() int64 {
100	return a.size
101}
102
103// Mode implementation of the os.FileInfo interface method.
104func (a *archive) Mode() os.FileMode {
105	return 0600
106}
107
108// ModTime implementation of the os.FileInfo interface method.
109func (a *archive) ModTime() time.Time {
110	return time.Now()
111}
112
113// IsDir implementation of the os.FileInfo interface method.
114func (a *archive) IsDir() bool {
115	return false
116}
117
118// Sys implementation of the os.FileInfo interface method.
119func (a *archive) Sys() interface{} {
120	return nil
121}
122
123// The trailer is required since TransferFromGuest requires a Content-Length,
124// which toolbox doesn't know ahead of time as the gzip'd tarball never touches the disk.
125// HTTP clients need to be aware of this and stop reading when they see the 2nd gzip header.
126var gzipHeader = []byte{0x1f, 0x8b, 0x08} // rfc1952 {ID1, ID2, CM}
127
128var gzipTrailer = true
129
130// newArchiveFromGuest returns an hgfs.File implementation to read a directory as a gzip'd tar.
131func (h *ArchiveHandler) newArchiveFromGuest(u *url.URL) (File, error) {
132	r, w := io.Pipe()
133
134	a := &archive{
135		name:   u.Path,
136		done:   r.Close,
137		Reader: r,
138		Writer: w,
139	}
140
141	var z io.Writer = w
142	var c io.Closer = ioutil.NopCloser(nil)
143
144	switch u.Query().Get("format") {
145	case "tgz":
146		gz := gzip.NewWriter(w)
147		z = gz
148		c = gz
149	}
150
151	tw := tar.NewWriter(z)
152
153	go func() {
154		err := h.Write(u, tw)
155
156		_ = tw.Close()
157		_ = c.Close()
158		if gzipTrailer {
159			_, _ = w.Write(gzipHeader)
160		}
161		_ = w.CloseWithError(err)
162	}()
163
164	return a, nil
165}
166
167// newArchiveToGuest returns an hgfs.File implementation to expand a gzip'd tar into a directory.
168func (h *ArchiveHandler) newArchiveToGuest(u *url.URL) (File, error) {
169	r, w := io.Pipe()
170
171	buf := bufio.NewReader(r)
172
173	a := &archive{
174		name:   u.Path,
175		Reader: buf,
176		Writer: w,
177	}
178
179	var cerr error
180	var wg sync.WaitGroup
181
182	a.done = func() error {
183		_ = w.Close()
184		// We need to wait for unpack to finish to complete its work
185		// and to propagate the error if any to Close.
186		wg.Wait()
187		return cerr
188	}
189
190	wg.Add(1)
191	go func() {
192		defer wg.Done()
193
194		c := func() error {
195			// Drain the pipe of tar trailer data (two null blocks)
196			if cerr == nil {
197				_, _ = io.Copy(ioutil.Discard, a.Reader)
198			}
199			return nil
200		}
201
202		header, _ := buf.Peek(len(gzipHeader))
203
204		if bytes.Equal(header, gzipHeader) {
205			gz, err := gzip.NewReader(a.Reader)
206			if err != nil {
207				_ = r.CloseWithError(err)
208				cerr = err
209				return
210			}
211
212			c = gz.Close
213			a.Reader = gz
214		}
215
216		tr := tar.NewReader(a.Reader)
217
218		cerr = h.Read(u, tr)
219
220		_ = c()
221		_ = r.CloseWithError(cerr)
222	}()
223
224	return a, nil
225}
226
227func (a *archive) Close() error {
228	return a.done()
229}
230
231// archiveRead writes the contents of the given tar.Reader to the given directory.
232func archiveRead(u *url.URL, tr *tar.Reader) error {
233	for {
234		header, err := tr.Next()
235		if err != nil {
236			if err == io.EOF {
237				return nil
238			}
239			return err
240		}
241
242		name := filepath.Join(u.Path, header.Name)
243		mode := os.FileMode(header.Mode)
244
245		switch header.Typeflag {
246		case tar.TypeDir:
247			err = os.MkdirAll(name, mode)
248		case tar.TypeReg:
249			_ = os.MkdirAll(filepath.Dir(name), 0755)
250
251			var f *os.File
252
253			f, err = os.OpenFile(name, os.O_CREATE|os.O_RDWR|os.O_TRUNC, mode)
254			if err == nil {
255				_, cerr := io.Copy(f, tr)
256				err = f.Close()
257				if cerr != nil {
258					err = cerr
259				}
260			}
261		case tar.TypeSymlink:
262			err = os.Symlink(header.Linkname, name)
263		}
264
265		// TODO: Uid/Gid may not be meaningful here without some mapping.
266		// The other option to consider would be making use of the guest auth user ID.
267		// os.Lchown(name, header.Uid, header.Gid)
268
269		if err != nil {
270			return err
271		}
272	}
273}
274
275// archiveWrite writes the contents of the given source directory to the given tar.Writer.
276func archiveWrite(u *url.URL, tw *tar.Writer) error {
277	info, err := os.Stat(u.Path)
278	if err != nil {
279		return err
280	}
281
282	// Note that the VMX will trim any trailing slash.  For example:
283	// "/foo/bar/?prefix=bar/" will end up here as "/foo/bar/?prefix=bar"
284	// Escape to avoid this: "/for/bar/?prefix=bar%2F"
285	prefix := u.Query().Get("prefix")
286
287	dir := u.Path
288
289	f := func(file string, fi os.FileInfo, err error) error {
290		if err != nil {
291			return filepath.SkipDir
292		}
293
294		name := strings.TrimPrefix(file, dir)
295		name = strings.TrimPrefix(name, "/")
296
297		if name == "" {
298			return nil // this is u.Path itself (which may or may not have a trailing "/")
299		}
300
301		if prefix != "" {
302			name = prefix + name
303		}
304
305		header, _ := tar.FileInfoHeader(fi, name)
306
307		header.Name = name
308
309		if header.Typeflag == tar.TypeDir {
310			header.Name += "/"
311		}
312
313		var f *os.File
314
315		if header.Typeflag == tar.TypeReg && fi.Size() != 0 {
316			f, err = os.Open(file)
317			if err != nil {
318				if os.IsPermission(err) {
319					return nil
320				}
321				return err
322			}
323		}
324
325		_ = tw.WriteHeader(header)
326
327		if f != nil {
328			_, err = io.Copy(tw, f)
329			_ = f.Close()
330		}
331
332		return err
333	}
334
335	if info.IsDir() {
336		return filepath.Walk(u.Path, f)
337	}
338
339	dir = filepath.Dir(dir)
340
341	return f(u.Path, info, nil)
342}
343