1package mail
2
3import (
4	"bytes"
5	"crypto/tls"
6	"io"
7	"net"
8	"net/smtp"
9	"reflect"
10	"testing"
11	"time"
12)
13
14const (
15	testPort    = 587
16	testSSLPort = 465
17)
18
19var (
20	testConn    = &net.TCPConn{}
21	testTLSConn = tls.Client(testConn, &tls.Config{InsecureSkipVerify: true})
22	testConfig  = &tls.Config{InsecureSkipVerify: true}
23	testAuth    = smtp.PlainAuth("", testUser, testPwd, testHost)
24)
25
26func TestDialer(t *testing.T) {
27	d := NewDialer(testHost, testPort, "user", "pwd")
28	testSendMail(t, d, []string{
29		"Extension STARTTLS",
30		"StartTLS",
31		"Extension AUTH",
32		"Auth",
33		"Mail " + testFrom,
34		"Rcpt " + testTo1,
35		"Rcpt " + testTo2,
36		"Data",
37		"Write message",
38		"Close writer",
39		"Quit",
40		"Close",
41	})
42}
43
44func TestDialerSSL(t *testing.T) {
45	d := NewDialer(testHost, testSSLPort, "user", "pwd")
46	testSendMail(t, d, []string{
47		"Extension AUTH",
48		"Auth",
49		"Mail " + testFrom,
50		"Rcpt " + testTo1,
51		"Rcpt " + testTo2,
52		"Data",
53		"Write message",
54		"Close writer",
55		"Quit",
56		"Close",
57	})
58}
59
60func TestDialerConfig(t *testing.T) {
61	d := NewDialer(testHost, testPort, "user", "pwd")
62	d.LocalName = "test"
63	d.TLSConfig = testConfig
64	testSendMail(t, d, []string{
65		"Hello test",
66		"Extension STARTTLS",
67		"StartTLS",
68		"Extension AUTH",
69		"Auth",
70		"Mail " + testFrom,
71		"Rcpt " + testTo1,
72		"Rcpt " + testTo2,
73		"Data",
74		"Write message",
75		"Close writer",
76		"Quit",
77		"Close",
78	})
79}
80
81func TestDialerSSLConfig(t *testing.T) {
82	d := NewDialer(testHost, testSSLPort, "user", "pwd")
83	d.LocalName = "test"
84	d.TLSConfig = testConfig
85	testSendMail(t, d, []string{
86		"Hello test",
87		"Extension AUTH",
88		"Auth",
89		"Mail " + testFrom,
90		"Rcpt " + testTo1,
91		"Rcpt " + testTo2,
92		"Data",
93		"Write message",
94		"Close writer",
95		"Quit",
96		"Close",
97	})
98}
99
100func TestDialerNoStartTLS(t *testing.T) {
101	d := NewDialer(testHost, testPort, "user", "pwd")
102	d.StartTLSPolicy = NoStartTLS
103	testSendMail(t, d, []string{
104		"Extension AUTH",
105		"Auth",
106		"Mail " + testFrom,
107		"Rcpt " + testTo1,
108		"Rcpt " + testTo2,
109		"Data",
110		"Write message",
111		"Close writer",
112		"Quit",
113		"Close",
114	})
115}
116
117func TestDialerOpportunisticStartTLS(t *testing.T) {
118	d := NewDialer(testHost, testPort, "user", "pwd")
119	d.StartTLSPolicy = OpportunisticStartTLS
120	testSendMail(t, d, []string{
121		"Extension STARTTLS",
122		"StartTLS",
123		"Extension AUTH",
124		"Auth",
125		"Mail " + testFrom,
126		"Rcpt " + testTo1,
127		"Rcpt " + testTo2,
128		"Data",
129		"Write message",
130		"Close writer",
131		"Quit",
132		"Close",
133	})
134
135	if OpportunisticStartTLS != 0 {
136		t.Errorf("OpportunisticStartTLS: expected 0, got %d",
137			OpportunisticStartTLS)
138	}
139}
140
141func TestDialerOpportunisticStartTLSUnsupported(t *testing.T) {
142	d := NewDialer(testHost, testPort, "user", "pwd")
143	d.StartTLSPolicy = OpportunisticStartTLS
144	testSendMailStartTLSUnsupported(t, d, []string{
145		"Extension STARTTLS",
146		"Extension AUTH",
147		"Auth",
148		"Mail " + testFrom,
149		"Rcpt " + testTo1,
150		"Rcpt " + testTo2,
151		"Data",
152		"Write message",
153		"Close writer",
154		"Quit",
155		"Close",
156	})
157}
158
159func TestDialerMandatoryStartTLS(t *testing.T) {
160	d := NewDialer(testHost, testPort, "user", "pwd")
161	d.StartTLSPolicy = MandatoryStartTLS
162	testSendMail(t, d, []string{
163		"Extension STARTTLS",
164		"StartTLS",
165		"Extension AUTH",
166		"Auth",
167		"Mail " + testFrom,
168		"Rcpt " + testTo1,
169		"Rcpt " + testTo2,
170		"Data",
171		"Write message",
172		"Close writer",
173		"Quit",
174		"Close",
175	})
176}
177
178func TestDialerMandatoryStartTLSUnsupported(t *testing.T) {
179	d := NewDialer(testHost, testPort, "user", "pwd")
180	d.StartTLSPolicy = MandatoryStartTLS
181
182	testClient := &mockClient{
183		t:        t,
184		addr:     addr(d.Host, d.Port),
185		config:   d.TLSConfig,
186		startTLS: false,
187		timeout:  true,
188	}
189
190	err := doTestSendMail(t, d, testClient, []string{
191		"Extension STARTTLS",
192	})
193
194	if _, ok := err.(StartTLSUnsupportedError); !ok {
195		t.Errorf("expected StartTLSUnsupportedError, but got: %s",
196			reflect.TypeOf(err).Name())
197	}
198
199	expected := "gomail: MandatoryStartTLS required, " +
200		"but SMTP server does not support STARTTLS"
201	if err.Error() != expected {
202		t.Errorf("expected %s, but got: %s", expected, err)
203	}
204}
205
206func TestDialerNoAuth(t *testing.T) {
207	d := &Dialer{
208		Host: testHost,
209		Port: testPort,
210	}
211	testSendMail(t, d, []string{
212		"Extension STARTTLS",
213		"StartTLS",
214		"Mail " + testFrom,
215		"Rcpt " + testTo1,
216		"Rcpt " + testTo2,
217		"Data",
218		"Write message",
219		"Close writer",
220		"Quit",
221		"Close",
222	})
223}
224
225func TestDialerTimeout(t *testing.T) {
226	d := &Dialer{
227		Host:         testHost,
228		Port:         testPort,
229		RetryFailure: true,
230	}
231	testSendMailTimeout(t, d, []string{
232		"Extension STARTTLS",
233		"StartTLS",
234		"Mail " + testFrom,
235		"Extension STARTTLS",
236		"StartTLS",
237		"Mail " + testFrom,
238		"Rcpt " + testTo1,
239		"Rcpt " + testTo2,
240		"Data",
241		"Write message",
242		"Close writer",
243		"Quit",
244		"Close",
245	})
246}
247
248func TestDialerTimeoutNoRetry(t *testing.T) {
249	d := &Dialer{
250		Host:         testHost,
251		Port:         testPort,
252		RetryFailure: false,
253	}
254	testClient := &mockClient{
255		t:        t,
256		addr:     addr(d.Host, d.Port),
257		config:   d.TLSConfig,
258		startTLS: true,
259		timeout:  true,
260	}
261
262	err := doTestSendMail(t, d, testClient, []string{
263		"Extension STARTTLS",
264		"StartTLS",
265		"Mail " + testFrom,
266		"Quit",
267	})
268
269	if err.Error() != "gomail: could not send email 1: EOF" {
270		t.Error("expected to have got EOF, but got:", err)
271	}
272}
273
274type mockClient struct {
275	t        *testing.T
276	i        int
277	want     []string
278	addr     string
279	config   *tls.Config
280	startTLS bool
281	timeout  bool
282}
283
284func (c *mockClient) Hello(localName string) error {
285	c.do("Hello " + localName)
286	return nil
287}
288
289func (c *mockClient) Extension(ext string) (bool, string) {
290	c.do("Extension " + ext)
291	ok := true
292	if ext == "STARTTLS" {
293		ok = c.startTLS
294	}
295	return ok, ""
296}
297
298func (c *mockClient) StartTLS(config *tls.Config) error {
299	assertConfig(c.t, config, c.config)
300	c.do("StartTLS")
301	return nil
302}
303
304func (c *mockClient) Auth(a smtp.Auth) error {
305	if !reflect.DeepEqual(a, testAuth) {
306		c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth)
307	}
308	c.do("Auth")
309	return nil
310}
311
312func (c *mockClient) Mail(from string) error {
313	c.do("Mail " + from)
314	if c.timeout {
315		c.timeout = false
316		return io.EOF
317	}
318	return nil
319}
320
321func (c *mockClient) Rcpt(to string) error {
322	c.do("Rcpt " + to)
323	return nil
324}
325
326func (c *mockClient) Data() (io.WriteCloser, error) {
327	c.do("Data")
328	return &mockWriter{c: c, want: testMsg}, nil
329}
330
331func (c *mockClient) Quit() error {
332	c.do("Quit")
333	return nil
334}
335
336func (c *mockClient) Close() error {
337	c.do("Close")
338	return nil
339}
340
341func (c *mockClient) do(cmd string) {
342	if c.i >= len(c.want) {
343		c.t.Fatalf("Invalid command %q", cmd)
344	}
345
346	if cmd != c.want[c.i] {
347		c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
348	}
349	c.i++
350}
351
352type mockWriter struct {
353	want string
354	c    *mockClient
355	buf  bytes.Buffer
356}
357
358func (w *mockWriter) Write(p []byte) (int, error) {
359	if w.buf.Len() == 0 {
360		w.c.do("Write message")
361	}
362	w.buf.Write(p)
363	return len(p), nil
364}
365
366func (w *mockWriter) Close() error {
367	compareBodies(w.c.t, w.buf.String(), w.want)
368	w.c.do("Close writer")
369	return nil
370}
371
372func testSendMail(t *testing.T, d *Dialer, want []string) {
373	testClient := &mockClient{
374		t:        t,
375		addr:     addr(d.Host, d.Port),
376		config:   d.TLSConfig,
377		startTLS: true,
378		timeout:  false,
379	}
380
381	if err := doTestSendMail(t, d, testClient, want); err != nil {
382		t.Error(err)
383	}
384}
385
386func testSendMailStartTLSUnsupported(t *testing.T, d *Dialer, want []string) {
387	testClient := &mockClient{
388		t:        t,
389		addr:     addr(d.Host, d.Port),
390		config:   d.TLSConfig,
391		startTLS: false,
392		timeout:  false,
393	}
394
395	if err := doTestSendMail(t, d, testClient, want); err != nil {
396		t.Error(err)
397	}
398}
399
400func testSendMailTimeout(t *testing.T, d *Dialer, want []string) {
401	testClient := &mockClient{
402		t:        t,
403		addr:     addr(d.Host, d.Port),
404		config:   d.TLSConfig,
405		startTLS: true,
406		timeout:  true,
407	}
408
409	if err := doTestSendMail(t, d, testClient, want); err != nil {
410		t.Error(err)
411	}
412}
413
414func doTestSendMail(t *testing.T, d *Dialer, testClient *mockClient, want []string) error {
415	testClient.want = want
416
417	NetDialTimeout = func(network, address string, d time.Duration) (net.Conn, error) {
418		if network != "tcp" {
419			t.Errorf("Invalid network, got %q, want tcp", network)
420		}
421		if address != testClient.addr {
422			t.Errorf("Invalid address, got %q, want %q",
423				address, testClient.addr)
424		}
425		return testConn, nil
426	}
427
428	tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn {
429		if conn != testConn {
430			t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn)
431		}
432		assertConfig(t, config, testClient.config)
433		return testTLSConn
434	}
435
436	smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
437		if host != testHost {
438			t.Errorf("Invalid host, got %q, want %q", host, testHost)
439		}
440		return testClient, nil
441	}
442
443	return d.DialAndSend(getTestMessage())
444}
445
446func assertConfig(t *testing.T, got, want *tls.Config) {
447	if want == nil {
448		want = &tls.Config{ServerName: testHost}
449	}
450	if got.ServerName != want.ServerName {
451		t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
452	}
453	if got.InsecureSkipVerify != want.InsecureSkipVerify {
454		t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
455	}
456}
457