1// Package sasl is an implementation detail of the mgo package.
2//
3// This package is not meant to be used by itself.
4//
5
6// +build !windows
7
8package sasl
9
10// #cgo LDFLAGS: -lsasl2
11//
12// struct sasl_conn {};
13//
14// #include <stdlib.h>
15// #include <sasl/sasl.h>
16//
17// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password);
18//
19import "C"
20
21import (
22	"fmt"
23	"strings"
24	"sync"
25	"unsafe"
26)
27
28type saslStepper interface {
29	Step(serverData []byte) (clientData []byte, done bool, err error)
30	Close()
31}
32
33type saslSession struct {
34	conn *C.sasl_conn_t
35	step int
36	mech string
37
38	cstrings  []*C.char
39	callbacks *C.sasl_callback_t
40}
41
42var initError error
43var initOnce sync.Once
44
45func initSASL() {
46	rc := C.sasl_client_init(nil)
47	if rc != C.SASL_OK {
48		initError = saslError(rc, nil, "cannot initialize SASL library")
49	}
50}
51
52func New(username, password, mechanism, service, host string) (saslStepper, error) {
53	initOnce.Do(initSASL)
54	if initError != nil {
55		return nil, initError
56	}
57
58	ss := &saslSession{mech: mechanism}
59	if service == "" {
60		service = "mongodb"
61	}
62	if i := strings.Index(host, ":"); i >= 0 {
63		host = host[:i]
64	}
65	ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password))
66	rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn)
67	if rc != C.SASL_OK {
68		ss.Close()
69		return nil, saslError(rc, nil, "cannot create new SASL client")
70	}
71	return ss, nil
72}
73
74func (ss *saslSession) cstr(s string) *C.char {
75	cstr := C.CString(s)
76	ss.cstrings = append(ss.cstrings, cstr)
77	return cstr
78}
79
80func (ss *saslSession) Close() {
81	for _, cstr := range ss.cstrings {
82		C.free(unsafe.Pointer(cstr))
83	}
84	ss.cstrings = nil
85
86	if ss.callbacks != nil {
87		C.free(unsafe.Pointer(ss.callbacks))
88	}
89
90	// The documentation of SASL dispose makes it clear that this should only
91	// be done when the connection is done, not when the authentication phase
92	// is done, because an encryption layer may have been negotiated.
93	// Even then, we'll do this for now, because it's simpler and prevents
94	// keeping track of this state for every socket. If it breaks, we'll fix it.
95	C.sasl_dispose(&ss.conn)
96}
97
98func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) {
99	ss.step++
100	if ss.step > 10 {
101		return nil, false, fmt.Errorf("too many SASL steps without authentication")
102	}
103	var cclientData *C.char
104	var cclientDataLen C.uint
105	var rc C.int
106	if ss.step == 1 {
107		var mechanism *C.char // ignored - must match cred
108		rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism)
109	} else {
110		var cserverData *C.char
111		var cserverDataLen C.uint
112		if len(serverData) > 0 {
113			cserverData = (*C.char)(unsafe.Pointer(&serverData[0]))
114			cserverDataLen = C.uint(len(serverData))
115		}
116		rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen)
117	}
118	if cclientData != nil && cclientDataLen > 0 {
119		clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen))
120	}
121	if rc == C.SASL_OK {
122		return clientData, true, nil
123	}
124	if rc == C.SASL_CONTINUE {
125		return clientData, false, nil
126	}
127	return nil, false, saslError(rc, ss.conn, "cannot establish SASL session")
128}
129
130func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error {
131	var detail string
132	if conn == nil {
133		detail = C.GoString(C.sasl_errstring(rc, nil, nil))
134	} else {
135		detail = C.GoString(C.sasl_errdetail(conn))
136	}
137	return fmt.Errorf(msg + ": " + detail)
138}
139