1package main
2
3import (
4	"bufio"
5	"errors"
6	"flag"
7	"fmt"
8	"io"
9	"io/ioutil"
10	"os"
11	"path/filepath"
12	"strings"
13	"sync"
14	"time"
15
16	"github.com/klauspost/compress/s2"
17	"github.com/klauspost/compress/s2/cmd/internal/readahead"
18)
19
20var (
21	safe   = flag.Bool("safe", false, "Do not overwrite output files")
22	verify = flag.Bool("verify", false, "Verify files, but do not write output")
23	stdout = flag.Bool("c", false, "Write all output to stdout. Multiple input files will be concatenated")
24	remove = flag.Bool("rm", false, "Delete source file(s) after successful decompression")
25	quiet  = flag.Bool("q", false, "Don't write any output to terminal, except errors")
26	bench  = flag.Int("bench", 0, "Run benchmark n times. No output will be written")
27	help   = flag.Bool("help", false, "Display help")
28
29	version = "(dev)"
30	date    = "(unknown)"
31)
32
33func main() {
34	flag.Parse()
35	r := s2.NewReader(nil)
36
37	// No args, use stdin/stdout
38	args := flag.Args()
39	if len(args) == 0 || *help {
40		_, _ = fmt.Fprintf(os.Stderr, "s2 decompress v%v, built at %v.\n\n", version, date)
41		_, _ = fmt.Fprintf(os.Stderr, "Copyright (c) 2011 The Snappy-Go Authors. All rights reserved.\n"+
42			"Copyright (c) 2019 Klaus Post. All rights reserved.\n\n")
43		_, _ = fmt.Fprintln(os.Stderr, `Usage: s2d [options] file1 file2
44
45Decompresses all files supplied as input. Input files must end with '.s2' or '.snappy'.
46Output file names have the extension removed. By default output files will be overwritten.
47Use - as the only file name to read from stdin and write to stdout.
48
49Wildcards are accepted: testdir/*.txt will compress all files in testdir ending with .txt
50Directories can be wildcards as well. testdir/*/*.txt will match testdir/subdir/b.txt
51
52Options:`)
53		flag.PrintDefaults()
54		os.Exit(0)
55	}
56	if len(args) == 1 && args[0] == "-" {
57		r.Reset(os.Stdin)
58		if !*verify {
59			_, err := io.Copy(os.Stdout, r)
60			exitErr(err)
61		} else {
62			_, err := io.Copy(ioutil.Discard, r)
63			exitErr(err)
64		}
65		return
66	}
67	var files []string
68
69	for _, pattern := range args {
70		found, err := filepath.Glob(pattern)
71		exitErr(err)
72		if len(found) == 0 {
73			exitErr(fmt.Errorf("unable to find file %v", pattern))
74		}
75		files = append(files, found...)
76	}
77
78	*quiet = *quiet || *stdout
79	allFiles := files
80	for i := 0; i < *bench; i++ {
81		files = append(files, allFiles...)
82	}
83
84	for _, filename := range files {
85		dstFilename := filename
86		switch {
87		case strings.HasSuffix(filename, ".s2"):
88			dstFilename = strings.TrimSuffix(filename, ".s2")
89		case strings.HasSuffix(filename, ".snappy"):
90			dstFilename = strings.TrimSuffix(filename, ".snappy")
91		default:
92			fmt.Println("Skipping", filename)
93			continue
94		}
95		if *bench > 0 {
96			dstFilename = "(discarded)"
97		}
98		if *verify {
99			dstFilename = "(verify)"
100		}
101
102		func() {
103			var closeOnce sync.Once
104			if !*quiet {
105				fmt.Print("Decompressing ", filename, " -> ", dstFilename)
106			}
107			// Input file.
108			file, err := os.Open(filename)
109			exitErr(err)
110			defer closeOnce.Do(func() { file.Close() })
111			rc := rCounter{in: file}
112			src, err := readahead.NewReaderSize(&rc, 2, 4<<20)
113			exitErr(err)
114			defer src.Close()
115			finfo, err := file.Stat()
116			exitErr(err)
117			mode := finfo.Mode() // use the same mode for the output file
118			if *safe {
119				_, err := os.Stat(dstFilename)
120				if !os.IsNotExist(err) {
121					exitErr(errors.New("destination files exists"))
122				}
123			}
124			var out io.Writer
125			switch {
126			case *bench > 0 || *verify:
127				out = ioutil.Discard
128			case *stdout:
129				out = os.Stdout
130			default:
131				dstFile, err := os.OpenFile(dstFilename, os.O_CREATE|os.O_WRONLY, mode)
132				exitErr(err)
133				defer dstFile.Close()
134				bw := bufio.NewWriterSize(dstFile, 4<<20)
135				defer bw.Flush()
136				out = bw
137			}
138			r.Reset(src)
139			start := time.Now()
140			output, err := io.Copy(out, r)
141			exitErr(err)
142			if !*quiet {
143				elapsed := time.Since(start)
144				mbPerSec := (float64(output) / (1024 * 1024)) / (float64(elapsed) / (float64(time.Second)))
145				pct := float64(output) * 100 / float64(rc.n)
146				fmt.Printf(" %d -> %d [%.02f%%]; %.01fMB/s\n", rc.n, output, pct, mbPerSec)
147			}
148			if *remove && !*verify {
149				closeOnce.Do(func() {
150					file.Close()
151					if !*quiet {
152						fmt.Println("Removing", filename)
153					}
154					err := os.Remove(filename)
155					exitErr(err)
156				})
157			}
158		}()
159	}
160}
161
162func exitErr(err error) {
163	if err != nil {
164		fmt.Fprintln(os.Stderr, "\nERROR:", err.Error())
165		os.Exit(2)
166	}
167}
168
169type rCounter struct {
170	n  int
171	in io.Reader
172}
173
174func (w *rCounter) Read(p []byte) (n int, err error) {
175	n, err = w.in.Read(p)
176	w.n += n
177	return n, err
178
179}
180