1/* 2 * Copyright © 2019-2021 A Bunch Tell LLC. 3 * 4 * This file is part of WriteFreely. 5 * 6 * WriteFreely is free software: you can redistribute it and/or modify 7 * it under the terms of the GNU Affero General Public License, included 8 * in the LICENSE file in this source code package. 9 */ 10 11package writefreely 12 13import ( 14 "context" 15 "encoding/json" 16 "fmt" 17 "io" 18 "io/ioutil" 19 "net/http" 20 "net/url" 21 "strings" 22 "time" 23 24 "github.com/gorilla/mux" 25 "github.com/gorilla/sessions" 26 "github.com/writeas/impart" 27 "github.com/writeas/web-core/log" 28 "github.com/writefreely/writefreely/config" 29) 30 31// OAuthButtons holds display information for different OAuth providers we support. 32type OAuthButtons struct { 33 SlackEnabled bool 34 WriteAsEnabled bool 35 GitLabEnabled bool 36 GitLabDisplayName string 37 GiteaEnabled bool 38 GiteaDisplayName string 39 GenericEnabled bool 40 GenericDisplayName string 41} 42 43// NewOAuthButtons creates a new OAuthButtons struct based on our app configuration. 44func NewOAuthButtons(cfg *config.Config) *OAuthButtons { 45 return &OAuthButtons{ 46 SlackEnabled: cfg.SlackOauth.ClientID != "", 47 WriteAsEnabled: cfg.WriteAsOauth.ClientID != "", 48 GitLabEnabled: cfg.GitlabOauth.ClientID != "", 49 GitLabDisplayName: config.OrDefaultString(cfg.GitlabOauth.DisplayName, gitlabDisplayName), 50 GiteaEnabled: cfg.GiteaOauth.ClientID != "", 51 GiteaDisplayName: config.OrDefaultString(cfg.GiteaOauth.DisplayName, giteaDisplayName), 52 GenericEnabled: cfg.GenericOauth.ClientID != "", 53 GenericDisplayName: config.OrDefaultString(cfg.GenericOauth.DisplayName, genericOauthDisplayName), 54 } 55} 56 57// TokenResponse contains data returned when a token is created either 58// through a code exchange or using a refresh token. 59type TokenResponse struct { 60 AccessToken string `json:"access_token"` 61 ExpiresIn int `json:"expires_in"` 62 RefreshToken string `json:"refresh_token"` 63 TokenType string `json:"token_type"` 64 Error string `json:"error"` 65} 66 67// InspectResponse contains data returned when an access token is inspected. 68type InspectResponse struct { 69 ClientID string `json:"client_id"` 70 UserID string `json:"user_id"` 71 ExpiresAt time.Time `json:"expires_at"` 72 Username string `json:"username"` 73 DisplayName string `json:"-"` 74 Email string `json:"email"` 75 Error string `json:"error"` 76} 77 78// tokenRequestMaxLen is the most bytes that we'll read from the /oauth/token 79// endpoint. One megabyte is plenty. 80const tokenRequestMaxLen = 1000000 81 82// infoRequestMaxLen is the most bytes that we'll read from the 83// /oauth/inspect endpoint. 84const infoRequestMaxLen = 1000000 85 86// OAuthDatastoreProvider provides a minimal interface of data store, config, 87// and session store for use with the oauth handlers. 88type OAuthDatastoreProvider interface { 89 DB() OAuthDatastore 90 Config() *config.Config 91 SessionStore() sessions.Store 92} 93 94// OAuthDatastore provides a minimal interface of data store methods used in 95// oauth functionality. 96type OAuthDatastore interface { 97 GetIDForRemoteUser(context.Context, string, string, string) (int64, error) 98 RecordRemoteUserID(context.Context, int64, string, string, string, string) error 99 ValidateOAuthState(context.Context, string) (string, string, int64, string, error) 100 GenerateOAuthState(context.Context, string, string, int64, string) (string, error) 101 102 CreateUser(*config.Config, *User, string, string) error 103 GetUserByID(int64) (*User, error) 104} 105 106type HttpClient interface { 107 Do(req *http.Request) (*http.Response, error) 108} 109 110type oauthClient interface { 111 GetProvider() string 112 GetClientID() string 113 GetCallbackLocation() string 114 buildLoginURL(state string) (string, error) 115 exchangeOauthCode(ctx context.Context, code string) (*TokenResponse, error) 116 inspectOauthAccessToken(ctx context.Context, accessToken string) (*InspectResponse, error) 117} 118 119type callbackProxyClient struct { 120 server string 121 callbackLocation string 122 httpClient HttpClient 123} 124 125type oauthHandler struct { 126 Config *config.Config 127 DB OAuthDatastore 128 Store sessions.Store 129 EmailKey []byte 130 oauthClient oauthClient 131 callbackProxy *callbackProxyClient 132} 133 134func (h oauthHandler) viewOauthInit(app *App, w http.ResponseWriter, r *http.Request) error { 135 ctx := r.Context() 136 137 var attachUser int64 138 if attach := r.URL.Query().Get("attach"); attach == "t" { 139 user, _ := getUserAndSession(app, r) 140 if user == nil { 141 return impart.HTTPError{http.StatusInternalServerError, "cannot attach auth to user: user not found in session"} 142 } 143 attachUser = user.ID 144 } 145 146 state, err := h.DB.GenerateOAuthState(ctx, h.oauthClient.GetProvider(), h.oauthClient.GetClientID(), attachUser, r.FormValue("invite_code")) 147 if err != nil { 148 log.Error("viewOauthInit error: %s", err) 149 return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} 150 } 151 152 if h.callbackProxy != nil { 153 if err := h.callbackProxy.register(ctx, state); err != nil { 154 log.Error("viewOauthInit error: %s", err) 155 return impart.HTTPError{http.StatusInternalServerError, "could not register state server"} 156 } 157 } 158 159 location, err := h.oauthClient.buildLoginURL(state) 160 if err != nil { 161 log.Error("viewOauthInit error: %s", err) 162 return impart.HTTPError{http.StatusInternalServerError, "could not prepare oauth redirect url"} 163 } 164 return impart.HTTPError{http.StatusTemporaryRedirect, location} 165} 166 167func configureSlackOauth(parentHandler *Handler, r *mux.Router, app *App) { 168 if app.Config().SlackOauth.ClientID != "" { 169 callbackLocation := app.Config().App.Host + "/oauth/callback/slack" 170 171 var stateRegisterClient *callbackProxyClient = nil 172 if app.Config().SlackOauth.CallbackProxyAPI != "" { 173 stateRegisterClient = &callbackProxyClient{ 174 server: app.Config().SlackOauth.CallbackProxyAPI, 175 callbackLocation: app.Config().App.Host + "/oauth/callback/slack", 176 httpClient: config.DefaultHTTPClient(), 177 } 178 callbackLocation = app.Config().SlackOauth.CallbackProxy 179 } 180 oauthClient := slackOauthClient{ 181 ClientID: app.Config().SlackOauth.ClientID, 182 ClientSecret: app.Config().SlackOauth.ClientSecret, 183 TeamID: app.Config().SlackOauth.TeamID, 184 HttpClient: config.DefaultHTTPClient(), 185 CallbackLocation: callbackLocation, 186 } 187 configureOauthRoutes(parentHandler, r, app, oauthClient, stateRegisterClient) 188 } 189} 190 191func configureWriteAsOauth(parentHandler *Handler, r *mux.Router, app *App) { 192 if app.Config().WriteAsOauth.ClientID != "" { 193 callbackLocation := app.Config().App.Host + "/oauth/callback/write.as" 194 195 var callbackProxy *callbackProxyClient = nil 196 if app.Config().WriteAsOauth.CallbackProxy != "" { 197 callbackProxy = &callbackProxyClient{ 198 server: app.Config().WriteAsOauth.CallbackProxyAPI, 199 callbackLocation: app.Config().App.Host + "/oauth/callback/write.as", 200 httpClient: config.DefaultHTTPClient(), 201 } 202 callbackLocation = app.Config().WriteAsOauth.CallbackProxy 203 } 204 205 oauthClient := writeAsOauthClient{ 206 ClientID: app.Config().WriteAsOauth.ClientID, 207 ClientSecret: app.Config().WriteAsOauth.ClientSecret, 208 ExchangeLocation: config.OrDefaultString(app.Config().WriteAsOauth.TokenLocation, writeAsExchangeLocation), 209 InspectLocation: config.OrDefaultString(app.Config().WriteAsOauth.InspectLocation, writeAsIdentityLocation), 210 AuthLocation: config.OrDefaultString(app.Config().WriteAsOauth.AuthLocation, writeAsAuthLocation), 211 HttpClient: config.DefaultHTTPClient(), 212 CallbackLocation: callbackLocation, 213 } 214 configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) 215 } 216} 217 218func configureGitlabOauth(parentHandler *Handler, r *mux.Router, app *App) { 219 if app.Config().GitlabOauth.ClientID != "" { 220 callbackLocation := app.Config().App.Host + "/oauth/callback/gitlab" 221 222 var callbackProxy *callbackProxyClient = nil 223 if app.Config().GitlabOauth.CallbackProxy != "" { 224 callbackProxy = &callbackProxyClient{ 225 server: app.Config().GitlabOauth.CallbackProxyAPI, 226 callbackLocation: app.Config().App.Host + "/oauth/callback/gitlab", 227 httpClient: config.DefaultHTTPClient(), 228 } 229 callbackLocation = app.Config().GitlabOauth.CallbackProxy 230 } 231 232 address := config.OrDefaultString(app.Config().GitlabOauth.Host, gitlabHost) 233 oauthClient := gitlabOauthClient{ 234 ClientID: app.Config().GitlabOauth.ClientID, 235 ClientSecret: app.Config().GitlabOauth.ClientSecret, 236 ExchangeLocation: address + "/oauth/token", 237 InspectLocation: address + "/api/v4/user", 238 AuthLocation: address + "/oauth/authorize", 239 HttpClient: config.DefaultHTTPClient(), 240 CallbackLocation: callbackLocation, 241 } 242 configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) 243 } 244} 245 246func configureGenericOauth(parentHandler *Handler, r *mux.Router, app *App) { 247 if app.Config().GenericOauth.ClientID != "" { 248 callbackLocation := app.Config().App.Host + "/oauth/callback/generic" 249 250 var callbackProxy *callbackProxyClient = nil 251 if app.Config().GenericOauth.CallbackProxy != "" { 252 callbackProxy = &callbackProxyClient{ 253 server: app.Config().GenericOauth.CallbackProxyAPI, 254 callbackLocation: app.Config().App.Host + "/oauth/callback/generic", 255 httpClient: config.DefaultHTTPClient(), 256 } 257 callbackLocation = app.Config().GenericOauth.CallbackProxy 258 } 259 260 oauthClient := genericOauthClient{ 261 ClientID: app.Config().GenericOauth.ClientID, 262 ClientSecret: app.Config().GenericOauth.ClientSecret, 263 ExchangeLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.TokenEndpoint, 264 InspectLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.InspectEndpoint, 265 AuthLocation: app.Config().GenericOauth.Host + app.Config().GenericOauth.AuthEndpoint, 266 HttpClient: config.DefaultHTTPClient(), 267 CallbackLocation: callbackLocation, 268 Scope: config.OrDefaultString(app.Config().GenericOauth.Scope, "read_user"), 269 MapUserID: config.OrDefaultString(app.Config().GenericOauth.MapUserID, "user_id"), 270 MapUsername: config.OrDefaultString(app.Config().GenericOauth.MapUsername, "username"), 271 MapDisplayName: config.OrDefaultString(app.Config().GenericOauth.MapDisplayName, "-"), 272 MapEmail: config.OrDefaultString(app.Config().GenericOauth.MapEmail, "email"), 273 } 274 configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) 275 } 276} 277 278func configureGiteaOauth(parentHandler *Handler, r *mux.Router, app *App) { 279 if app.Config().GiteaOauth.ClientID != "" { 280 callbackLocation := app.Config().App.Host + "/oauth/callback/gitea" 281 282 var callbackProxy *callbackProxyClient = nil 283 if app.Config().GiteaOauth.CallbackProxy != "" { 284 callbackProxy = &callbackProxyClient{ 285 server: app.Config().GiteaOauth.CallbackProxyAPI, 286 callbackLocation: app.Config().App.Host + "/oauth/callback/gitea", 287 httpClient: config.DefaultHTTPClient(), 288 } 289 callbackLocation = app.Config().GiteaOauth.CallbackProxy 290 } 291 292 oauthClient := giteaOauthClient{ 293 ClientID: app.Config().GiteaOauth.ClientID, 294 ClientSecret: app.Config().GiteaOauth.ClientSecret, 295 ExchangeLocation: app.Config().GiteaOauth.Host + "/login/oauth/access_token", 296 InspectLocation: app.Config().GiteaOauth.Host + "/api/v1/user", 297 AuthLocation: app.Config().GiteaOauth.Host + "/login/oauth/authorize", 298 HttpClient: config.DefaultHTTPClient(), 299 CallbackLocation: callbackLocation, 300 } 301 configureOauthRoutes(parentHandler, r, app, oauthClient, callbackProxy) 302 } 303} 304 305func configureOauthRoutes(parentHandler *Handler, r *mux.Router, app *App, oauthClient oauthClient, callbackProxy *callbackProxyClient) { 306 handler := &oauthHandler{ 307 Config: app.Config(), 308 DB: app.DB(), 309 Store: app.SessionStore(), 310 oauthClient: oauthClient, 311 EmailKey: app.keys.EmailKey, 312 callbackProxy: callbackProxy, 313 } 314 r.HandleFunc("/oauth/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthInit)).Methods("GET") 315 r.HandleFunc("/oauth/callback/"+oauthClient.GetProvider(), parentHandler.OAuth(handler.viewOauthCallback)).Methods("GET") 316 r.HandleFunc("/oauth/signup", parentHandler.OAuth(handler.viewOauthSignup)).Methods("POST") 317} 318 319func (h oauthHandler) viewOauthCallback(app *App, w http.ResponseWriter, r *http.Request) error { 320 ctx := r.Context() 321 322 code := r.FormValue("code") 323 state := r.FormValue("state") 324 325 provider, clientID, attachUserID, inviteCode, err := h.DB.ValidateOAuthState(ctx, state) 326 if err != nil { 327 log.Error("Unable to ValidateOAuthState: %s", err) 328 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 329 } 330 331 tokenResponse, err := h.oauthClient.exchangeOauthCode(ctx, code) 332 if err != nil { 333 log.Error("Unable to exchangeOauthCode: %s", err) 334 // TODO: show user friendly message if needed 335 // TODO: show NO message for cases like user pressing "Cancel" on authorize step 336 addSessionFlash(app, w, r, err.Error(), nil) 337 if attachUserID > 0 { 338 return impart.HTTPError{http.StatusFound, "/me/settings"} 339 } 340 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 341 } 342 343 // Now that we have the access token, let's use it real quick to make sure 344 // it really really works. 345 tokenInfo, err := h.oauthClient.inspectOauthAccessToken(ctx, tokenResponse.AccessToken) 346 if err != nil { 347 log.Error("Unable to inspectOauthAccessToken: %s", err) 348 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 349 } 350 351 localUserID, err := h.DB.GetIDForRemoteUser(ctx, tokenInfo.UserID, provider, clientID) 352 if err != nil { 353 log.Error("Unable to GetIDForRemoteUser: %s", err) 354 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 355 } 356 357 if localUserID != -1 && attachUserID > 0 { 358 if err = addSessionFlash(app, w, r, "This Slack account is already attached to another user.", nil); err != nil { 359 return impart.HTTPError{Status: http.StatusInternalServerError, Message: err.Error()} 360 } 361 return impart.HTTPError{http.StatusFound, "/me/settings"} 362 } 363 364 if localUserID != -1 { 365 // Existing user, so log in now 366 user, err := h.DB.GetUserByID(localUserID) 367 if err != nil { 368 log.Error("Unable to GetUserByID %d: %s", localUserID, err) 369 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 370 } 371 if err = loginOrFail(h.Store, w, r, user); err != nil { 372 log.Error("Unable to loginOrFail %d: %s", localUserID, err) 373 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 374 } 375 return nil 376 } 377 if attachUserID > 0 { 378 log.Info("attaching to user %d", attachUserID) 379 err = h.DB.RecordRemoteUserID(r.Context(), attachUserID, tokenInfo.UserID, provider, clientID, tokenResponse.AccessToken) 380 if err != nil { 381 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 382 } 383 return impart.HTTPError{http.StatusFound, "/me/settings"} 384 } 385 386 // New user registration below. 387 // First, verify that user is allowed to register 388 if inviteCode != "" { 389 // Verify invite code is valid 390 i, err := app.db.GetUserInvite(inviteCode) 391 if err != nil { 392 return impart.HTTPError{http.StatusInternalServerError, err.Error()} 393 } 394 if !i.Active(app.db) { 395 return impart.HTTPError{http.StatusNotFound, "Invite link has expired."} 396 } 397 } else if !app.cfg.App.OpenRegistration { 398 addSessionFlash(app, w, r, ErrUserNotFound.Error(), nil) 399 return impart.HTTPError{http.StatusFound, "/login"} 400 } 401 402 displayName := tokenInfo.DisplayName 403 if len(displayName) == 0 { 404 displayName = tokenInfo.Username 405 } 406 407 tp := &oauthSignupPageParams{ 408 AccessToken: tokenResponse.AccessToken, 409 TokenUsername: tokenInfo.Username, 410 TokenAlias: tokenInfo.DisplayName, 411 TokenEmail: tokenInfo.Email, 412 TokenRemoteUser: tokenInfo.UserID, 413 Provider: provider, 414 ClientID: clientID, 415 InviteCode: inviteCode, 416 } 417 tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed) 418 419 return h.showOauthSignupPage(app, w, r, tp, nil) 420} 421 422func (r *callbackProxyClient) register(ctx context.Context, state string) error { 423 form := url.Values{} 424 form.Add("state", state) 425 form.Add("location", r.callbackLocation) 426 req, err := http.NewRequestWithContext(ctx, "POST", r.server, strings.NewReader(form.Encode())) 427 if err != nil { 428 return err 429 } 430 req.Header.Set("User-Agent", ServerUserAgent("")) 431 req.Header.Set("Accept", "application/json") 432 req.Header.Set("Content-Type", "application/x-www-form-urlencoded") 433 434 resp, err := r.httpClient.Do(req) 435 if err != nil { 436 return err 437 } 438 if resp.StatusCode != http.StatusCreated { 439 return fmt.Errorf("unable register state location: %d", resp.StatusCode) 440 } 441 442 return nil 443} 444 445func limitedJsonUnmarshal(body io.ReadCloser, n int, thing interface{}) error { 446 lr := io.LimitReader(body, int64(n+1)) 447 data, err := ioutil.ReadAll(lr) 448 if err != nil { 449 return err 450 } 451 if len(data) == n+1 { 452 return fmt.Errorf("content larger than max read allowance: %d", n) 453 } 454 return json.Unmarshal(data, thing) 455} 456 457func loginOrFail(store sessions.Store, w http.ResponseWriter, r *http.Request, user *User) error { 458 // An error may be returned, but a valid session should always be returned. 459 session, _ := store.Get(r, cookieName) 460 session.Values[cookieUserVal] = user.Cookie() 461 if err := session.Save(r, w); err != nil { 462 fmt.Println("error saving session", err) 463 return err 464 } 465 http.Redirect(w, r, "/", http.StatusTemporaryRedirect) 466 return nil 467} 468