1package cookie
2
3import (
4	"errors"
5	"fmt"
6	"net/http"
7	"regexp"
8	"time"
9
10	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
11	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
12	pkgcookies "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies"
13	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption"
14	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/logger"
15)
16
17const (
18	// Cookies are limited to 4kb for all parts
19	// including the cookie name, value, attributes; IE (http.cookie).String()
20	// Most browsers' max is 4096 -- but we give ourselves some leeway
21	maxCookieLength = 4000
22)
23
24// Ensure CookieSessionStore implements the interface
25var _ sessions.SessionStore = &SessionStore{}
26
27// SessionStore is an implementation of the sessions.SessionStore
28// interface that stores sessions in client side cookies
29type SessionStore struct {
30	Cookie       *options.Cookie
31	CookieCipher encryption.Cipher
32	Minimal      bool
33}
34
35// Save takes a sessions.SessionState and stores the information from it
36// within Cookies set on the HTTP response writer
37func (s *SessionStore) Save(rw http.ResponseWriter, req *http.Request, ss *sessions.SessionState) error {
38	if ss.CreatedAt == nil || ss.CreatedAt.IsZero() {
39		ss.CreatedAtNow()
40	}
41	value, err := s.cookieForSession(ss)
42	if err != nil {
43		return err
44	}
45	return s.setSessionCookie(rw, req, value, *ss.CreatedAt)
46}
47
48// Load reads sessions.SessionState information from Cookies within the
49// HTTP request object
50func (s *SessionStore) Load(req *http.Request) (*sessions.SessionState, error) {
51	c, err := loadCookie(req, s.Cookie.Name)
52	if err != nil {
53		// always http.ErrNoCookie
54		return nil, err
55	}
56	val, _, ok := encryption.Validate(c, s.Cookie.Secret, s.Cookie.Expire)
57	if !ok {
58		return nil, errors.New("cookie signature not valid")
59	}
60
61	session, err := sessions.DecodeSessionState(val, s.CookieCipher, true)
62	if err != nil {
63		return nil, err
64	}
65	return session, nil
66}
67
68// Clear clears any saved session information by writing a cookie to
69// clear the session
70func (s *SessionStore) Clear(rw http.ResponseWriter, req *http.Request) error {
71	// matches CookieName, CookieName_<number>
72	var cookieNameRegex = regexp.MustCompile(fmt.Sprintf("^%s(_\\d+)?$", s.Cookie.Name))
73
74	for _, c := range req.Cookies() {
75		if cookieNameRegex.MatchString(c.Name) {
76			clearCookie := s.makeCookie(req, c.Name, "", time.Hour*-1, time.Now())
77
78			http.SetCookie(rw, clearCookie)
79		}
80	}
81
82	return nil
83}
84
85// cookieForSession serializes a session state for storage in a cookie
86func (s *SessionStore) cookieForSession(ss *sessions.SessionState) ([]byte, error) {
87	if s.Minimal && (ss.AccessToken != "" || ss.IDToken != "" || ss.RefreshToken != "") {
88		minimal := *ss
89		minimal.AccessToken = ""
90		minimal.IDToken = ""
91		minimal.RefreshToken = ""
92
93		return minimal.EncodeSessionState(s.CookieCipher, true)
94	}
95
96	return ss.EncodeSessionState(s.CookieCipher, true)
97}
98
99// setSessionCookie adds the user's session cookie to the response
100func (s *SessionStore) setSessionCookie(rw http.ResponseWriter, req *http.Request, val []byte, created time.Time) error {
101	cookies, err := s.makeSessionCookie(req, val, created)
102	if err != nil {
103		return err
104	}
105	for _, c := range cookies {
106		http.SetCookie(rw, c)
107	}
108	return nil
109}
110
111// makeSessionCookie creates an http.Cookie containing the authenticated user's
112// authentication details
113func (s *SessionStore) makeSessionCookie(req *http.Request, value []byte, now time.Time) ([]*http.Cookie, error) {
114	strValue := string(value)
115	if strValue != "" {
116		var err error
117		strValue, err = encryption.SignedValue(s.Cookie.Secret, s.Cookie.Name, value, now)
118		if err != nil {
119			return nil, err
120		}
121	}
122	c := s.makeCookie(req, s.Cookie.Name, strValue, s.Cookie.Expire, now)
123	if len(c.String()) > maxCookieLength {
124		return splitCookie(c), nil
125	}
126	return []*http.Cookie{c}, nil
127}
128
129func (s *SessionStore) makeCookie(req *http.Request, name string, value string, expiration time.Duration, now time.Time) *http.Cookie {
130	return pkgcookies.MakeCookieFromOptions(
131		req,
132		name,
133		value,
134		s.Cookie,
135		expiration,
136		now,
137	)
138}
139
140// NewCookieSessionStore initialises a new instance of the SessionStore from
141// the configuration given
142func NewCookieSessionStore(opts *options.SessionOptions, cookieOpts *options.Cookie) (sessions.SessionStore, error) {
143	cipher, err := encryption.NewCFBCipher(encryption.SecretBytes(cookieOpts.Secret))
144	if err != nil {
145		return nil, fmt.Errorf("error initialising cipher: %v", err)
146	}
147
148	return &SessionStore{
149		CookieCipher: cipher,
150		Cookie:       cookieOpts,
151		Minimal:      opts.Cookie.Minimal,
152	}, nil
153}
154
155// splitCookie reads the full cookie generated to store the session and splits
156// it into a slice of cookies which fit within the 4kb cookie limit indexing
157// the cookies from 0
158func splitCookie(c *http.Cookie) []*http.Cookie {
159	if len(c.String()) < maxCookieLength {
160		return []*http.Cookie{c}
161	}
162
163	logger.Errorf("WARNING: Multiple cookies are required for this session as it exceeds the 4kb cookie limit. Please use server side session storage (eg. Redis) instead.")
164
165	cookies := []*http.Cookie{}
166	valueBytes := []byte(c.Value)
167	count := 0
168	for len(valueBytes) > 0 {
169		newCookie := copyCookie(c)
170		newCookie.Name = splitCookieName(c.Name, count)
171		count++
172
173		newCookie.Value = string(valueBytes)
174		cookieLength := len(newCookie.String())
175		if cookieLength <= maxCookieLength {
176			valueBytes = []byte{}
177		} else {
178			overflow := cookieLength - maxCookieLength
179			valueSize := len(valueBytes) - overflow
180
181			newValue := valueBytes[:valueSize]
182			valueBytes = valueBytes[valueSize:]
183			newCookie.Value = string(newValue)
184		}
185		cookies = append(cookies, newCookie)
186	}
187	return cookies
188}
189
190func splitCookieName(name string, count int) string {
191	splitName := fmt.Sprintf("%s_%d", name, count)
192	overflow := len(splitName) - 256
193	if overflow > 0 {
194		splitName = fmt.Sprintf("%s_%d", name[:len(name)-overflow], count)
195	}
196	return splitName
197}
198
199// loadCookie retreieves the sessions state cookie from the http request.
200// If a single cookie is present this will be returned, otherwise it attempts
201// to reconstruct a cookie split up by splitCookie
202func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) {
203	c, err := req.Cookie(cookieName)
204	if err == nil {
205		return c, nil
206	}
207	cookies := []*http.Cookie{}
208	err = nil
209	count := 0
210	for err == nil {
211		var c *http.Cookie
212		c, err = req.Cookie(splitCookieName(cookieName, count))
213		if err == nil {
214			cookies = append(cookies, c)
215			count++
216		}
217	}
218	if len(cookies) == 0 {
219		return nil, http.ErrNoCookie
220	}
221	return joinCookies(cookies, cookieName)
222}
223
224// joinCookies takes a slice of cookies from the request and reconstructs the
225// full session cookie
226func joinCookies(cookies []*http.Cookie, cookieName string) (*http.Cookie, error) {
227	if len(cookies) == 0 {
228		return nil, fmt.Errorf("list of cookies must be > 0")
229	}
230	if len(cookies) == 1 {
231		return cookies[0], nil
232	}
233	c := copyCookie(cookies[0])
234	for i := 1; i < len(cookies); i++ {
235		c.Value += cookies[i].Value
236	}
237	c.Name = cookieName
238	return c, nil
239}
240
241func copyCookie(c *http.Cookie) *http.Cookie {
242	return &http.Cookie{
243		Name:       c.Name,
244		Value:      c.Value,
245		Path:       c.Path,
246		Domain:     c.Domain,
247		Expires:    c.Expires,
248		RawExpires: c.RawExpires,
249		MaxAge:     c.MaxAge,
250		Secure:     c.Secure,
251		HttpOnly:   c.HttpOnly,
252		Raw:        c.Raw,
253		Unparsed:   c.Unparsed,
254		SameSite:   c.SameSite,
255	}
256}
257