1/*
2 * Copyright © 2019-2021 A Bunch Tell LLC.
3 *
4 * This file is part of WriteFreely.
5 *
6 * WriteFreely is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License, included
8 * in the LICENSE file in this source code package.
9 */
10
11package writefreely
12
13import (
14	"context"
15	"encoding/json"
16	"fmt"
17	"io"
18	"io/ioutil"
19	"net/http"
20	"net/url"
21	"strings"
22	"time"
23
24	"github.com/gorilla/mux"
25	"github.com/gorilla/sessions"
26	"github.com/writeas/impart"
27	"github.com/writeas/web-core/log"
28	"github.com/writefreely/writefreely/config"
29)
30
31// OAuthButtons holds display information for different OAuth providers we support.
32type OAuthButtons struct {
33	SlackEnabled       bool
34	WriteAsEnabled     bool
35	GitLabEnabled      bool
36	GitLabDisplayName  string
37	GiteaEnabled       bool
38	GiteaDisplayName   string
39	GenericEnabled     bool
40	GenericDisplayName string
41}
42
43// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration.
44func NewOAuthButtons(cfg *config.Config) *OAuthButtons {
45	return &OAuthButtons{
46		SlackEnabled:       cfg.SlackOauth.ClientID != "",
47		WriteAsEnabled:     cfg.WriteAsOauth.ClientID != "",
48		GitLabEnabled:      cfg.GitlabOauth.ClientID != "",
49		GitLabDisplayName:  config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName),
50		GiteaEnabled:       cfg.GiteaOauth.ClientID != "",
51		GiteaDisplayName:   config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName),
52		GenericEnabled:     cfg.GenericOauth.ClientID != "",
53		GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName),
54	}
55}
56
57// TokenResponse contains data returned when a token is created either
58// through a code exchange or using a refresh token.
59type TokenResponse struct {
60	AccessToken  string `json:"access_token"`
61	ExpiresIn    int    `json:"expires_in"`
62	RefreshToken string `json:"refresh_token"`
63	TokenType    string `json:"token_type"`
64	Error        string `json:"error"`
65}
66
67// InspectResponse contains data returned when an access token is inspected.
68type InspectResponse struct {
69	ClientID    string    `json:"client_id"`
70	UserID      string    `json:"user_id"`
71	ExpiresAt   time.Time `json:"expires_at"`
72	Username    string    `json:"username"`
73	DisplayName string    `json:"-"`
74	Email       string    `json:"email"`
75	Error       string    `json:"error"`
76}
77
78// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token
79// endpoint. One megabyte is plenty.
80const tokenRequestMaxLen = 1000000
81
82// infoRequestMaxLen is the most bytes that we'll read from the
83// /oauth/inspect endpoint.
84const infoRequestMaxLen = 1000000
85
86// OAuthDatastoreProvider provides a minimal interface of data store, config,
87// and session store for use with the oauth handlers.
88type OAuthDatastoreProvider interface {
89	DB() OAuthDatastore
90	Config() *config.Config
91	SessionStore() sessions.Store
92}
93
94// OAuthDatastore provides a minimal interface of data store methods used in
95// oauth functionality.
96type OAuthDatastore interface {
97	GetIDForRemoteUser(context.Context, string, string, string) (int64, error)
98	RecordRemoteUserID(context.Context, int64, string, string, string, string) error
99	ValidateOAuthState(context.Context, string) (string, string, int64, string, error)
100	GenerateOAuthState(context.Context, string, string, int64, string) (string, error)
101
102	CreateUser(*config.Config, *User, string, string) error
103	GetUserByID(int64) (*User, error)
104}
105
106type HttpClient interface {
107	Do(req *http.Request) (*http.Response, error)
108}
109
110type oauthClient interface {
111	GetProvider() string
112	GetClientID() string
113	GetCallbackLocation() string
114	buildLoginURL(state string) (string, error)
115	exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error)
116	inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error)
117}
118
119type callbackProxyClient struct {
120	server           string
121	callbackLocation string
122	httpClient       HttpClient
123}
124
125type oauthHandler struct {
126	Config        *config.Config
127	DB            OAuthDatastore
128	Store         sessions.Store
129	EmailKey      []byte
130	oauthClient   oauthClient
131	callbackProxy *callbackProxyClient
132}
133
134func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error {
135	ctx := r.Context()
136
137	var attachUser int64
138	if attach := r.URL.Query().Get("attach"); attach == "t" {
139		user, _ := getUserAndSession(app, r)
140		if user == nil {
141			return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"}
142		}
143		attachUser = user.ID
144	}
145
146	state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code"))
147	if err != nil {
148		log.Error("viewOauthInit error: %s", err)
149		return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
150	}
151
152	if h.callbackProxy != nil {
153		if err := h.callbackProxy.register(ctx, state); err != nil {
154			log.Error("viewOauthInit error: %s", err)
155			return impart.HTTPError{http.StatusInternalServerError, "could not register state server"}
156		}
157	}
158
159	location, err := h.oauthClient.buildLoginURL(state)
160	if err != nil {
161		log.Error("viewOauthInit error: %s", err)
162		return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"}
163	}
164	return impart.HTTPError{http.StatusTemporaryRedirect, location}
165}
166
167func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) {
168	if app.Config().SlackOauth.ClientID != "" {
169		callbackLocation := app.Config().App.Host + "/oauth/callback/slack"
170
171		var stateRegisterClient *callbackProxyClient = nil
172		if app.Config().SlackOauth.CallbackProxyAPI != "" {
173			stateRegisterClient = &callbackProxyClient{
174				server:           app.Config().SlackOauth.CallbackProxyAPI,
175				callbackLocation: app.Config().App.Host + "/oauth/callback/slack",
176				httpClient:       config.DefaultHTTPClient(),
177			}
178			callbackLocation = app.Config().SlackOauth.CallbackProxy
179		}
180		oauthClient := slackOauthClient{
181			ClientID:         app.Config().SlackOauth.ClientID,
182			ClientSecret:     app.Config().SlackOauth.ClientSecret,
183			TeamID:           app.Config().SlackOauth.TeamID,
184			HttpClient:       config.DefaultHTTPClient(),
185			CallbackLocation: callbackLocation,
186		}
187		configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient)
188	}
189}
190
191func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) {
192	if app.Config().WriteAsOauth.ClientID != "" {
193		callbackLocation := app.Config().App.Host + "/oauth/callback/write.as"
194
195		var callbackProxy *callbackProxyClient = nil
196		if app.Config().WriteAsOauth.CallbackProxy != "" {
197			callbackProxy = &callbackProxyClient{
198				server:           app.Config().WriteAsOauth.CallbackProxyAPI,
199				callbackLocation: app.Config().App.Host + "/oauth/callback/write.as",
200				httpClient:       config.DefaultHTTPClient(),
201			}
202			callbackLocation = app.Config().WriteAsOauth.CallbackProxy
203		}
204
205		oauthClient := writeAsOauthClient{
206			ClientID:         app.Config().WriteAsOauth.ClientID,
207			ClientSecret:     app.Config().WriteAsOauth.ClientSecret,
208			ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation),
209			InspectLocation:  config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation),
210			AuthLocation:     config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation),
211			HttpClient:       config.DefaultHTTPClient(),
212			CallbackLocation: callbackLocation,
213		}
214		configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
215	}
216}
217
218func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) {
219	if app.Config().GitlabOauth.ClientID != "" {
220		callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab"
221
222		var callbackProxy *callbackProxyClient = nil
223		if app.Config().GitlabOauth.CallbackProxy != "" {
224			callbackProxy = &callbackProxyClient{
225				server:           app.Config().GitlabOauth.CallbackProxyAPI,
226				callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab",
227				httpClient:       config.DefaultHTTPClient(),
228			}
229			callbackLocation = app.Config().GitlabOauth.CallbackProxy
230		}
231
232		address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost)
233		oauthClient := gitlabOauthClient{
234			ClientID:         app.Config().GitlabOauth.ClientID,
235			ClientSecret:     app.Config().GitlabOauth.ClientSecret,
236			ExchangeLocation: address + "/oauth/token",
237			InspectLocation:  address + "/api/v4/user",
238			AuthLocation:     address + "/oauth/authorize",
239			HttpClient:       config.DefaultHTTPClient(),
240			CallbackLocation: callbackLocation,
241		}
242		configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
243	}
244}
245
246func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) {
247	if app.Config().GenericOauth.ClientID != "" {
248		callbackLocation := app.Config().App.Host + "/oauth/callback/generic"
249
250		var callbackProxy *callbackProxyClient = nil
251		if app.Config().GenericOauth.CallbackProxy != "" {
252			callbackProxy = &callbackProxyClient{
253				server:           app.Config().GenericOauth.CallbackProxyAPI,
254				callbackLocation: app.Config().App.Host + "/oauth/callback/generic",
255				httpClient:       config.DefaultHTTPClient(),
256			}
257			callbackLocation = app.Config().GenericOauth.CallbackProxy
258		}
259
260		oauthClient := genericOauthClient{
261			ClientID:         app.Config().GenericOauth.ClientID,
262			ClientSecret:     app.Config().GenericOauth.ClientSecret,
263			ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint,
264			InspectLocation:  app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint,
265			AuthLocation:     app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint,
266			HttpClient:       config.DefaultHTTPClient(),
267			CallbackLocation: callbackLocation,
268			Scope:            config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"),
269			MapUserID:        config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"),
270			MapUsername:      config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"),
271			MapDisplayName:   config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"),
272			MapEmail:         config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"),
273		}
274		configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
275	}
276}
277
278func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) {
279	if app.Config().GiteaOauth.ClientID != "" {
280		callbackLocation := app.Config().App.Host + "/oauth/callback/gitea"
281
282		var callbackProxy *callbackProxyClient = nil
283		if app.Config().GiteaOauth.CallbackProxy != "" {
284			callbackProxy = &callbackProxyClient{
285				server:           app.Config().GiteaOauth.CallbackProxyAPI,
286				callbackLocation: app.Config().App.Host + "/oauth/callback/gitea",
287				httpClient:       config.DefaultHTTPClient(),
288			}
289			callbackLocation = app.Config().GiteaOauth.CallbackProxy
290		}
291
292		oauthClient := giteaOauthClient{
293			ClientID:         app.Config().GiteaOauth.ClientID,
294			ClientSecret:     app.Config().GiteaOauth.ClientSecret,
295			ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token",
296			InspectLocation:  app.Config().GiteaOauth.Host + "/api/v1/user",
297			AuthLocation:     app.Config().GiteaOauth.Host + "/login/oauth/authorize",
298			HttpClient:       config.DefaultHTTPClient(),
299			CallbackLocation: callbackLocation,
300		}
301		configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy)
302	}
303}
304
305func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) {
306	handler := &oauthHandler{
307		Config:        app.Config(),
308		DB:            app.DB(),
309		Store:         app.SessionStore(),
310		oauthClient:   oauthClient,
311		EmailKey:      app.keys.EmailKey,
312		callbackProxy: callbackProxy,
313	}
314	r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET")
315	r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET")
316	r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST")
317}
318
319func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error {
320	ctx := r.Context()
321
322	code := r.FormValue("code")
323	state := r.FormValue("state")
324
325	provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state)
326	if err != nil {
327		log.Error("Unable to ValidateOAuthState: %s", err)
328		return impart.HTTPError{http.StatusInternalServerError, err.Error()}
329	}
330
331	tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code)
332	if err != nil {
333		log.Error("Unable to exchangeOauthCode: %s", err)
334		// TODO: show user friendly message if needed
335		// TODO: show NO message for cases like user pressing "Cancel" on authorize step
336		addSessionFlash(app, w, r, err.Error(), nil)
337		if attachUserID > 0 {
338			return impart.HTTPError{http.StatusFound, "/me/settings"}
339		}
340		return impart.HTTPError{http.StatusInternalServerError, err.Error()}
341	}
342
343	// Now that we have the access token, let's use it real quick to make sure
344	// it really really works.
345	tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken)
346	if err != nil {
347		log.Error("Unable to inspectOauthAccessToken: %s", err)
348		return impart.HTTPError{http.StatusInternalServerError, err.Error()}
349	}
350
351	localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID)
352	if err != nil {
353		log.Error("Unable to GetIDForRemoteUser: %s", err)
354		return impart.HTTPError{http.StatusInternalServerError, err.Error()}
355	}
356
357	if localUserID != -1 && attachUserID > 0 {
358		if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil {
359			return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()}
360		}
361		return impart.HTTPError{http.StatusFound, "/me/settings"}
362	}
363
364	if localUserID != -1 {
365		// Existing user, so log in now
366		user, err := h.DB.GetUserByID(localUserID)
367		if err != nil {
368			log.Error("Unable to GetUserByID %d: %s", localUserID, err)
369			return impart.HTTPError{http.StatusInternalServerError, err.Error()}
370		}
371		if err = loginOrFail(h.Store, w, r, user); err != nil {
372			log.Error("Unable to loginOrFail %d: %s", localUserID, err)
373			return impart.HTTPError{http.StatusInternalServerError, err.Error()}
374		}
375		return nil
376	}
377	if attachUserID > 0 {
378		log.Info("attaching to user %d", attachUserID)
379		err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken)
380		if err != nil {
381			return impart.HTTPError{http.StatusInternalServerError, err.Error()}
382		}
383		return impart.HTTPError{http.StatusFound, "/me/settings"}
384	}
385
386	// New user registration below.
387	// First, verify that user is allowed to register
388	if inviteCode != "" {
389		// Verify invite code is valid
390		i, err := app.db.GetUserInvite(inviteCode)
391		if err != nil {
392			return impart.HTTPError{http.StatusInternalServerError, err.Error()}
393		}
394		if !i.Active(app.db) {
395			return impart.HTTPError{http.StatusNotFound, "Invite link has expired."}
396		}
397	} else if !app.cfg.App.OpenRegistration {
398		addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil)
399		return impart.HTTPError{http.StatusFound, "/login"}
400	}
401
402	displayName := tokenInfo.DisplayName
403	if len(displayName) == 0 {
404		displayName = tokenInfo.Username
405	}
406
407	tp := &oauthSignupPageParams{
408		AccessToken:     tokenResponse.AccessToken,
409		TokenUsername:   tokenInfo.Username,
410		TokenAlias:      tokenInfo.DisplayName,
411		TokenEmail:      tokenInfo.Email,
412		TokenRemoteUser: tokenInfo.UserID,
413		Provider:        provider,
414		ClientID:        clientID,
415		InviteCode:      inviteCode,
416	}
417	tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
418
419	return h.showOauthSignupPage(app, w, r, tp, nil)
420}
421
422func (r *callbackProxyClient) register(ctx context.Context, state string) error {
423	form := url.Values{}
424	form.Add("state", state)
425	form.Add("location", r.callbackLocation)
426	req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode()))
427	if err != nil {
428		return err
429	}
430	req.Header.Set("User-Agent", ServerUserAgent(""))
431	req.Header.Set("Accept", "application/json")
432	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
433
434	resp, err := r.httpClient.Do(req)
435	if err != nil {
436		return err
437	}
438	if resp.StatusCode != http.StatusCreated {
439		return fmt.Errorf("unable register state location: %d", resp.StatusCode)
440	}
441
442	return nil
443}
444
445func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error {
446	lr := io.LimitReader(body, int64(n+1))
447	data, err := ioutil.ReadAll(lr)
448	if err != nil {
449		return err
450	}
451	if len(data) == n+1 {
452		return fmt.Errorf("content larger than max read allowance: %d", n)
453	}
454	return json.Unmarshal(data, thing)
455}
456
457func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error {
458	// An error may be returned, but a valid session should always be returned.
459	session, _ := store.Get(r, cookieName)
460	session.Values[cookieUserVal] = user.Cookie()
461	if err := session.Save(r, w); err != nil {
462		fmt.Println("error saving session", err)
463		return err
464	}
465	http.Redirect(w, r, "/", http.StatusTemporaryRedirect)
466	return nil
467}
468