1// Copyright 2019 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package iotest
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"log"
12	"testing"
13)
14
15type errWriter struct {
16	err error
17}
18
19func (w errWriter) Write([]byte) (int, error) {
20	return 0, w.err
21}
22
23func TestWriteLogger(t *testing.T) {
24	olw := log.Writer()
25	olf := log.Flags()
26	olp := log.Prefix()
27
28	// Revert the original log settings before we exit.
29	defer func() {
30		log.SetFlags(olf)
31		log.SetPrefix(olp)
32		log.SetOutput(olw)
33	}()
34
35	lOut := new(bytes.Buffer)
36	log.SetPrefix("lw: ")
37	log.SetOutput(lOut)
38	log.SetFlags(0)
39
40	lw := new(bytes.Buffer)
41	wl := NewWriteLogger("write:", lw)
42	if _, err := wl.Write([]byte("Hello, World!")); err != nil {
43		t.Fatalf("Unexpectedly failed to write: %v", err)
44	}
45
46	if g, w := lw.String(), "Hello, World!"; g != w {
47		t.Errorf("WriteLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
48	}
49	wantLogWithHex := fmt.Sprintf("lw: write: %x\n", "Hello, World!")
50	if g, w := lOut.String(), wantLogWithHex; g != w {
51		t.Errorf("WriteLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
52	}
53}
54
55func TestWriteLogger_errorOnWrite(t *testing.T) {
56	olw := log.Writer()
57	olf := log.Flags()
58	olp := log.Prefix()
59
60	// Revert the original log settings before we exit.
61	defer func() {
62		log.SetFlags(olf)
63		log.SetPrefix(olp)
64		log.SetOutput(olw)
65	}()
66
67	lOut := new(bytes.Buffer)
68	log.SetPrefix("lw: ")
69	log.SetOutput(lOut)
70	log.SetFlags(0)
71
72	lw := errWriter{err: errors.New("Write Error!")}
73	wl := NewWriteLogger("write:", lw)
74	if _, err := wl.Write([]byte("Hello, World!")); err == nil {
75		t.Fatalf("Unexpectedly succeeded to write: %v", err)
76	}
77
78	wantLogWithHex := fmt.Sprintf("lw: write: %x: %v\n", "", "Write Error!")
79	if g, w := lOut.String(), wantLogWithHex; g != w {
80		t.Errorf("WriteLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
81	}
82}
83
84func TestReadLogger(t *testing.T) {
85	olw := log.Writer()
86	olf := log.Flags()
87	olp := log.Prefix()
88
89	// Revert the original log settings before we exit.
90	defer func() {
91		log.SetFlags(olf)
92		log.SetPrefix(olp)
93		log.SetOutput(olw)
94	}()
95
96	lOut := new(bytes.Buffer)
97	log.SetPrefix("lr: ")
98	log.SetOutput(lOut)
99	log.SetFlags(0)
100
101	data := []byte("Hello, World!")
102	p := make([]byte, len(data))
103	lr := bytes.NewReader(data)
104	rl := NewReadLogger("read:", lr)
105
106	n, err := rl.Read(p)
107	if err != nil {
108		t.Fatalf("Unexpectedly failed to read: %v", err)
109	}
110
111	if g, w := p[:n], data; !bytes.Equal(g, w) {
112		t.Errorf("ReadLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
113	}
114
115	wantLogWithHex := fmt.Sprintf("lr: read: %x\n", "Hello, World!")
116	if g, w := lOut.String(), wantLogWithHex; g != w {
117		t.Errorf("ReadLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
118	}
119}
120
121func TestReadLogger_errorOnRead(t *testing.T) {
122	olw := log.Writer()
123	olf := log.Flags()
124	olp := log.Prefix()
125
126	// Revert the original log settings before we exit.
127	defer func() {
128		log.SetFlags(olf)
129		log.SetPrefix(olp)
130		log.SetOutput(olw)
131	}()
132
133	lOut := new(bytes.Buffer)
134	log.SetPrefix("lr: ")
135	log.SetOutput(lOut)
136	log.SetFlags(0)
137
138	data := []byte("Hello, World!")
139	p := make([]byte, len(data))
140
141	lr := ErrReader(errors.New("io failure"))
142	rl := NewReadLogger("read", lr)
143	n, err := rl.Read(p)
144	if err == nil {
145		t.Fatalf("Unexpectedly succeeded to read: %v", err)
146	}
147
148	wantLogWithHex := fmt.Sprintf("lr: read %x: io failure\n", p[:n])
149	if g, w := lOut.String(), wantLogWithHex; g != w {
150		t.Errorf("ReadLogger mismatch\n\tgot:  %q\n\twant: %q", g, w)
151	}
152}
153