1// Copyright 2012 The Gorilla Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sessions
6
7import (
8	"bytes"
9	"encoding/gob"
10	"net/http"
11	"net/http/httptest"
12	"testing"
13)
14
15// NewRecorder returns an initialized ResponseRecorder.
16func NewRecorder() *httptest.ResponseRecorder {
17	return &httptest.ResponseRecorder{
18		HeaderMap: make(http.Header),
19		Body:      new(bytes.Buffer),
20	}
21}
22
23// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
24// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
25const DefaultRemoteAddr = "1.2.3.4"
26
27// ----------------------------------------------------------------------------
28
29type FlashMessage struct {
30	Type    int
31	Message string
32}
33
34func TestFlashes(t *testing.T) {
35	var req *http.Request
36	var rsp *httptest.ResponseRecorder
37	var hdr http.Header
38	var err error
39	var ok bool
40	var cookies []string
41	var session *Session
42	var flashes []interface{}
43
44	store := NewCookieStore([]byte("secret-key"))
45
46	// Round 1 ----------------------------------------------------------------
47
48	req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
49	rsp = NewRecorder()
50	// Get a session.
51	if session, err = store.Get(req, "session-key"); err != nil {
52		t.Fatalf("Error getting session: %v", err)
53	}
54	// Get a flash.
55	flashes = session.Flashes()
56	if len(flashes) != 0 {
57		t.Errorf("Expected empty flashes; Got %v", flashes)
58	}
59	// Add some flashes.
60	session.AddFlash("foo")
61	session.AddFlash("bar")
62	// Custom key.
63	session.AddFlash("baz", "custom_key")
64	// Save.
65	if err = Save(req, rsp); err != nil {
66		t.Fatalf("Error saving session: %v", err)
67	}
68	hdr = rsp.Header()
69	cookies, ok = hdr["Set-Cookie"]
70	if !ok || len(cookies) != 1 {
71		t.Fatal("No cookies. Header:", hdr)
72	}
73
74	if _, err = store.Get(req, "session:key"); err.Error() != "sessions: invalid character in cookie name: session:key" {
75		t.Fatalf("Expected error due to invalid cookie name")
76	}
77
78	// Round 2 ----------------------------------------------------------------
79
80	req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
81	req.Header.Add("Cookie", cookies[0])
82	rsp = NewRecorder()
83	// Get a session.
84	if session, err = store.Get(req, "session-key"); err != nil {
85		t.Fatalf("Error getting session: %v", err)
86	}
87	// Check all saved values.
88	flashes = session.Flashes()
89	if len(flashes) != 2 {
90		t.Fatalf("Expected flashes; Got %v", flashes)
91	}
92	if flashes[0] != "foo" || flashes[1] != "bar" {
93		t.Errorf("Expected foo,bar; Got %v", flashes)
94	}
95	flashes = session.Flashes()
96	if len(flashes) != 0 {
97		t.Errorf("Expected dumped flashes; Got %v", flashes)
98	}
99	// Custom key.
100	flashes = session.Flashes("custom_key")
101	if len(flashes) != 1 {
102		t.Errorf("Expected flashes; Got %v", flashes)
103	} else if flashes[0] != "baz" {
104		t.Errorf("Expected baz; Got %v", flashes)
105	}
106	flashes = session.Flashes("custom_key")
107	if len(flashes) != 0 {
108		t.Errorf("Expected dumped flashes; Got %v", flashes)
109	}
110
111	// Round 3 ----------------------------------------------------------------
112	// Custom type
113
114	req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
115	rsp = NewRecorder()
116	// Get a session.
117	if session, err = store.Get(req, "session-key"); err != nil {
118		t.Fatalf("Error getting session: %v", err)
119	}
120	// Get a flash.
121	flashes = session.Flashes()
122	if len(flashes) != 0 {
123		t.Errorf("Expected empty flashes; Got %v", flashes)
124	}
125	// Add some flashes.
126	session.AddFlash(&FlashMessage{42, "foo"})
127	// Save.
128	if err = Save(req, rsp); err != nil {
129		t.Fatalf("Error saving session: %v", err)
130	}
131	hdr = rsp.Header()
132	cookies, ok = hdr["Set-Cookie"]
133	if !ok || len(cookies) != 1 {
134		t.Fatal("No cookies. Header:", hdr)
135	}
136
137	// Round 4 ----------------------------------------------------------------
138	// Custom type
139
140	req, _ = http.NewRequest("GET", "http://localhost:8080/", nil)
141	req.Header.Add("Cookie", cookies[0])
142	rsp = NewRecorder()
143	// Get a session.
144	if session, err = store.Get(req, "session-key"); err != nil {
145		t.Fatalf("Error getting session: %v", err)
146	}
147	// Check all saved values.
148	flashes = session.Flashes()
149	if len(flashes) != 1 {
150		t.Fatalf("Expected flashes; Got %v", flashes)
151	}
152	custom := flashes[0].(FlashMessage)
153	if custom.Type != 42 || custom.Message != "foo" {
154		t.Errorf("Expected %#v, got %#v", FlashMessage{42, "foo"}, custom)
155	}
156}
157
158func TestCookieStoreMapPanic(t *testing.T) {
159	defer func() {
160		err := recover()
161		if err != nil {
162			t.Fatal(err)
163		}
164	}()
165
166	store := NewCookieStore([]byte("aaa0defe5d2839cbc46fc4f080cd7adc"))
167	req, err := http.NewRequest("GET", "http://www.example.com", nil)
168	if err != nil {
169		t.Fatal("failed to create request", err)
170	}
171	w := httptest.NewRecorder()
172
173	session := NewSession(store, "hello")
174
175	session.Values["data"] = "hello-world"
176
177	err = session.Save(req, w)
178	if err != nil {
179		t.Fatal("failed to save session", err)
180	}
181}
182
183func init() {
184	gob.Register(FlashMessage{})
185}
186