1/*
2 * Copyright © 2020-2021 A Bunch Tell LLC.
3 *
4 * This file is part of WriteFreely.
5 *
6 * WriteFreely is free software: you can redistribute it and/or modify
7 * it under the terms of the GNU Affero General Public License, included
8 * in the LICENSE file in this source code package.
9 */
10
11package writefreely
12
13import (
14	"crypto/sha256"
15	"encoding/hex"
16	"fmt"
17	"github.com/writeas/impart"
18	"github.com/writeas/web-core/auth"
19	"github.com/writeas/web-core/log"
20	"github.com/writefreely/writefreely/page"
21	"html/template"
22	"net/http"
23	"strings"
24	"time"
25)
26
27type viewOauthSignupVars struct {
28	page.StaticPage
29	To      string
30	Message template.HTML
31	Flashes []template.HTML
32
33	AccessToken     string
34	TokenUsername   string
35	TokenAlias      string // TODO: rename this to match the data it represents: the collection title
36	TokenEmail      string
37	TokenRemoteUser string
38	Provider        string
39	ClientID        string
40	TokenHash       string
41	InviteCode      string
42
43	LoginUsername string
44	Alias         string // TODO: rename this to match the data it represents: the collection title
45	Email         string
46}
47
48const (
49	oauthParamAccessToken       = "access_token"
50	oauthParamTokenUsername     = "token_username"
51	oauthParamTokenAlias        = "token_alias"
52	oauthParamTokenEmail        = "token_email"
53	oauthParamTokenRemoteUserID = "token_remote_user"
54	oauthParamClientID          = "client_id"
55	oauthParamProvider          = "provider"
56	oauthParamHash              = "signature"
57	oauthParamUsername          = "username"
58	oauthParamAlias             = "alias"
59	oauthParamEmail             = "email"
60	oauthParamPassword          = "password"
61	oauthParamInviteCode        = "invite_code"
62)
63
64type oauthSignupPageParams struct {
65	AccessToken     string
66	TokenUsername   string
67	TokenAlias      string // TODO: rename this to match the data it represents: the collection title
68	TokenEmail      string
69	TokenRemoteUser string
70	ClientID        string
71	Provider        string
72	TokenHash       string
73	InviteCode      string
74}
75
76func (p oauthSignupPageParams) HashTokenParams(key string) string {
77	hasher := sha256.New()
78	hasher.Write([]byte(key))
79	hasher.Write([]byte(p.AccessToken))
80	hasher.Write([]byte(p.TokenUsername))
81	hasher.Write([]byte(p.TokenAlias))
82	hasher.Write([]byte(p.TokenEmail))
83	hasher.Write([]byte(p.TokenRemoteUser))
84	hasher.Write([]byte(p.ClientID))
85	hasher.Write([]byte(p.Provider))
86	return hex.EncodeToString(hasher.Sum(nil))
87}
88
89func (h oauthHandler) viewOauthSignup(app *App, w http.ResponseWriter, r *http.Request) error {
90	tp := &oauthSignupPageParams{
91		AccessToken:     r.FormValue(oauthParamAccessToken),
92		TokenUsername:   r.FormValue(oauthParamTokenUsername),
93		TokenAlias:      r.FormValue(oauthParamTokenAlias),
94		TokenEmail:      r.FormValue(oauthParamTokenEmail),
95		TokenRemoteUser: r.FormValue(oauthParamTokenRemoteUserID),
96		ClientID:        r.FormValue(oauthParamClientID),
97		Provider:        r.FormValue(oauthParamProvider),
98		InviteCode:      r.FormValue(oauthParamInviteCode),
99	}
100	if tp.HashTokenParams(h.Config.Server.HashSeed) != r.FormValue(oauthParamHash) {
101		return impart.HTTPError{Status: http.StatusBadRequest, Message: "Request has been tampered with."}
102	}
103	tp.TokenHash = tp.HashTokenParams(h.Config.Server.HashSeed)
104	if err := h.validateOauthSignup(r); err != nil {
105		return h.showOauthSignupPage(app, w, r, tp, err)
106	}
107
108	var err error
109	hashedPass := []byte{}
110	clearPass := r.FormValue(oauthParamPassword)
111	hasPass := clearPass != ""
112	if hasPass {
113		hashedPass, err = auth.HashPass([]byte(clearPass))
114		if err != nil {
115			return h.showOauthSignupPage(app, w, r, tp, fmt.Errorf("unable to hash password"))
116		}
117	}
118	newUser := &User{
119		Username:   r.FormValue(oauthParamUsername),
120		HashedPass: hashedPass,
121		HasPass:    hasPass,
122		Email:      prepareUserEmail(r.FormValue(oauthParamEmail), h.EmailKey),
123		Created:    time.Now().Truncate(time.Second).UTC(),
124	}
125	displayName := r.FormValue(oauthParamAlias)
126	if len(displayName) == 0 {
127		displayName = r.FormValue(oauthParamUsername)
128	}
129
130	err = h.DB.CreateUser(h.Config, newUser, displayName, "")
131	if err != nil {
132		return h.showOauthSignupPage(app, w, r, tp, err)
133	}
134
135	// Log invite if needed
136	if tp.InviteCode != "" {
137		err = app.db.CreateInvitedUser(tp.InviteCode, newUser.ID)
138		if err != nil {
139			return err
140		}
141	}
142
143	err = h.DB.RecordRemoteUserID(r.Context(), newUser.ID, r.FormValue(oauthParamTokenRemoteUserID), r.FormValue(oauthParamProvider), r.FormValue(oauthParamClientID), r.FormValue(oauthParamAccessToken))
144	if err != nil {
145		return h.showOauthSignupPage(app, w, r, tp, err)
146	}
147
148	if err := loginOrFail(h.Store, w, r, newUser); err != nil {
149		return h.showOauthSignupPage(app, w, r, tp, err)
150	}
151	return nil
152}
153
154func (h oauthHandler) validateOauthSignup(r *http.Request) error {
155	username := r.FormValue(oauthParamUsername)
156	if len(username) < h.Config.App.MinUsernameLen {
157		return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too short."}
158	}
159	if len(username) > 100 {
160		return impart.HTTPError{Status: http.StatusBadRequest, Message: "Username is too long."}
161	}
162	collTitle := r.FormValue(oauthParamAlias)
163	if len(collTitle) == 0 {
164		collTitle = username
165	}
166	email := r.FormValue(oauthParamEmail)
167	if len(email) > 0 {
168		parts := strings.Split(email, "@")
169		if len(parts) != 2 || (len(parts[0]) < 1 || len(parts[1]) < 1) {
170			return impart.HTTPError{Status: http.StatusBadRequest, Message: "Invalid email address"}
171		}
172	}
173	return nil
174}
175
176func (h oauthHandler) showOauthSignupPage(app *App, w http.ResponseWriter, r *http.Request, tp *oauthSignupPageParams, errMsg error) error {
177	username := tp.TokenUsername
178	collTitle := tp.TokenAlias
179	email := tp.TokenEmail
180
181	session, err := app.sessionStore.Get(r, cookieName)
182	if err != nil {
183		// Ignore this
184		log.Error("Unable to get session; ignoring: %v", err)
185	}
186
187	if tmpValue := r.FormValue(oauthParamUsername); len(tmpValue) > 0 {
188		username = tmpValue
189	}
190	if tmpValue := r.FormValue(oauthParamAlias); len(tmpValue) > 0 {
191		collTitle = tmpValue
192	}
193	if tmpValue := r.FormValue(oauthParamEmail); len(tmpValue) > 0 {
194		email = tmpValue
195	}
196
197	p := &viewOauthSignupVars{
198		StaticPage: pageForReq(app, r),
199		To:         r.FormValue("to"),
200		Flashes:    []template.HTML{},
201
202		AccessToken:     tp.AccessToken,
203		TokenUsername:   tp.TokenUsername,
204		TokenAlias:      tp.TokenAlias,
205		TokenEmail:      tp.TokenEmail,
206		TokenRemoteUser: tp.TokenRemoteUser,
207		Provider:        tp.Provider,
208		ClientID:        tp.ClientID,
209		TokenHash:       tp.TokenHash,
210		InviteCode:      tp.InviteCode,
211
212		LoginUsername: username,
213		Alias:         collTitle,
214		Email:         email,
215	}
216
217	// Display any error messages
218	flashes, _ := getSessionFlashes(app, w, r, session)
219	for _, flash := range flashes {
220		p.Flashes = append(p.Flashes, template.HTML(flash))
221	}
222	if errMsg != nil {
223		p.Flashes = append(p.Flashes, template.HTML(errMsg.Error()))
224	}
225	err = pages["signup-oauth.tmpl"].ExecuteTemplate(w, "base", p)
226	if err != nil {
227		log.Error("Unable to render signup-oauth: %v", err)
228		return err
229	}
230	return nil
231}
232