1package persistence
2
3import (
4	"encoding/base64"
5	"errors"
6	"fmt"
7	"time"
8
9	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options"
10	"github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions"
11	. "github.com/onsi/ginkgo"
12	. "github.com/onsi/ginkgo/extensions/table"
13	. "github.com/onsi/gomega"
14)
15
16var _ = Describe("Session Ticket Tests", func() {
17	Context("encodeTicket & decodeTicket", func() {
18		type ticketTableInput struct {
19			ticket        *ticket
20			encodedTicket string
21			expectedError error
22		}
23
24		DescribeTable("encodeTicket should decodeTicket back when valid",
25			func(in ticketTableInput) {
26				if in.ticket != nil {
27					enc := in.ticket.encodeTicket()
28					Expect(enc).To(Equal(in.encodedTicket))
29
30					dec, err := decodeTicket(enc, in.ticket.options)
31					Expect(err).ToNot(HaveOccurred())
32					Expect(dec).To(Equal(in.ticket))
33				} else {
34					_, err := decodeTicket(in.encodedTicket, nil)
35					Expect(err).To(MatchError(in.expectedError))
36				}
37			},
38			Entry("with a valid ticket", ticketTableInput{
39				ticket: &ticket{
40					id:     "dummy-0123456789abcdef",
41					secret: []byte("0123456789abcdef"),
42					options: &options.Cookie{
43						Name: "dummy",
44					},
45				},
46				encodedTicket: fmt.Sprintf("%s.%s",
47					"dummy-0123456789abcdef",
48					base64.RawURLEncoding.EncodeToString([]byte("0123456789abcdef"))),
49				expectedError: nil,
50			}),
51			Entry("with an invalid encoded ticket with 1 part", ticketTableInput{
52				ticket:        nil,
53				encodedTicket: "dummy-0123456789abcdef",
54				expectedError: errors.New("failed to decode ticket"),
55			}),
56			Entry("with an invalid base64 encoded secret", ticketTableInput{
57				ticket:        nil,
58				encodedTicket: "dummy-0123456789abcdef.@)#($*@)#(*$@)#(*$",
59				expectedError: fmt.Errorf("failed to decode encryption secret: illegal base64 data at input byte 0"),
60			}),
61		)
62	})
63
64	Context("saveSession", func() {
65		It("uses the passed save function", func() {
66			t, err := newTicket(&options.Cookie{Name: "dummy"})
67			Expect(err).ToNot(HaveOccurred())
68
69			c, err := t.makeCipher()
70			Expect(err).ToNot(HaveOccurred())
71
72			ss := &sessions.SessionState{User: "foobar"}
73			store := map[string][]byte{}
74			err = t.saveSession(ss, func(k string, v []byte, e time.Duration) error {
75				store[k] = v
76				return nil
77			})
78			Expect(err).ToNot(HaveOccurred())
79
80			stored, err := sessions.DecodeSessionState(store[t.id], c, false)
81			Expect(err).ToNot(HaveOccurred())
82			Expect(stored).To(Equal(ss))
83		})
84
85		It("errors when the saveFunc errors", func() {
86			t, err := newTicket(&options.Cookie{Name: "dummy"})
87			Expect(err).ToNot(HaveOccurred())
88
89			err = t.saveSession(
90				&sessions.SessionState{User: "foobar"},
91				func(k string, v []byte, e time.Duration) error {
92					return errors.New("save error")
93				})
94			Expect(err).To(MatchError(errors.New("save error")))
95		})
96	})
97
98	Context("loadSession", func() {
99		It("uses the passed load function", func() {
100			t, err := newTicket(&options.Cookie{Name: "dummy"})
101			Expect(err).ToNot(HaveOccurred())
102
103			c, err := t.makeCipher()
104			Expect(err).ToNot(HaveOccurred())
105
106			ss := &sessions.SessionState{
107				User: "foobar",
108				Lock: &sessions.NoOpLock{},
109			}
110			loadedSession, err := t.loadSession(
111				func(k string) ([]byte, error) {
112					return ss.EncodeSessionState(c, false)
113				},
114				func(k string) sessions.Lock {
115					return &sessions.NoOpLock{}
116				})
117			Expect(err).ToNot(HaveOccurred())
118			Expect(loadedSession).To(Equal(ss))
119		})
120
121		It("errors when the loadFunc errors", func() {
122			t, err := newTicket(&options.Cookie{Name: "dummy"})
123			Expect(err).ToNot(HaveOccurred())
124
125			data, err := t.loadSession(
126				func(k string) ([]byte, error) {
127					return nil, errors.New("load error")
128				},
129				func(k string) sessions.Lock {
130					return &sessions.NoOpLock{}
131				})
132			Expect(data).To(BeNil())
133			Expect(err).To(MatchError(errors.New("failed to load the session state with the ticket: load error")))
134		})
135	})
136
137	Context("clearSession", func() {
138		It("uses the passed clear function", func() {
139			t, err := newTicket(&options.Cookie{Name: "dummy"})
140			Expect(err).ToNot(HaveOccurred())
141
142			var tracker string
143			err = t.clearSession(func(k string) error {
144				tracker = k
145				return nil
146			})
147			Expect(err).ToNot(HaveOccurred())
148			Expect(tracker).To(Equal(t.id))
149		})
150
151		It("errors when the clearFunc errors", func() {
152			t, err := newTicket(&options.Cookie{Name: "dummy"})
153			Expect(err).ToNot(HaveOccurred())
154
155			err = t.clearSession(func(k string) error {
156				return errors.New("clear error")
157			})
158			Expect(err).To(MatchError(errors.New("clear error")))
159		})
160	})
161})
162