1// TODO: more expire/cleanup tests?
2
3package gormstore
4
5import (
6	"flag"
7	"fmt"
8	"net/http"
9	"net/http/httptest"
10	"os"
11	"strings"
12	"testing"
13	"time"
14
15	_ "github.com/go-sql-driver/mysql"
16	"github.com/jinzhu/gorm"
17	_ "github.com/lib/pq"
18	_ "github.com/mattn/go-sqlite3"
19)
20
21// default test db
22var dbURI = "sqlite3://file:dummy?mode=memory&cache=shared"
23
24// TODO: this is ugly
25func parseCookies(value string) map[string]*http.Cookie {
26	m := map[string]*http.Cookie{}
27	for _, c := range (&http.Request{Header: http.Header{"Cookie": {value}}}).Cookies() {
28		m[c.Name] = c
29	}
30	return m
31}
32
33func connectDbURI(uri string) (*gorm.DB, error) {
34	parts := strings.SplitN(uri, "://", 2)
35	driver := parts[0]
36	dsn := parts[1]
37
38	var err error
39	// retry to give some time for db to be ready
40	for i := 0; i < 50; i++ {
41		g, err := gorm.Open(driver, dsn)
42		if err == nil {
43			return g, nil
44		}
45		time.Sleep(500 * time.Millisecond)
46	}
47
48	return nil, err
49}
50
51// create new shared in memory db
52func newDB() *gorm.DB {
53	var err error
54	var db *gorm.DB
55	if db, err = connectDbURI(dbURI); err != nil {
56		panic(err)
57	}
58
59	//db.LogMode(true)
60
61	// cleanup db
62	if err := db.DropTableIfExists(
63		&gormSession{tableName: "abc"},
64		&gormSession{tableName: "sessions"},
65	).Error; err != nil {
66		panic(err)
67	}
68
69	return db
70}
71
72func req(handler http.HandlerFunc, sessionCookie *http.Cookie) *httptest.ResponseRecorder {
73	req, _ := http.NewRequest("GET", "http://test", nil)
74	if sessionCookie != nil {
75		req.Header.Add("Cookie", fmt.Sprintf("%s=%s", sessionCookie.Name, sessionCookie.Value))
76	}
77	w := httptest.NewRecorder()
78	handler(w, req)
79	return w
80}
81
82func match(t *testing.T, resp *httptest.ResponseRecorder, code int, body string) {
83	if resp.Code != code {
84		t.Errorf("Expected %v, actual %v", code, resp.Code)
85	}
86	// http.Error in countHandler adds a \n
87	if strings.Trim(resp.Body.String(), "\n") != body {
88		t.Errorf("Expected %v, actual %v", body, resp.Body)
89	}
90}
91
92func findSession(db *gorm.DB, store *Store, id string) *gormSession {
93	s := &gormSession{tableName: store.opts.TableName}
94	if db.Where("id = ?", id).First(s).RecordNotFound() {
95		return nil
96	}
97	return s
98}
99
100func makeCountHandler(name string, store *Store) http.HandlerFunc {
101	return func(w http.ResponseWriter, r *http.Request) {
102		session, err := store.Get(r, name)
103		if err != nil {
104			panic(err)
105		}
106
107		count, _ := session.Values["count"].(int)
108		count++
109		session.Values["count"] = count
110		if err := store.Save(r, w, session); err != nil {
111			panic(err)
112		}
113		// leak session ID so we can mess with it in the db
114		w.Header().Add("X-Session", session.ID)
115		http.Error(w, fmt.Sprintf("%d", count), http.StatusOK)
116	}
117}
118
119func TestBasic(t *testing.T) {
120	countFn := makeCountHandler("session", New(newDB(), []byte("secret")))
121	r1 := req(countFn, nil)
122	match(t, r1, 200, "1")
123	r2 := req(countFn, parseCookies(r1.Header().Get("Set-Cookie"))["session"])
124	match(t, r2, 200, "2")
125}
126
127func TestExpire(t *testing.T) {
128	db := newDB()
129	store := New(db, []byte("secret"))
130	countFn := makeCountHandler("session", store)
131
132	r1 := req(countFn, nil)
133	match(t, r1, 200, "1")
134
135	// test still in db but expired
136	id := r1.Header().Get("X-Session")
137	s := findSession(db, store, id)
138	s.ExpiresAt = gorm.NowFunc().Add(-40 * 24 * time.Hour)
139	db.Save(s)
140
141	r2 := req(countFn, parseCookies(r1.Header().Get("Set-Cookie"))["session"])
142	match(t, r2, 200, "1")
143
144	store.Cleanup()
145
146	if findSession(db, store, id) != nil {
147		t.Error("Expected session to be deleted")
148	}
149}
150
151func TestBrokenCookie(t *testing.T) {
152	db := newDB()
153	store := New(db, []byte("secret"))
154	countFn := makeCountHandler("session", store)
155
156	r1 := req(countFn, nil)
157	match(t, r1, 200, "1")
158
159	cookie := parseCookies(r1.Header().Get("Set-Cookie"))["session"]
160	cookie.Value += "junk"
161	r2 := req(countFn, cookie)
162	match(t, r2, 200, "1")
163}
164
165func TestMaxAgeNegative(t *testing.T) {
166	db := newDB()
167	store := New(db, []byte("secret"))
168	countFn := makeCountHandler("session", store)
169
170	r1 := req(countFn, nil)
171	match(t, r1, 200, "1")
172
173	r2 := req(func(w http.ResponseWriter, r *http.Request) {
174		session, err := store.Get(r, "session")
175		if err != nil {
176			panic(err)
177		}
178
179		session.Options.MaxAge = -1
180		store.Save(r, w, session)
181
182		http.Error(w, "", http.StatusOK)
183	}, parseCookies(r1.Header().Get("Set-Cookie"))["session"])
184
185	match(t, r2, 200, "")
186	c := parseCookies(r2.Header().Get("Set-Cookie"))["session"]
187	if c.Value != "" {
188		t.Error("Expected empty Set-Cookie session header", c)
189	}
190
191	id := r1.Header().Get("X-Session")
192	if s := findSession(db, store, id); s != nil {
193		t.Error("Expected session to be deleted")
194	}
195}
196
197func TestMaxLength(t *testing.T) {
198	store := New(newDB(), []byte("secret"))
199	store.MaxLength(10)
200
201	r1 := req(func(w http.ResponseWriter, r *http.Request) {
202		session, err := store.Get(r, "session")
203		if err != nil {
204			panic(err)
205		}
206
207		session.Values["a"] = "aaaaaaaaaaaaaaaaaaaaaaaa"
208		if err := store.Save(r, w, session); err == nil {
209			t.Error("Expected too large error")
210		}
211
212		http.Error(w, "", http.StatusOK)
213	}, nil)
214	match(t, r1, 200, "")
215}
216
217func TestTableName(t *testing.T) {
218	db := newDB()
219	store := NewOptions(db, Options{TableName: "abc"}, []byte("secret"))
220	countFn := makeCountHandler("session", store)
221
222	if !db.HasTable(&gormSession{tableName: store.opts.TableName}) {
223		t.Error("Expected abc table created")
224	}
225
226	r1 := req(countFn, nil)
227	match(t, r1, 200, "1")
228	r2 := req(countFn, parseCookies(r1.Header().Get("Set-Cookie"))["session"])
229	match(t, r2, 200, "2")
230
231	id := r2.Header().Get("X-Session")
232	s := findSession(db, store, id)
233	s.ExpiresAt = gorm.NowFunc().Add(-time.Duration(store.SessionOpts.MaxAge+1) * time.Second)
234	db.Save(s)
235
236	store.Cleanup()
237
238	if findSession(db, store, id) != nil {
239		t.Error("Expected session to be deleted")
240	}
241}
242
243func TestSkipCreateTable(t *testing.T) {
244	db := newDB()
245	store := NewOptions(db, Options{SkipCreateTable: true}, []byte("secret"))
246
247	if db.HasTable(&gormSession{tableName: store.opts.TableName}) {
248		t.Error("Expected no table created")
249	}
250}
251
252func TestMultiSessions(t *testing.T) {
253	store := New(newDB(), []byte("secret"))
254	countFn1 := makeCountHandler("session1", store)
255	countFn2 := makeCountHandler("session2", store)
256
257	r1 := req(countFn1, nil)
258	match(t, r1, 200, "1")
259	r2 := req(countFn2, nil)
260	match(t, r2, 200, "1")
261
262	r3 := req(countFn1, parseCookies(r1.Header().Get("Set-Cookie"))["session1"])
263	match(t, r3, 200, "2")
264	r4 := req(countFn2, parseCookies(r2.Header().Get("Set-Cookie"))["session2"])
265	match(t, r4, 200, "2")
266}
267
268func TestPeriodicCleanup(t *testing.T) {
269	db := newDB()
270	store := New(db, []byte("secret"))
271	store.SessionOpts.MaxAge = 1
272	countFn := makeCountHandler("session", store)
273
274	quit := make(chan struct{})
275	go store.PeriodicCleanup(200*time.Millisecond, quit)
276
277	// test that cleanup i done at least twice
278
279	r1 := req(countFn, nil)
280	id1 := r1.Header().Get("X-Session")
281
282	if findSession(db, store, id1) == nil {
283		t.Error("Expected r1 session to exist")
284	}
285
286	time.Sleep(2 * time.Second)
287
288	if findSession(db, store, id1) != nil {
289		t.Error("Expected r1 session to be deleted")
290	}
291
292	r2 := req(countFn, nil)
293	id2 := r2.Header().Get("X-Session")
294
295	if findSession(db, store, id2) == nil {
296		t.Error("Expected r2 session to exist")
297	}
298
299	time.Sleep(2 * time.Second)
300
301	if findSession(db, store, id2) != nil {
302		t.Error("Expected r2 session to be deleted")
303	}
304
305	close(quit)
306
307	// test that cleanup has stopped
308
309	r3 := req(countFn, nil)
310	id3 := r3.Header().Get("X-Session")
311
312	if findSession(db, store, id3) == nil {
313		t.Error("Expected r3 session to exist")
314	}
315
316	time.Sleep(2 * time.Second)
317
318	if findSession(db, store, id3) == nil {
319		t.Error("Expected r3 session to exist")
320	}
321}
322
323func TestMain(m *testing.M) {
324	flag.Parse()
325
326	if v := os.Getenv("DATABASE_URI"); v != "" {
327		dbURI = v
328	}
329	fmt.Printf("DATABASE_URI=%s\n", dbURI)
330
331	os.Exit(m.Run())
332}
333