1// Copyright (C) 2021 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package console
5
6import (
7	"context"
8	"crypto/rand"
9	"math/big"
10	"time"
11
12	"github.com/pquerna/otp"
13	"github.com/pquerna/otp/totp"
14	"github.com/zeebo/errs"
15)
16
17const (
18	// MFARecoveryCodeCount specifies how many MFA recovery codes to generate.
19	MFARecoveryCodeCount = 10
20)
21
22// Error messages.
23const (
24	mfaPasscodeInvalidErrMsg    = "The MFA passcode is not valid or has expired"
25	mfaRequiredErrMsg           = "A MFA passcode or recovery code is required"
26	mfaRecoveryInvalidErrMsg    = "The MFA recovery code is not valid or has been previously used"
27	mfaRecoveryGenerationErrMsg = "MFA recovery codes cannot be generated while MFA is disabled."
28	mfaConflictErrMsg           = "Expected either passcode or recovery code, but got both"
29)
30
31var (
32	// ErrMFAMissing is error type that occurs when a request is incomplete
33	// due to missing MFA passcode and recovery code.
34	ErrMFAMissing = errs.Class("MFA code required")
35
36	// ErrMFAConflict is error type that occurs when both a passcode and recovery code are given.
37	ErrMFAConflict = errs.Class("MFA conflict")
38
39	// ErrMFALogin is error type caused by MFA that occurs when logging in / retrieving token.
40	ErrMFALogin = errs.Class("MFA login")
41
42	// ErrMFARecoveryCode is error type that represents usage of invalid MFA recovery code.
43	ErrMFARecoveryCode = errs.Class("MFA recovery code")
44
45	// ErrMFAPasscode is error type that represents usage of invalid MFA passcode.
46	ErrMFAPasscode = errs.Class("MFA passcode")
47)
48
49// NewMFAValidationOpts returns the options used to validate TOTP passcodes.
50// These settings are also used to generate MFA secret keys for use in testing.
51func NewMFAValidationOpts() totp.ValidateOpts {
52	return totp.ValidateOpts{
53		Period:    30,
54		Skew:      1,
55		Digits:    6,
56		Algorithm: otp.AlgorithmSHA1,
57	}
58}
59
60// ValidateMFAPasscode returns whether the TOTP passcode is valid for the secret key at the given time.
61func ValidateMFAPasscode(passcode string, secretKey string, t time.Time) (bool, error) {
62	valid, err := totp.ValidateCustom(passcode, secretKey, t, NewMFAValidationOpts())
63	return valid, Error.Wrap(err)
64}
65
66// NewMFAPasscode derives a TOTP passcode from a secret key using a timestamp.
67func NewMFAPasscode(secretKey string, t time.Time) (string, error) {
68	code, err := totp.GenerateCodeCustom(secretKey, t, NewMFAValidationOpts())
69	return code, Error.Wrap(err)
70}
71
72// NewMFASecretKey generates a new TOTP secret key.
73func NewMFASecretKey() (string, error) {
74	opts := NewMFAValidationOpts()
75	key, err := totp.Generate(totp.GenerateOpts{
76		Issuer:      " ",
77		AccountName: " ",
78		Period:      opts.Period,
79		Digits:      otp.DigitsSix,
80		Algorithm:   opts.Algorithm,
81	})
82	if err != nil {
83		return "", Error.Wrap(err)
84	}
85	return key.Secret(), nil
86}
87
88// EnableUserMFA enables multi-factor authentication for the user if the given secret key and password are valid.
89func (s *Service) EnableUserMFA(ctx context.Context, passcode string, t time.Time) (err error) {
90	defer mon.Task()(&ctx)(&err)
91
92	auth, err := s.getAuthAndAuditLog(ctx, "enable MFA")
93	if err != nil {
94		return Error.Wrap(err)
95	}
96
97	valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t)
98	if err != nil {
99		return ErrValidation.Wrap(ErrMFAPasscode.Wrap(err))
100	}
101	if !valid {
102		return ErrValidation.Wrap(ErrMFAPasscode.New(mfaPasscodeInvalidErrMsg))
103	}
104
105	auth.User.MFAEnabled = true
106	err = s.store.Users().Update(ctx, &auth.User)
107	if err != nil {
108		return Error.Wrap(err)
109	}
110
111	return nil
112}
113
114// DisableUserMFA disables multi-factor authentication for the user if the given secret key and password are valid.
115func (s *Service) DisableUserMFA(ctx context.Context, passcode string, t time.Time, recoveryCode string) (err error) {
116	defer mon.Task()(&ctx)(&err)
117
118	auth, err := s.getAuthAndAuditLog(ctx, "disable MFA")
119	if err != nil {
120		return Error.Wrap(err)
121	}
122
123	user := &auth.User
124
125	if !user.MFAEnabled {
126		return nil
127	}
128
129	if recoveryCode != "" && passcode != "" {
130		return ErrMFAConflict.New(mfaConflictErrMsg)
131	}
132
133	if recoveryCode != "" {
134		found := false
135		for _, code := range user.MFARecoveryCodes {
136			if code == recoveryCode {
137				found = true
138				break
139			}
140		}
141		if !found {
142			return ErrUnauthorized.Wrap(ErrMFARecoveryCode.New(mfaRecoveryInvalidErrMsg))
143		}
144	} else if passcode != "" {
145		valid, err := ValidateMFAPasscode(passcode, auth.User.MFASecretKey, t)
146		if err != nil {
147			return ErrValidation.Wrap(ErrMFAPasscode.Wrap(err))
148		}
149		if !valid {
150			return ErrValidation.Wrap(ErrMFAPasscode.New(mfaPasscodeInvalidErrMsg))
151		}
152	} else {
153		return ErrMFAMissing.New(mfaRequiredErrMsg)
154	}
155
156	auth.User.MFAEnabled = false
157	auth.User.MFASecretKey = ""
158	auth.User.MFARecoveryCodes = nil
159	err = s.store.Users().Update(ctx, &auth.User)
160	if err != nil {
161		return Error.Wrap(err)
162	}
163
164	return nil
165}
166
167// NewMFARecoveryCode returns a randomly generated MFA recovery code.
168// Recovery codes are uppercase and alphanumeric. They are of the form XXXX-XXXX-XXXX.
169func NewMFARecoveryCode() (string, error) {
170	const chars = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
171	b := make([]byte, 14)
172	max := big.NewInt(int64(len(chars)))
173	for i := 0; i < 14; i++ {
174		if (i+1)%5 == 0 {
175			b[i] = '-'
176		} else {
177			num, err := rand.Int(rand.Reader, max)
178			if err != nil {
179				return "", err
180			}
181			b[i] = chars[num.Int64()]
182		}
183	}
184	return string(b), nil
185}
186
187// ResetMFASecretKey creates a new TOTP secret key for the user.
188func (s *Service) ResetMFASecretKey(ctx context.Context) (key string, err error) {
189	defer mon.Task()(&ctx)(&err)
190
191	auth, err := s.getAuthAndAuditLog(ctx, "reset MFA secret key")
192	if err != nil {
193		return "", Error.Wrap(err)
194	}
195
196	key, err = NewMFASecretKey()
197	if err != nil {
198		return "", Error.Wrap(err)
199	}
200
201	auth.User.MFASecretKey = key
202	err = s.store.Users().Update(ctx, &auth.User)
203	if err != nil {
204		return "", Error.Wrap(err)
205	}
206
207	return key, nil
208}
209
210// ResetMFARecoveryCodes creates a new set of MFA recovery codes for the user.
211func (s *Service) ResetMFARecoveryCodes(ctx context.Context) (codes []string, err error) {
212	defer mon.Task()(&ctx)(&err)
213
214	auth, err := s.getAuthAndAuditLog(ctx, "reset MFA recovery codes")
215	if err != nil {
216		return nil, Error.Wrap(err)
217	}
218
219	if !auth.User.MFAEnabled {
220		return nil, ErrUnauthorized.New(mfaRecoveryGenerationErrMsg)
221	}
222
223	codes = make([]string, MFARecoveryCodeCount)
224	for i := 0; i < MFARecoveryCodeCount; i++ {
225		code, err := NewMFARecoveryCode()
226		if err != nil {
227			return nil, Error.Wrap(err)
228		}
229		codes[i] = code
230	}
231	auth.User.MFARecoveryCodes = codes
232
233	err = s.store.Users().Update(ctx, &auth.User)
234	if err != nil {
235		return nil, Error.Wrap(err)
236	}
237
238	return codes, nil
239}
240