1// Copyright 2020 Google LLC.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package idtoken
6
7import (
8	"context"
9	"encoding/json"
10	"fmt"
11	"net/http"
12	"strconv"
13	"strings"
14	"sync"
15	"time"
16)
17
18type cachingClient struct {
19	client *http.Client
20
21	// clock optionally specifies a func to return the current time.
22	// If nil, time.Now is used.
23	clock func() time.Time
24
25	mu    sync.Mutex
26	certs map[string]*cachedResponse
27}
28
29func newCachingClient(client *http.Client) *cachingClient {
30	return &cachingClient{
31		client: client,
32		certs:  make(map[string]*cachedResponse, 2),
33	}
34}
35
36type cachedResponse struct {
37	resp *certResponse
38	exp  time.Time
39}
40
41func (c *cachingClient) getCert(ctx context.Context, url string) (*certResponse, error) {
42	if response, ok := c.get(url); ok {
43		return response, nil
44	}
45	req, err := http.NewRequest(http.MethodGet, url, nil)
46	if err != nil {
47		return nil, err
48	}
49	req = req.WithContext(ctx)
50	resp, err := c.client.Do(req)
51	if err != nil {
52		return nil, err
53	}
54	defer resp.Body.Close()
55	if resp.StatusCode != http.StatusOK {
56		return nil, fmt.Errorf("idtoken: unable to retrieve cert, got status code %d", resp.StatusCode)
57	}
58
59	certResp := &certResponse{}
60	if err := json.NewDecoder(resp.Body).Decode(certResp); err != nil {
61		return nil, err
62
63	}
64	c.set(url, certResp, resp.Header)
65	return certResp, nil
66}
67
68func (c *cachingClient) now() time.Time {
69	if c.clock != nil {
70		return c.clock()
71	}
72	return time.Now()
73}
74
75func (c *cachingClient) get(url string) (*certResponse, bool) {
76	c.mu.Lock()
77	defer c.mu.Unlock()
78	cachedResp, ok := c.certs[url]
79	if !ok {
80		return nil, false
81	}
82	if c.now().After(cachedResp.exp) {
83		return nil, false
84	}
85	return cachedResp.resp, true
86}
87
88func (c *cachingClient) set(url string, resp *certResponse, headers http.Header) {
89	exp := c.calculateExpireTime(headers)
90	c.mu.Lock()
91	c.certs[url] = &cachedResponse{resp: resp, exp: exp}
92	c.mu.Unlock()
93}
94
95// calculateExpireTime will determine the expire time for the cache based on
96// HTTP headers. If there is any difficulty reading the headers the fallback is
97// to set the cache to expire now.
98func (c *cachingClient) calculateExpireTime(headers http.Header) time.Time {
99	var maxAge int
100	cc := strings.Split(headers.Get("cache-control"), ",")
101	for _, v := range cc {
102		if strings.Contains(v, "max-age") {
103			ss := strings.Split(v, "=")
104			if len(ss) < 2 {
105				return c.now()
106			}
107			ma, err := strconv.Atoi(ss[1])
108			if err != nil {
109				return c.now()
110			}
111			maxAge = ma
112		}
113	}
114	age, err := strconv.Atoi(headers.Get("age"))
115	if err != nil {
116		return c.now()
117	}
118	return c.now().Add(time.Duration(maxAge-age) * time.Second)
119}
120