1// Copyright 2012 The Gorilla Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package sessions
6
7import (
8	"encoding/gob"
9	"fmt"
10	"net/http"
11	"time"
12
13	"github.com/gorilla/context"
14)
15
16// Default flashes key.
17const flashesKey = "_flash"
18
19// Session --------------------------------------------------------------------
20
21// NewSession is called by session stores to create a new session instance.
22func NewSession(store Store, name string) *Session {
23	return &Session{
24		Values:  make(map[interface{}]interface{}),
25		store:   store,
26		name:    name,
27		Options: new(Options),
28	}
29}
30
31// Session stores the values and optional configuration for a session.
32type Session struct {
33	// The ID of the session, generated by stores. It should not be used for
34	// user data.
35	ID string
36	// Values contains the user-data for the session.
37	Values  map[interface{}]interface{}
38	Options *Options
39	IsNew   bool
40	store   Store
41	name    string
42}
43
44// Flashes returns a slice of flash messages from the session.
45//
46// A single variadic argument is accepted, and it is optional: it defines
47// the flash key. If not defined "_flash" is used by default.
48func (s *Session) Flashes(vars ...string) []interface{} {
49	var flashes []interface{}
50	key := flashesKey
51	if len(vars) > 0 {
52		key = vars[0]
53	}
54	if v, ok := s.Values[key]; ok {
55		// Drop the flashes and return it.
56		delete(s.Values, key)
57		flashes = v.([]interface{})
58	}
59	return flashes
60}
61
62// AddFlash adds a flash message to the session.
63//
64// A single variadic argument is accepted, and it is optional: it defines
65// the flash key. If not defined "_flash" is used by default.
66func (s *Session) AddFlash(value interface{}, vars ...string) {
67	key := flashesKey
68	if len(vars) > 0 {
69		key = vars[0]
70	}
71	var flashes []interface{}
72	if v, ok := s.Values[key]; ok {
73		flashes = v.([]interface{})
74	}
75	s.Values[key] = append(flashes, value)
76}
77
78// Save is a convenience method to save this session. It is the same as calling
79// store.Save(request, response, session). You should call Save before writing to
80// the response or returning from the handler.
81func (s *Session) Save(r *http.Request, w http.ResponseWriter) error {
82	return s.store.Save(r, w, s)
83}
84
85// Name returns the name used to register the session.
86func (s *Session) Name() string {
87	return s.name
88}
89
90// Store returns the session store used to register the session.
91func (s *Session) Store() Store {
92	return s.store
93}
94
95// Registry -------------------------------------------------------------------
96
97// sessionInfo stores a session tracked by the registry.
98type sessionInfo struct {
99	s *Session
100	e error
101}
102
103// contextKey is the type used to store the registry in the context.
104type contextKey int
105
106// registryKey is the key used to store the registry in the context.
107const registryKey contextKey = 0
108
109// GetRegistry returns a registry instance for the current request.
110func GetRegistry(r *http.Request) *Registry {
111	registry := context.Get(r, registryKey)
112	if registry != nil {
113		return registry.(*Registry)
114	}
115	newRegistry := &Registry{
116		request:  r,
117		sessions: make(map[string]sessionInfo),
118	}
119	context.Set(r, registryKey, newRegistry)
120	return newRegistry
121}
122
123// Registry stores sessions used during a request.
124type Registry struct {
125	request  *http.Request
126	sessions map[string]sessionInfo
127}
128
129// Get registers and returns a session for the given name and session store.
130//
131// It returns a new session if there are no sessions registered for the name.
132func (s *Registry) Get(store Store, name string) (session *Session, err error) {
133	if !isCookieNameValid(name) {
134		return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name)
135	}
136	if info, ok := s.sessions[name]; ok {
137		session, err = info.s, info.e
138	} else {
139		session, err = store.New(s.request, name)
140		session.name = name
141		s.sessions[name] = sessionInfo{s: session, e: err}
142	}
143	session.store = store
144	return
145}
146
147// Save saves all sessions registered for the current request.
148func (s *Registry) Save(w http.ResponseWriter) error {
149	var errMulti MultiError
150	for name, info := range s.sessions {
151		session := info.s
152		if session.store == nil {
153			errMulti = append(errMulti, fmt.Errorf(
154				"sessions: missing store for session %q", name))
155		} else if err := session.store.Save(s.request, w, session); err != nil {
156			errMulti = append(errMulti, fmt.Errorf(
157				"sessions: error saving session %q -- %v", name, err))
158		}
159	}
160	if errMulti != nil {
161		return errMulti
162	}
163	return nil
164}
165
166// Helpers --------------------------------------------------------------------
167
168func init() {
169	gob.Register([]interface{}{})
170}
171
172// Save saves all sessions used during the current request.
173func Save(r *http.Request, w http.ResponseWriter) error {
174	return GetRegistry(r).Save(w)
175}
176
177// NewCookie returns an http.Cookie with the options set. It also sets
178// the Expires field calculated based on the MaxAge value, for Internet
179// Explorer compatibility.
180func NewCookie(name, value string, options *Options) *http.Cookie {
181	cookie := newCookieFromOptions(name, value, options)
182	if options.MaxAge > 0 {
183		d := time.Duration(options.MaxAge) * time.Second
184		cookie.Expires = time.Now().Add(d)
185	} else if options.MaxAge < 0 {
186		// Set it to the past to expire now.
187		cookie.Expires = time.Unix(1, 0)
188	}
189	return cookie
190}
191
192// Error ----------------------------------------------------------------------
193
194// MultiError stores multiple errors.
195//
196// Borrowed from the App Engine SDK.
197type MultiError []error
198
199func (m MultiError) Error() string {
200	s, n := "", 0
201	for _, e := range m {
202		if e != nil {
203			if n == 0 {
204				s = e.Error()
205			}
206			n++
207		}
208	}
209	switch n {
210	case 0:
211		return "(0 errors)"
212	case 1:
213		return s
214	case 2:
215		return s + " (and 1 other error)"
216	}
217	return fmt.Sprintf("%s (and %d other errors)", s, n-1)
218}
219