1// Copyright 2013 Beego Authors 2// Copyright 2014 The Macaron Authors 3// 4// Licensed under the Apache License, Version 2.0 (the "License"): you may 5// not use this file except in compliance with the License. You may obtain 6// a copy of the License at 7// 8// http://www.apache.org/licenses/LICENSE-2.0 9// 10// Unless required by applicable law or agreed to in writing, software 11// distributed under the License is distributed on an "AS IS" BASIS, WITHOUT 12// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the 13// License for the specific language governing permissions and limitations 14// under the License. 15 16package session 17 18import ( 19 "database/sql" 20 "fmt" 21 "log" 22 "sync" 23 "time" 24 25 "gitea.com/go-chi/session" 26 _ "github.com/go-sql-driver/mysql" 27) 28 29// MysqlStore represents a mysql session store implementation. 30type MysqlStore struct { 31 c *sql.DB 32 sid string 33 lock sync.RWMutex 34 data map[interface{}]interface{} 35} 36 37// NewMysqlStore creates and returns a mysql session store. 38func NewMysqlStore(c *sql.DB, sid string, kv map[interface{}]interface{}) *MysqlStore { 39 return &MysqlStore{ 40 c: c, 41 sid: sid, 42 data: kv, 43 } 44} 45 46// Set sets value to given key in session. 47func (s *MysqlStore) Set(key, val interface{}) error { 48 s.lock.Lock() 49 defer s.lock.Unlock() 50 51 s.data[key] = val 52 return nil 53} 54 55// Get gets value by given key in session. 56func (s *MysqlStore) Get(key interface{}) interface{} { 57 s.lock.RLock() 58 defer s.lock.RUnlock() 59 60 return s.data[key] 61} 62 63// Delete delete a key from session. 64func (s *MysqlStore) Delete(key interface{}) error { 65 s.lock.Lock() 66 defer s.lock.Unlock() 67 68 delete(s.data, key) 69 return nil 70} 71 72// ID returns current session ID. 73func (s *MysqlStore) ID() string { 74 return s.sid 75} 76 77// Release releases resource and save data to provider. 78func (s *MysqlStore) Release() error { 79 // Skip encoding if the data is empty 80 if len(s.data) == 0 { 81 return nil 82 } 83 84 data, err := session.EncodeGob(s.data) 85 if err != nil { 86 return err 87 } 88 89 _, err = s.c.Exec("UPDATE session SET data=?, expiry=? WHERE `key`=?", 90 data, time.Now().Unix(), s.sid) 91 return err 92} 93 94// Flush deletes all session data. 95func (s *MysqlStore) Flush() error { 96 s.lock.Lock() 97 defer s.lock.Unlock() 98 99 s.data = make(map[interface{}]interface{}) 100 return nil 101} 102 103// MysqlProvider represents a mysql session provider implementation. 104type MysqlProvider struct { 105 c *sql.DB 106 expire int64 107} 108 109// Init initializes mysql session provider. 110// connStr: username:password@protocol(address)/dbname?param=value 111func (p *MysqlProvider) Init(expire int64, connStr string) (err error) { 112 p.expire = expire 113 114 p.c, err = sql.Open("mysql", connStr) 115 if err != nil { 116 return err 117 } 118 return p.c.Ping() 119} 120 121// Read returns raw session store by session ID. 122func (p *MysqlProvider) Read(sid string) (session.RawStore, error) { 123 now := time.Now().Unix() 124 var data []byte 125 expiry := now 126 err := p.c.QueryRow("SELECT data, expiry FROM session WHERE `key`=?", sid).Scan(&data, &expiry) 127 if err == sql.ErrNoRows { 128 _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", 129 sid, "", now) 130 } 131 if err != nil { 132 return nil, err 133 } 134 135 var kv map[interface{}]interface{} 136 if len(data) == 0 || expiry+p.expire <= now { 137 kv = make(map[interface{}]interface{}) 138 } else { 139 kv, err = session.DecodeGob(data) 140 if err != nil { 141 return nil, err 142 } 143 } 144 145 return NewMysqlStore(p.c, sid, kv), nil 146} 147 148// Exist returns true if session with given ID exists. 149func (p *MysqlProvider) Exist(sid string) bool { 150 var data []byte 151 err := p.c.QueryRow("SELECT data FROM session WHERE `key`=?", sid).Scan(&data) 152 if err != nil && err != sql.ErrNoRows { 153 panic("session/mysql: error checking existence: " + err.Error()) 154 } 155 return err != sql.ErrNoRows 156} 157 158// Destroy deletes a session by session ID. 159func (p *MysqlProvider) Destroy(sid string) error { 160 _, err := p.c.Exec("DELETE FROM session WHERE `key`=?", sid) 161 return err 162} 163 164// Regenerate regenerates a session store from old session ID to new one. 165func (p *MysqlProvider) Regenerate(oldsid, sid string) (_ session.RawStore, err error) { 166 if p.Exist(sid) { 167 return nil, fmt.Errorf("new sid '%s' already exists", sid) 168 } 169 170 if !p.Exist(oldsid) { 171 if _, err = p.c.Exec("INSERT INTO session(`key`,data,expiry) VALUES(?,?,?)", 172 oldsid, "", time.Now().Unix()); err != nil { 173 return nil, err 174 } 175 } 176 177 if _, err = p.c.Exec("UPDATE session SET `key`=? WHERE `key`=?", sid, oldsid); err != nil { 178 return nil, err 179 } 180 181 return p.Read(sid) 182} 183 184// Count counts and returns number of sessions. 185func (p *MysqlProvider) Count() (total int) { 186 if err := p.c.QueryRow("SELECT COUNT(*) AS NUM FROM session").Scan(&total); err != nil { 187 panic("session/mysql: error counting records: " + err.Error()) 188 } 189 return total 190} 191 192// GC calls GC to clean expired sessions. 193func (p *MysqlProvider) GC() { 194 if _, err := p.c.Exec("DELETE FROM session WHERE expiry + ? <= UNIX_TIMESTAMP(NOW())", p.expire); err != nil { 195 log.Printf("session/mysql: error garbage collecting: %v", err) 196 } 197} 198 199func init() { 200 session.Register("mysql", &MysqlProvider{}) 201} 202