1// Copyright (C) MongoDB, Inc. 2017-present.
2//
3// Licensed under the Apache License, Version 2.0 (the "License"); you may
4// not use this file except in compliance with the License. You may obtain
5// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6
7// Copyright (C) MongoDB, Inc. 2018-present.
8//
9// Licensed under the Apache License, Version 2.0 (the "License"); you may
10// not use this file except in compliance with the License. You may obtain
11// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
12
13package auth
14
15import (
16	"context"
17	"fmt"
18
19	"github.com/xdg/scram"
20	"github.com/xdg/stringprep"
21	"go.mongodb.org/mongo-driver/x/mongo/driver"
22	"go.mongodb.org/mongo-driver/x/mongo/driver/description"
23)
24
25// SCRAMSHA1 holds the mechanism name "SCRAM-SHA-1"
26const SCRAMSHA1 = "SCRAM-SHA-1"
27
28// SCRAMSHA256 holds the mechanism name "SCRAM-SHA-256"
29const SCRAMSHA256 = "SCRAM-SHA-256"
30
31func newScramSHA1Authenticator(cred *Cred) (Authenticator, error) {
32	passdigest := mongoPasswordDigest(cred.Username, cred.Password)
33	client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "")
34	if err != nil {
35		return nil, newAuthError("error initializing SCRAM-SHA-1 client", err)
36	}
37	client.WithMinIterations(4096)
38	return &ScramAuthenticator{
39		mechanism: SCRAMSHA1,
40		source:    cred.Source,
41		client:    client,
42	}, nil
43}
44
45func newScramSHA256Authenticator(cred *Cred) (Authenticator, error) {
46	passprep, err := stringprep.SASLprep.Prepare(cred.Password)
47	if err != nil {
48		return nil, newAuthError(fmt.Sprintf("error SASLprepping password '%s'", cred.Password), err)
49	}
50	client, err := scram.SHA256.NewClientUnprepped(cred.Username, passprep, "")
51	if err != nil {
52		return nil, newAuthError("error initializing SCRAM-SHA-256 client", err)
53	}
54	client.WithMinIterations(4096)
55	return &ScramAuthenticator{
56		mechanism: SCRAMSHA256,
57		source:    cred.Source,
58		client:    client,
59	}, nil
60}
61
62// ScramAuthenticator uses the SCRAM algorithm over SASL to authenticate a connection.
63type ScramAuthenticator struct {
64	mechanism string
65	source    string
66	client    *scram.Client
67}
68
69// Auth authenticates the connection.
70func (a *ScramAuthenticator) Auth(ctx context.Context, _ description.Server, conn driver.Connection) error {
71	adapter := &scramSaslAdapter{conversation: a.client.NewConversation(), mechanism: a.mechanism}
72	err := ConductSaslConversation(ctx, conn, a.source, adapter)
73	if err != nil {
74		return newAuthError("sasl conversation error", err)
75	}
76	return nil
77}
78
79type scramSaslAdapter struct {
80	mechanism    string
81	conversation *scram.ClientConversation
82}
83
84func (a *scramSaslAdapter) Start() (string, []byte, error) {
85	step, err := a.conversation.Step("")
86	if err != nil {
87		return a.mechanism, nil, err
88	}
89	return a.mechanism, []byte(step), nil
90}
91
92func (a *scramSaslAdapter) Next(challenge []byte) ([]byte, error) {
93	step, err := a.conversation.Step(string(challenge))
94	if err != nil {
95		return nil, err
96	}
97	return []byte(step), nil
98}
99
100func (a *scramSaslAdapter) Completed() bool {
101	return a.conversation.Done()
102}
103