1package api
2
3import (
4	"bufio"
5	"bytes"
6	"crypto/sha256"
7	"errors"
8	"fmt"
9	"io"
10	"io/ioutil"
11	"net/http"
12	"os"
13	"path/filepath"
14	"strings"
15	"sync"
16	"time"
17)
18
19func NewCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client {
20	cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache")
21	return &http.Client{
22		Transport: CacheResponse(cacheTTL, cacheDir)(httpClient.Transport),
23	}
24}
25
26func isCacheableRequest(req *http.Request) bool {
27	if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") {
28		return true
29	}
30
31	if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") {
32		return true
33	}
34
35	return false
36}
37
38func isCacheableResponse(res *http.Response) bool {
39	return res.StatusCode < 500 && res.StatusCode != 403
40}
41
42// CacheResponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time
43func CacheResponse(ttl time.Duration, dir string) ClientOption {
44	fs := fileStorage{
45		dir: dir,
46		ttl: ttl,
47		mu:  &sync.RWMutex{},
48	}
49
50	return func(tr http.RoundTripper) http.RoundTripper {
51		return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) {
52			if !isCacheableRequest(req) {
53				return tr.RoundTrip(req)
54			}
55
56			key, keyErr := cacheKey(req)
57			if keyErr == nil {
58				if res, err := fs.read(key); err == nil {
59					res.Request = req
60					return res, nil
61				}
62			}
63
64			res, err := tr.RoundTrip(req)
65			if err == nil && keyErr == nil && isCacheableResponse(res) {
66				_ = fs.store(key, res)
67			}
68			return res, err
69		}}
70	}
71}
72
73func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) {
74	b := &bytes.Buffer{}
75	nr := io.TeeReader(r, b)
76	return ioutil.NopCloser(b), &readCloser{
77		Reader: nr,
78		Closer: r,
79	}
80}
81
82type readCloser struct {
83	io.Reader
84	io.Closer
85}
86
87func cacheKey(req *http.Request) (string, error) {
88	h := sha256.New()
89	fmt.Fprintf(h, "%s:", req.Method)
90	fmt.Fprintf(h, "%s:", req.URL.String())
91	fmt.Fprintf(h, "%s:", req.Header.Get("Accept"))
92	fmt.Fprintf(h, "%s:", req.Header.Get("Authorization"))
93
94	if req.Body != nil {
95		var bodyCopy io.ReadCloser
96		req.Body, bodyCopy = copyStream(req.Body)
97		defer bodyCopy.Close()
98		if _, err := io.Copy(h, bodyCopy); err != nil {
99			return "", err
100		}
101	}
102
103	digest := h.Sum(nil)
104	return fmt.Sprintf("%x", digest), nil
105}
106
107type fileStorage struct {
108	dir string
109	ttl time.Duration
110	mu  *sync.RWMutex
111}
112
113func (fs *fileStorage) filePath(key string) string {
114	if len(key) >= 6 {
115		return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:])
116	}
117	return filepath.Join(fs.dir, key)
118}
119
120func (fs *fileStorage) read(key string) (*http.Response, error) {
121	cacheFile := fs.filePath(key)
122
123	fs.mu.RLock()
124	defer fs.mu.RUnlock()
125
126	f, err := os.Open(cacheFile)
127	if err != nil {
128		return nil, err
129	}
130	defer f.Close()
131
132	stat, err := f.Stat()
133	if err != nil {
134		return nil, err
135	}
136
137	age := time.Since(stat.ModTime())
138	if age > fs.ttl {
139		return nil, errors.New("cache expired")
140	}
141
142	body := &bytes.Buffer{}
143	_, err = io.Copy(body, f)
144	if err != nil {
145		return nil, err
146	}
147
148	res, err := http.ReadResponse(bufio.NewReader(body), nil)
149	return res, err
150}
151
152func (fs *fileStorage) store(key string, res *http.Response) error {
153	cacheFile := fs.filePath(key)
154
155	fs.mu.Lock()
156	defer fs.mu.Unlock()
157
158	err := os.MkdirAll(filepath.Dir(cacheFile), 0755)
159	if err != nil {
160		return err
161	}
162
163	f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600)
164	if err != nil {
165		return err
166	}
167	defer f.Close()
168
169	var origBody io.ReadCloser
170	if res.Body != nil {
171		origBody, res.Body = copyStream(res.Body)
172		defer res.Body.Close()
173	}
174	err = res.Write(f)
175	if origBody != nil {
176		res.Body = origBody
177	}
178	return err
179}
180