1// Copyright 2012 The Gorilla Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package sessions 6 7import ( 8 "encoding/gob" 9 "fmt" 10 "net/http" 11 "time" 12 13 "github.com/gorilla/context" 14) 15 16// Default flashes key. 17const flashesKey = "_flash" 18 19// Session -------------------------------------------------------------------- 20 21// NewSession is called by session stores to create a new session instance. 22func NewSession(store Store, name string) *Session { 23 return &Session{ 24 Values: make(map[interface{}]interface{}), 25 store: store, 26 name: name, 27 Options: new(Options), 28 } 29} 30 31// Session stores the values and optional configuration for a session. 32type Session struct { 33 // The ID of the session, generated by stores. It should not be used for 34 // user data. 35 ID string 36 // Values contains the user-data for the session. 37 Values map[interface{}]interface{} 38 Options *Options 39 IsNew bool 40 store Store 41 name string 42} 43 44// Flashes returns a slice of flash messages from the session. 45// 46// A single variadic argument is accepted, and it is optional: it defines 47// the flash key. If not defined "_flash" is used by default. 48func (s *Session) Flashes(vars ...string) []interface{} { 49 var flashes []interface{} 50 key := flashesKey 51 if len(vars) > 0 { 52 key = vars[0] 53 } 54 if v, ok := s.Values[key]; ok { 55 // Drop the flashes and return it. 56 delete(s.Values, key) 57 flashes = v.([]interface{}) 58 } 59 return flashes 60} 61 62// AddFlash adds a flash message to the session. 63// 64// A single variadic argument is accepted, and it is optional: it defines 65// the flash key. If not defined "_flash" is used by default. 66func (s *Session) AddFlash(value interface{}, vars ...string) { 67 key := flashesKey 68 if len(vars) > 0 { 69 key = vars[0] 70 } 71 var flashes []interface{} 72 if v, ok := s.Values[key]; ok { 73 flashes = v.([]interface{}) 74 } 75 s.Values[key] = append(flashes, value) 76} 77 78// Save is a convenience method to save this session. It is the same as calling 79// store.Save(request, response, session). You should call Save before writing to 80// the response or returning from the handler. 81func (s *Session) Save(r *http.Request, w http.ResponseWriter) error { 82 return s.store.Save(r, w, s) 83} 84 85// Name returns the name used to register the session. 86func (s *Session) Name() string { 87 return s.name 88} 89 90// Store returns the session store used to register the session. 91func (s *Session) Store() Store { 92 return s.store 93} 94 95// Registry ------------------------------------------------------------------- 96 97// sessionInfo stores a session tracked by the registry. 98type sessionInfo struct { 99 s *Session 100 e error 101} 102 103// contextKey is the type used to store the registry in the context. 104type contextKey int 105 106// registryKey is the key used to store the registry in the context. 107const registryKey contextKey = 0 108 109// GetRegistry returns a registry instance for the current request. 110func GetRegistry(r *http.Request) *Registry { 111 registry := context.Get(r, registryKey) 112 if registry != nil { 113 return registry.(*Registry) 114 } 115 newRegistry := &Registry{ 116 request: r, 117 sessions: make(map[string]sessionInfo), 118 } 119 context.Set(r, registryKey, newRegistry) 120 return newRegistry 121} 122 123// Registry stores sessions used during a request. 124type Registry struct { 125 request *http.Request 126 sessions map[string]sessionInfo 127} 128 129// Get registers and returns a session for the given name and session store. 130// 131// It returns a new session if there are no sessions registered for the name. 132func (s *Registry) Get(store Store, name string) (session *Session, err error) { 133 if !isCookieNameValid(name) { 134 return nil, fmt.Errorf("sessions: invalid character in cookie name: %s", name) 135 } 136 if info, ok := s.sessions[name]; ok { 137 session, err = info.s, info.e 138 } else { 139 session, err = store.New(s.request, name) 140 session.name = name 141 s.sessions[name] = sessionInfo{s: session, e: err} 142 } 143 session.store = store 144 return 145} 146 147// Save saves all sessions registered for the current request. 148func (s *Registry) Save(w http.ResponseWriter) error { 149 var errMulti MultiError 150 for name, info := range s.sessions { 151 session := info.s 152 if session.store == nil { 153 errMulti = append(errMulti, fmt.Errorf( 154 "sessions: missing store for session %q", name)) 155 } else if err := session.store.Save(s.request, w, session); err != nil { 156 errMulti = append(errMulti, fmt.Errorf( 157 "sessions: error saving session %q -- %v", name, err)) 158 } 159 } 160 if errMulti != nil { 161 return errMulti 162 } 163 return nil 164} 165 166// Helpers -------------------------------------------------------------------- 167 168func init() { 169 gob.Register([]interface{}{}) 170} 171 172// Save saves all sessions used during the current request. 173func Save(r *http.Request, w http.ResponseWriter) error { 174 return GetRegistry(r).Save(w) 175} 176 177// NewCookie returns an http.Cookie with the options set. It also sets 178// the Expires field calculated based on the MaxAge value, for Internet 179// Explorer compatibility. 180func NewCookie(name, value string, options *Options) *http.Cookie { 181 cookie := newCookieFromOptions(name, value, options) 182 if options.MaxAge > 0 { 183 d := time.Duration(options.MaxAge) * time.Second 184 cookie.Expires = time.Now().Add(d) 185 } else if options.MaxAge < 0 { 186 // Set it to the past to expire now. 187 cookie.Expires = time.Unix(1, 0) 188 } 189 return cookie 190} 191 192// Error ---------------------------------------------------------------------- 193 194// MultiError stores multiple errors. 195// 196// Borrowed from the App Engine SDK. 197type MultiError []error 198 199func (m MultiError) Error() string { 200 s, n := "", 0 201 for _, e := range m { 202 if e != nil { 203 if n == 0 { 204 s = e.Error() 205 } 206 n++ 207 } 208 } 209 switch n { 210 case 0: 211 return "(0 errors)" 212 case 1: 213 return s 214 case 2: 215 return s + " (and 1 other error)" 216 } 217 return fmt.Sprintf("%s (and %d other errors)", s, n-1) 218} 219