1package api 2 3import ( 4 "bufio" 5 "bytes" 6 "crypto/sha256" 7 "errors" 8 "fmt" 9 "io" 10 "io/ioutil" 11 "net/http" 12 "os" 13 "path/filepath" 14 "strings" 15 "sync" 16 "time" 17) 18 19func NewCachedClient(httpClient *http.Client, cacheTTL time.Duration) *http.Client { 20 cacheDir := filepath.Join(os.TempDir(), "gh-cli-cache") 21 return &http.Client{ 22 Transport: CacheResponse(cacheTTL, cacheDir)(httpClient.Transport), 23 } 24} 25 26func isCacheableRequest(req *http.Request) bool { 27 if strings.EqualFold(req.Method, "GET") || strings.EqualFold(req.Method, "HEAD") { 28 return true 29 } 30 31 if strings.EqualFold(req.Method, "POST") && (req.URL.Path == "/graphql" || req.URL.Path == "/api/graphql") { 32 return true 33 } 34 35 return false 36} 37 38func isCacheableResponse(res *http.Response) bool { 39 return res.StatusCode < 500 && res.StatusCode != 403 40} 41 42// CacheResponse produces a RoundTripper that caches HTTP responses to disk for a specified amount of time 43func CacheResponse(ttl time.Duration, dir string) ClientOption { 44 fs := fileStorage{ 45 dir: dir, 46 ttl: ttl, 47 mu: &sync.RWMutex{}, 48 } 49 50 return func(tr http.RoundTripper) http.RoundTripper { 51 return &funcTripper{roundTrip: func(req *http.Request) (*http.Response, error) { 52 if !isCacheableRequest(req) { 53 return tr.RoundTrip(req) 54 } 55 56 key, keyErr := cacheKey(req) 57 if keyErr == nil { 58 if res, err := fs.read(key); err == nil { 59 res.Request = req 60 return res, nil 61 } 62 } 63 64 res, err := tr.RoundTrip(req) 65 if err == nil && keyErr == nil && isCacheableResponse(res) { 66 _ = fs.store(key, res) 67 } 68 return res, err 69 }} 70 } 71} 72 73func copyStream(r io.ReadCloser) (io.ReadCloser, io.ReadCloser) { 74 b := &bytes.Buffer{} 75 nr := io.TeeReader(r, b) 76 return ioutil.NopCloser(b), &readCloser{ 77 Reader: nr, 78 Closer: r, 79 } 80} 81 82type readCloser struct { 83 io.Reader 84 io.Closer 85} 86 87func cacheKey(req *http.Request) (string, error) { 88 h := sha256.New() 89 fmt.Fprintf(h, "%s:", req.Method) 90 fmt.Fprintf(h, "%s:", req.URL.String()) 91 fmt.Fprintf(h, "%s:", req.Header.Get("Accept")) 92 fmt.Fprintf(h, "%s:", req.Header.Get("Authorization")) 93 94 if req.Body != nil { 95 var bodyCopy io.ReadCloser 96 req.Body, bodyCopy = copyStream(req.Body) 97 defer bodyCopy.Close() 98 if _, err := io.Copy(h, bodyCopy); err != nil { 99 return "", err 100 } 101 } 102 103 digest := h.Sum(nil) 104 return fmt.Sprintf("%x", digest), nil 105} 106 107type fileStorage struct { 108 dir string 109 ttl time.Duration 110 mu *sync.RWMutex 111} 112 113func (fs *fileStorage) filePath(key string) string { 114 if len(key) >= 6 { 115 return filepath.Join(fs.dir, key[0:2], key[2:4], key[4:]) 116 } 117 return filepath.Join(fs.dir, key) 118} 119 120func (fs *fileStorage) read(key string) (*http.Response, error) { 121 cacheFile := fs.filePath(key) 122 123 fs.mu.RLock() 124 defer fs.mu.RUnlock() 125 126 f, err := os.Open(cacheFile) 127 if err != nil { 128 return nil, err 129 } 130 defer f.Close() 131 132 stat, err := f.Stat() 133 if err != nil { 134 return nil, err 135 } 136 137 age := time.Since(stat.ModTime()) 138 if age > fs.ttl { 139 return nil, errors.New("cache expired") 140 } 141 142 body := &bytes.Buffer{} 143 _, err = io.Copy(body, f) 144 if err != nil { 145 return nil, err 146 } 147 148 res, err := http.ReadResponse(bufio.NewReader(body), nil) 149 return res, err 150} 151 152func (fs *fileStorage) store(key string, res *http.Response) error { 153 cacheFile := fs.filePath(key) 154 155 fs.mu.Lock() 156 defer fs.mu.Unlock() 157 158 err := os.MkdirAll(filepath.Dir(cacheFile), 0755) 159 if err != nil { 160 return err 161 } 162 163 f, err := os.OpenFile(cacheFile, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0600) 164 if err != nil { 165 return err 166 } 167 defer f.Close() 168 169 var origBody io.ReadCloser 170 if res.Body != nil { 171 origBody, res.Body = copyStream(res.Body) 172 defer res.Body.Close() 173 } 174 err = res.Write(f) 175 if origBody != nil { 176 res.Body = origBody 177 } 178 return err 179} 180