1// Copyright 2020 Google LLC. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package idtoken 6 7import ( 8 "context" 9 "encoding/json" 10 "fmt" 11 "net/http" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16) 17 18type cachingClient struct { 19 client *http.Client 20 21 // clock optionally specifies a func to return the current time. 22 // If nil, time.Now is used. 23 clock func() time.Time 24 25 mu sync.Mutex 26 certs map[string]*cachedResponse 27} 28 29func newCachingClient(client *http.Client) *cachingClient { 30 return &cachingClient{ 31 client: client, 32 certs: make(map[string]*cachedResponse, 2), 33 } 34} 35 36type cachedResponse struct { 37 resp *certResponse 38 exp time.Time 39} 40 41func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) { 42 if response, ok := c.get(url); ok { 43 return response, nil 44 } 45 req, err := http.NewRequest(http.MethodGet, url, nil) 46 if err != nil { 47 return nil, err 48 } 49 req = req.WithContext(ctx) 50 resp, err := c.client.Do(req) 51 if err != nil { 52 return nil, err 53 } 54 defer resp.Body.Close() 55 if resp.StatusCode != http.StatusOK { 56 return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode) 57 } 58 59 certResp := &certResponse{} 60 if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil { 61 return nil, err 62 63 } 64 c.set(url, certResp, resp.Header) 65 return certResp, nil 66} 67 68func (c *cachingClient) now() time.Time { 69 if c.clock != nil { 70 return c.clock() 71 } 72 return time.Now() 73} 74 75func (c *cachingClient) get(url string) (*certResponse, bool) { 76 c.mu.Lock() 77 defer c.mu.Unlock() 78 cachedResp, ok := c.certs[url] 79 if !ok { 80 return nil, false 81 } 82 if c.now().After(cachedResp.exp) { 83 return nil, false 84 } 85 return cachedResp.resp, true 86} 87 88func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) { 89 exp := c.calculateExpireTime(headers) 90 c.mu.Lock() 91 c.certs[url] = &cachedResponse{resp: resp, exp: exp} 92 c.mu.Unlock() 93} 94 95// calculateExpireTime will determine the expire time for the cache based on 96// HTTP headers. If there is any difficulty reading the headers the fallback is 97// to set the cache to expire now. 98func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time { 99 var maxAge int 100 cc := strings.Split(headers.Get("cache-control"), ",") 101 for _, v := range cc { 102 if strings.Contains(v, "max-age") { 103 ss := strings.Split(v, "=") 104 if len(ss) < 2 { 105 return c.now() 106 } 107 ma, err := strconv.Atoi(ss[1]) 108 if err != nil { 109 return c.now() 110 } 111 maxAge = ma 112 } 113 } 114 age, err := strconv.Atoi(headers.Get("age")) 115 if err != nil { 116 return c.now() 117 } 118 return c.now().Add(time.Duration(maxAge-age) * time.Second) 119} 120