1package api
2
3import (
4	"database/sql"
5	"fmt"
6	"net/http"
7	"strings"
8	"time"
9
10	"github.com/ansible-semaphore/semaphore/db"
11	"github.com/ansible-semaphore/semaphore/util"
12	"github.com/gorilla/context"
13)
14
15//nolint: gocyclo
16func authentication(next http.Handler) http.Handler {
17	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
18		var userID int
19
20		if authHeader := strings.ToLower(r.Header.Get("authorization")); len(authHeader) > 0 && strings.Contains(authHeader, "bearer") {
21			var token db.APIToken
22			if err := db.Mysql.SelectOne(&token, "select * from user__token where id=? and expired=0", strings.Replace(authHeader, "bearer ", "", 1)); err != nil {
23				if err == sql.ErrNoRows {
24					w.WriteHeader(http.StatusUnauthorized)
25					return
26				}
27
28				panic(err)
29			}
30
31			userID = token.UserID
32		} else {
33			// fetch session from cookie
34			cookie, err := r.Cookie("semaphore")
35			if err != nil {
36				w.WriteHeader(http.StatusUnauthorized)
37				return
38			}
39
40			value := make(map[string]interface{})
41			if err = util.Cookie.Decode("semaphore", cookie.Value, &value); err != nil {
42				w.WriteHeader(http.StatusUnauthorized)
43				return
44			}
45
46			user, ok := value["user"]
47			sessionVal, okSession := value["session"]
48			if !ok || !okSession {
49				w.WriteHeader(http.StatusUnauthorized)
50				return
51			}
52
53			userID = user.(int)
54			sessionID := sessionVal.(int)
55
56			// fetch session
57			var session db.Session
58			if err := db.Mysql.SelectOne(&session, "select * from session where id=? and user_id=? and expired=0", sessionID, userID); err != nil {
59				w.WriteHeader(http.StatusUnauthorized)
60				return
61			}
62
63			if time.Since(session.LastActive).Hours() > 7*24 {
64				// more than week old unused session
65				// destroy.
66				if _, err := db.Mysql.Exec("update session set expired=1 where id=?", sessionID); err != nil {
67					panic(err)
68				}
69
70				w.WriteHeader(http.StatusUnauthorized)
71				return
72			}
73
74			if _, err := db.Mysql.Exec("update session set last_active=UTC_TIMESTAMP() where id=?", sessionID); err != nil {
75				panic(err)
76			}
77		}
78
79		user, err := db.FetchUser(userID)
80		if err != nil {
81			fmt.Println("Can't find user", err)
82			w.WriteHeader(http.StatusUnauthorized)
83			return
84		}
85
86		context.Set(r, "user", user)
87
88		next.ServeHTTP(w, r)
89	})
90}
91