1package archiver
2
3import (
4	"archive/zip"
5	"bytes"
6	"compress/flate"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"log"
11	"os"
12	"path"
13	"path/filepath"
14	"strings"
15
16	"github.com/dsnet/compress/bzip2"
17	"github.com/klauspost/compress/zstd"
18	"github.com/ulikunitz/xz"
19)
20
21// ZipCompressionMethod Compression type
22type ZipCompressionMethod uint16
23
24// Compression methods.
25// see https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT.
26// Note LZMA: Disabled - because 7z isn't able to unpack ZIP+LZMA ZIP+LZMA2 archives made this way - and vice versa.
27const (
28	Store   ZipCompressionMethod = 0
29	Deflate ZipCompressionMethod = 8
30	BZIP2   ZipCompressionMethod = 12
31	LZMA    ZipCompressionMethod = 14
32	ZSTD    ZipCompressionMethod = 93
33	XZ      ZipCompressionMethod = 95
34)
35
36// Zip provides facilities for operating ZIP archives.
37// See https://pkware.cachefly.net/webdocs/casestudies/APPNOTE.TXT.
38type Zip struct {
39	// The compression level to use, as described
40	// in the compress/flate package.
41	CompressionLevel int
42
43	// Whether to overwrite existing files; if false,
44	// an error is returned if the file exists.
45	OverwriteExisting bool
46
47	// Whether to make all the directories necessary
48	// to create a zip archive in the desired path.
49	MkdirAll bool
50
51	// If enabled, selective compression will only
52	// compress files which are not already in a
53	// compressed format; this is decided based
54	// simply on file extension.
55	SelectiveCompression bool
56
57	// A single top-level folder can be implicitly
58	// created by the Archive or Unarchive methods
59	// if the files to be added to the archive
60	// or the files to be extracted from the archive
61	// do not all have a common root. This roughly
62	// mimics the behavior of archival tools integrated
63	// into OS file browsers which create a subfolder
64	// to avoid unexpectedly littering the destination
65	// folder with potentially many files, causing a
66	// problematic cleanup/organization situation.
67	// This feature is available for both creation
68	// and extraction of archives, but may be slightly
69	// inefficient with lots and lots of files,
70	// especially on extraction.
71	ImplicitTopLevelFolder bool
72
73	// Strip number of leading paths. This feature is available
74	// only during unpacking of the entire archive.
75	StripComponents int
76
77	// If true, errors encountered during reading
78	// or writing a single file will be logged and
79	// the operation will continue on remaining files.
80	ContinueOnError bool
81
82	// Compression algorithm
83	FileMethod ZipCompressionMethod
84	zw         *zip.Writer
85	zr         *zip.Reader
86	ridx       int
87	//decinitialized bool
88}
89
90// CheckExt ensures the file extension matches the format.
91func (*Zip) CheckExt(filename string) error {
92	if !strings.HasSuffix(filename, ".zip") {
93		return fmt.Errorf("filename must have a .zip extension")
94	}
95	return nil
96}
97
98// Registering a global decompressor is not reentrant and may panic
99func registerDecompressor(zr *zip.Reader) {
100	// register zstd decompressor
101	zr.RegisterDecompressor(uint16(ZSTD), func(r io.Reader) io.ReadCloser {
102		zr, err := zstd.NewReader(r)
103		if err != nil {
104			return nil
105		}
106		return zr.IOReadCloser()
107	})
108	zr.RegisterDecompressor(uint16(BZIP2), func(r io.Reader) io.ReadCloser {
109		bz2r, err := bzip2.NewReader(r, nil)
110		if err != nil {
111			return nil
112		}
113		return bz2r
114	})
115	zr.RegisterDecompressor(uint16(XZ), func(r io.Reader) io.ReadCloser {
116		xr, err := xz.NewReader(r)
117		if err != nil {
118			return nil
119		}
120		return ioutil.NopCloser(xr)
121	})
122}
123
124// CheckPath ensures the file extension matches the format.
125func (*Zip) CheckPath(to, filename string) error {
126	to, _ = filepath.Abs(to) //explicit the destination folder to prevent that 'string.HasPrefix' check can be 'bypassed' when no destination folder is supplied in input
127	dest := filepath.Join(to, filename)
128	//prevent path traversal attacks
129	if !strings.HasPrefix(dest, to) {
130		return &IllegalPathError{AbsolutePath: dest, Filename: filename}
131	}
132	return nil
133}
134
135// Archive creates a .zip file at destination containing
136// the files listed in sources. The destination must end
137// with ".zip". File paths can be those of regular files
138// or directories. Regular files are stored at the 'root'
139// of the archive, and directories are recursively added.
140func (z *Zip) Archive(sources []string, destination string) error {
141	err := z.CheckExt(destination)
142	if err != nil {
143		return fmt.Errorf("checking extension: %v", err)
144	}
145	if !z.OverwriteExisting && fileExists(destination) {
146		return fmt.Errorf("file already exists: %s", destination)
147	}
148
149	// make the folder to contain the resulting archive
150	// if it does not already exist
151	destDir := filepath.Dir(destination)
152	if z.MkdirAll && !fileExists(destDir) {
153		err := mkdir(destDir, 0755)
154		if err != nil {
155			return fmt.Errorf("making folder for destination: %v", err)
156		}
157	}
158
159	out, err := os.Create(destination)
160	if err != nil {
161		return fmt.Errorf("creating %s: %v", destination, err)
162	}
163	defer out.Close()
164
165	err = z.Create(out)
166	if err != nil {
167		return fmt.Errorf("creating zip: %v", err)
168	}
169	defer z.Close()
170
171	var topLevelFolder string
172	if z.ImplicitTopLevelFolder && multipleTopLevels(sources) {
173		topLevelFolder = folderNameFromFileName(destination)
174	}
175
176	for _, source := range sources {
177		err := z.writeWalk(source, topLevelFolder, destination)
178		if err != nil {
179			return fmt.Errorf("walking %s: %v", source, err)
180		}
181	}
182
183	return nil
184}
185
186// Unarchive unpacks the .zip file at source to destination.
187// Destination will be treated as a folder name.
188func (z *Zip) Unarchive(source, destination string) error {
189	if !fileExists(destination) && z.MkdirAll {
190		err := mkdir(destination, 0755)
191		if err != nil {
192			return fmt.Errorf("preparing destination: %v", err)
193		}
194	}
195
196	file, err := os.Open(source)
197	if err != nil {
198		return fmt.Errorf("opening source file: %v", err)
199	}
200	defer file.Close()
201
202	fileInfo, err := file.Stat()
203	if err != nil {
204		return fmt.Errorf("statting source file: %v", err)
205	}
206
207	err = z.Open(file, fileInfo.Size())
208	if err != nil {
209		return fmt.Errorf("opening zip archive for reading: %v", err)
210	}
211	defer z.Close()
212
213	// if the files in the archive do not all share a common
214	// root, then make sure we extract to a single subfolder
215	// rather than potentially littering the destination...
216	if z.ImplicitTopLevelFolder {
217		files := make([]string, len(z.zr.File))
218		for i := range z.zr.File {
219			files[i] = z.zr.File[i].Name
220		}
221		if multipleTopLevels(files) {
222			destination = filepath.Join(destination, folderNameFromFileName(source))
223		}
224	}
225
226	for {
227		err := z.extractNext(destination)
228		if err == io.EOF {
229			break
230		}
231		if err != nil {
232			if z.ContinueOnError || IsIllegalPathError(err) {
233				log.Printf("[ERROR] Reading file in zip archive: %v", err)
234				continue
235			}
236			return fmt.Errorf("reading file in zip archive: %v", err)
237		}
238	}
239
240	return nil
241}
242
243func (z *Zip) extractNext(to string) error {
244	f, err := z.Read()
245	if err != nil {
246		return err // don't wrap error; calling loop must break on io.EOF
247	}
248	defer f.Close()
249
250	header, ok := f.Header.(zip.FileHeader)
251	if !ok {
252		return fmt.Errorf("expected header to be zip.FileHeader but was %T", f.Header)
253	}
254
255	errPath := z.CheckPath(to, header.Name)
256	if errPath != nil {
257		return fmt.Errorf("checking path traversal attempt: %v", errPath)
258	}
259
260	if z.StripComponents > 0 {
261		if strings.Count(header.Name, "/") < z.StripComponents {
262			return nil // skip path with fewer components
263		}
264
265		for i := 0; i < z.StripComponents; i++ {
266			slash := strings.Index(header.Name, "/")
267			header.Name = header.Name[slash+1:]
268		}
269	}
270	return z.extractFile(f, to, &header)
271}
272
273func (z *Zip) extractFile(f File, to string, header *zip.FileHeader) error {
274	to = filepath.Join(to, header.Name)
275
276	// if a directory, no content; simply make the directory and return
277	if f.IsDir() {
278		return mkdir(to, f.Mode())
279	}
280
281	// do not overwrite existing files, if configured
282	if !z.OverwriteExisting && fileExists(to) {
283		return fmt.Errorf("file already exists: %s", to)
284	}
285
286	// extract symbolic links as symbolic links
287	if isSymlink(header.FileInfo()) {
288		// symlink target is the contents of the file
289		buf := new(bytes.Buffer)
290		_, err := io.Copy(buf, f)
291		if err != nil {
292			return fmt.Errorf("%s: reading symlink target: %v", header.Name, err)
293		}
294		return writeNewSymbolicLink(to, strings.TrimSpace(buf.String()))
295	}
296
297	return writeNewFile(to, f, f.Mode())
298}
299
300func (z *Zip) writeWalk(source, topLevelFolder, destination string) error {
301	sourceInfo, err := os.Stat(source)
302	if err != nil {
303		return fmt.Errorf("%s: stat: %v", source, err)
304	}
305	destAbs, err := filepath.Abs(destination)
306	if err != nil {
307		return fmt.Errorf("%s: getting absolute path of destination %s: %v", source, destination, err)
308	}
309
310	return filepath.Walk(source, func(fpath string, info os.FileInfo, err error) error {
311		handleErr := func(err error) error {
312			if z.ContinueOnError {
313				log.Printf("[ERROR] Walking %s: %v", fpath, err)
314				return nil
315			}
316			return err
317		}
318		if err != nil {
319			return handleErr(fmt.Errorf("traversing %s: %v", fpath, err))
320		}
321		if info == nil {
322			return handleErr(fmt.Errorf("%s: no file info", fpath))
323		}
324
325		// make sure we do not copy the output file into the output
326		// file; that results in an infinite loop and disk exhaustion!
327		fpathAbs, err := filepath.Abs(fpath)
328		if err != nil {
329			return handleErr(fmt.Errorf("%s: getting absolute path: %v", fpath, err))
330		}
331		if within(fpathAbs, destAbs) {
332			return nil
333		}
334
335		// build the name to be used within the archive
336		nameInArchive, err := makeNameInArchive(sourceInfo, source, topLevelFolder, fpath)
337		if err != nil {
338			return handleErr(err)
339		}
340
341		var file io.ReadCloser
342		if info.Mode().IsRegular() {
343			file, err = os.Open(fpath)
344			if err != nil {
345				return handleErr(fmt.Errorf("%s: opening: %v", fpath, err))
346			}
347			defer file.Close()
348		}
349		err = z.Write(File{
350			FileInfo: FileInfo{
351				FileInfo:   info,
352				CustomName: nameInArchive,
353			},
354			ReadCloser: file,
355		})
356		if err != nil {
357			return handleErr(fmt.Errorf("%s: writing: %s", fpath, err))
358		}
359
360		return nil
361	})
362}
363
364// Create opens z for writing a ZIP archive to out.
365func (z *Zip) Create(out io.Writer) error {
366	if z.zw != nil {
367		return fmt.Errorf("zip archive is already created for writing")
368	}
369	z.zw = zip.NewWriter(out)
370	if z.CompressionLevel != flate.DefaultCompression {
371		z.zw.RegisterCompressor(zip.Deflate, func(out io.Writer) (io.WriteCloser, error) {
372			return flate.NewWriter(out, z.CompressionLevel)
373		})
374	}
375	switch z.FileMethod {
376	case BZIP2:
377		z.zw.RegisterCompressor(uint16(BZIP2), func(out io.Writer) (io.WriteCloser, error) {
378			return bzip2.NewWriter(out, &bzip2.WriterConfig{Level: z.CompressionLevel})
379		})
380	case ZSTD:
381		z.zw.RegisterCompressor(uint16(ZSTD), func(out io.Writer) (io.WriteCloser, error) {
382			return zstd.NewWriter(out)
383		})
384	case XZ:
385		z.zw.RegisterCompressor(uint16(XZ), func(out io.Writer) (io.WriteCloser, error) {
386			return xz.NewWriter(out)
387		})
388	}
389	return nil
390}
391
392// Write writes f to z, which must have been opened for writing first.
393func (z *Zip) Write(f File) error {
394	if z.zw == nil {
395		return fmt.Errorf("zip archive was not created for writing first")
396	}
397	if f.FileInfo == nil {
398		return fmt.Errorf("no file info")
399	}
400	if f.FileInfo.Name() == "" {
401		return fmt.Errorf("missing file name")
402	}
403
404	header, err := zip.FileInfoHeader(f)
405	if err != nil {
406		return fmt.Errorf("%s: getting header: %v", f.Name(), err)
407	}
408
409	if f.IsDir() {
410		header.Name += "/" // required - strangely no mention of this in zip spec? but is in godoc...
411		header.Method = zip.Store
412	} else {
413		ext := strings.ToLower(path.Ext(header.Name))
414		if _, ok := compressedFormats[ext]; ok && z.SelectiveCompression {
415			header.Method = zip.Store
416		} else {
417			header.Method = uint16(z.FileMethod)
418		}
419	}
420
421	writer, err := z.zw.CreateHeader(header)
422	if err != nil {
423		return fmt.Errorf("%s: making header: %v", f.Name(), err)
424	}
425
426	return z.writeFile(f, writer)
427}
428
429func (z *Zip) writeFile(f File, writer io.Writer) error {
430	if f.IsDir() {
431		return nil // directories have no contents
432	}
433	if isSymlink(f) {
434		// file body for symlinks is the symlink target
435		linkTarget, err := os.Readlink(f.Name())
436		if err != nil {
437			return fmt.Errorf("%s: readlink: %v", f.Name(), err)
438		}
439		_, err = writer.Write([]byte(filepath.ToSlash(linkTarget)))
440		if err != nil {
441			return fmt.Errorf("%s: writing symlink target: %v", f.Name(), err)
442		}
443		return nil
444	}
445
446	if f.ReadCloser == nil {
447		return fmt.Errorf("%s: no way to read file contents", f.Name())
448	}
449	_, err := io.Copy(writer, f)
450	if err != nil {
451		return fmt.Errorf("%s: copying contents: %v", f.Name(), err)
452	}
453
454	return nil
455}
456
457// Open opens z for reading an archive from in,
458// which is expected to have the given size and
459// which must be an io.ReaderAt.
460func (z *Zip) Open(in io.Reader, size int64) error {
461	inRdrAt, ok := in.(io.ReaderAt)
462	if !ok {
463		return fmt.Errorf("reader must be io.ReaderAt")
464	}
465	if z.zr != nil {
466		return fmt.Errorf("zip archive is already open for reading")
467	}
468	var err error
469	z.zr, err = zip.NewReader(inRdrAt, size)
470	if err != nil {
471		return fmt.Errorf("creating reader: %v", err)
472	}
473	registerDecompressor(z.zr)
474	z.ridx = 0
475	return nil
476}
477
478// Read reads the next file from z, which must have
479// already been opened for reading. If there are no
480// more files, the error is io.EOF. The File must
481// be closed when finished reading from it.
482func (z *Zip) Read() (File, error) {
483	if z.zr == nil {
484		return File{}, fmt.Errorf("zip archive is not open")
485	}
486	if z.ridx >= len(z.zr.File) {
487		return File{}, io.EOF
488	}
489
490	// access the file and increment counter so that
491	// if there is an error processing this file, the
492	// caller can still iterate to the next file
493	zf := z.zr.File[z.ridx]
494	z.ridx++
495
496	file := File{
497		FileInfo: zf.FileInfo(),
498		Header:   zf.FileHeader,
499	}
500
501	rc, err := zf.Open()
502	if err != nil {
503		return file, fmt.Errorf("%s: open compressed file: %v", zf.Name, err)
504	}
505	file.ReadCloser = rc
506
507	return file, nil
508}
509
510// Close closes the zip archive(s) opened by Create and Open.
511func (z *Zip) Close() error {
512	if z.zr != nil {
513		z.zr = nil
514	}
515	if z.zw != nil {
516		zw := z.zw
517		z.zw = nil
518		return zw.Close()
519	}
520	return nil
521}
522
523// Walk calls walkFn for each visited item in archive.
524func (z *Zip) Walk(archive string, walkFn WalkFunc) error {
525	zr, err := zip.OpenReader(archive)
526	if err != nil {
527		return fmt.Errorf("opening zip reader: %v", err)
528	}
529	defer zr.Close()
530	registerDecompressor(&zr.Reader)
531	for _, zf := range zr.File {
532		zfrc, err := zf.Open()
533		if err != nil {
534			if zfrc != nil {
535				zfrc.Close()
536			}
537			if z.ContinueOnError {
538				log.Printf("[ERROR] Opening %s: %v", zf.Name, err)
539				continue
540			}
541			return fmt.Errorf("opening %s: %v", zf.Name, err)
542		}
543
544		err = walkFn(File{
545			FileInfo:   zf.FileInfo(),
546			Header:     zf.FileHeader,
547			ReadCloser: zfrc,
548		})
549		zfrc.Close()
550		if err != nil {
551			if err == ErrStopWalk {
552				break
553			}
554			if z.ContinueOnError {
555				log.Printf("[ERROR] Walking %s: %v", zf.Name, err)
556				continue
557			}
558			return fmt.Errorf("walking %s: %v", zf.Name, err)
559		}
560	}
561
562	return nil
563}
564
565// Extract extracts a single file from the zip archive.
566// If the target is a directory, the entire folder will
567// be extracted into destination.
568func (z *Zip) Extract(source, target, destination string) error {
569	// target refers to a path inside the archive, which should be clean also
570	target = path.Clean(target)
571
572	// if the target ends up being a directory, then
573	// we will continue walking and extracting files
574	// until we are no longer within that directory
575	var targetDirPath string
576
577	return z.Walk(source, func(f File) error {
578		zfh, ok := f.Header.(zip.FileHeader)
579		if !ok {
580			return fmt.Errorf("expected header to be zip.FileHeader but was %T", f.Header)
581		}
582
583		// importantly, cleaning the path strips tailing slash,
584		// which must be appended to folders within the archive
585		name := path.Clean(zfh.Name)
586		if f.IsDir() && target == name {
587			targetDirPath = path.Dir(name)
588		}
589
590		if within(target, zfh.Name) {
591			// either this is the exact file we want, or is
592			// in the directory we want to extract
593
594			// build the filename we will extract to
595			end, err := filepath.Rel(targetDirPath, zfh.Name)
596			if err != nil {
597				return fmt.Errorf("relativizing paths: %v", err)
598			}
599			joined := filepath.Join(destination, end)
600
601			err = z.extractFile(f, joined, &zfh)
602			if err != nil {
603				return fmt.Errorf("extracting file %s: %v", zfh.Name, err)
604			}
605
606			// if our target was not a directory, stop walk
607			if targetDirPath == "" {
608				return ErrStopWalk
609			}
610		} else if targetDirPath != "" {
611			// finished walking the entire directory
612			return ErrStopWalk
613		}
614
615		return nil
616	})
617}
618
619// Match returns true if the format of file matches this
620// type's format. It should not affect reader position.
621func (*Zip) Match(file io.ReadSeeker) (bool, error) {
622	currentPos, err := file.Seek(0, io.SeekCurrent)
623	if err != nil {
624		return false, err
625	}
626	_, err = file.Seek(0, 0)
627	if err != nil {
628		return false, err
629	}
630	defer func() {
631		_, _ = file.Seek(currentPos, io.SeekStart)
632	}()
633
634	buf := make([]byte, 4)
635	if n, err := file.Read(buf); err != nil || n < 4 {
636		return false, nil
637	}
638	return bytes.Equal(buf, []byte("PK\x03\x04")), nil
639}
640
641func (z *Zip) String() string { return "zip" }
642
643// NewZip returns a new, default instance ready to be customized and used.
644func NewZip() *Zip {
645	return &Zip{
646		CompressionLevel:     flate.DefaultCompression,
647		MkdirAll:             true,
648		SelectiveCompression: true,
649		FileMethod:           Deflate,
650	}
651}
652
653// Compile-time checks to ensure type implements desired interfaces.
654var (
655	_ = Reader(new(Zip))
656	_ = Writer(new(Zip))
657	_ = Archiver(new(Zip))
658	_ = Unarchiver(new(Zip))
659	_ = Walker(new(Zip))
660	_ = Extractor(new(Zip))
661	_ = Matcher(new(Zip))
662	_ = ExtensionChecker(new(Zip))
663	_ = FilenameChecker(new(Zip))
664)
665
666// compressedFormats is a (non-exhaustive) set of lowercased
667// file extensions for formats that are typically already
668// compressed. Compressing files that are already compressed
669// is inefficient, so use this set of extension to avoid that.
670var compressedFormats = map[string]struct{}{
671	".7z":   {},
672	".avi":  {},
673	".br":   {},
674	".bz2":  {},
675	".cab":  {},
676	".docx": {},
677	".gif":  {},
678	".gz":   {},
679	".jar":  {},
680	".jpeg": {},
681	".jpg":  {},
682	".lz":   {},
683	".lz4":  {},
684	".lzma": {},
685	".m4v":  {},
686	".mov":  {},
687	".mp3":  {},
688	".mp4":  {},
689	".mpeg": {},
690	".mpg":  {},
691	".png":  {},
692	".pptx": {},
693	".rar":  {},
694	".sz":   {},
695	".tbz2": {},
696	".tgz":  {},
697	".tsz":  {},
698	".txz":  {},
699	".xlsx": {},
700	".xz":   {},
701	".zip":  {},
702	".zipx": {},
703}
704
705// DefaultZip is a default instance that is conveniently ready to use.
706var DefaultZip = NewZip()
707