1// Copyright 2016 The etcd Authors
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package rafthttp
16
17import (
18	"fmt"
19	"io"
20	"io/ioutil"
21	"net/http"
22	"net/http/httptest"
23	"os"
24	"strings"
25	"testing"
26	"time"
27
28	"github.com/coreos/etcd/pkg/types"
29	"github.com/coreos/etcd/raft/raftpb"
30	"github.com/coreos/etcd/snap"
31)
32
33type strReaderCloser struct{ *strings.Reader }
34
35func (s strReaderCloser) Close() error { return nil }
36
37func TestSnapshotSend(t *testing.T) {
38	tests := []struct {
39		m    raftpb.Message
40		rc   io.ReadCloser
41		size int64
42
43		wsent  bool
44		wfiles int
45	}{
46		// sent and receive with no errors
47		{
48			m:    raftpb.Message{Type: raftpb.MsgSnap, To: 1},
49			rc:   strReaderCloser{strings.NewReader("hello")},
50			size: 5,
51
52			wsent:  true,
53			wfiles: 1,
54		},
55		// error when reading snapshot for send
56		{
57			m:    raftpb.Message{Type: raftpb.MsgSnap, To: 1},
58			rc:   &errReadCloser{fmt.Errorf("snapshot error")},
59			size: 1,
60
61			wsent:  false,
62			wfiles: 0,
63		},
64		// sends less than the given snapshot length
65		{
66			m:    raftpb.Message{Type: raftpb.MsgSnap, To: 1},
67			rc:   strReaderCloser{strings.NewReader("hello")},
68			size: 10000,
69
70			wsent:  false,
71			wfiles: 0,
72		},
73		// sends less than actual snapshot length
74		{
75			m:    raftpb.Message{Type: raftpb.MsgSnap, To: 1},
76			rc:   strReaderCloser{strings.NewReader("hello")},
77			size: 1,
78
79			wsent:  false,
80			wfiles: 0,
81		},
82	}
83
84	for i, tt := range tests {
85		sent, files := testSnapshotSend(t, snap.NewMessage(tt.m, tt.rc, tt.size))
86		if tt.wsent != sent {
87			t.Errorf("#%d: snapshot expected %v, got %v", i, tt.wsent, sent)
88		}
89		if tt.wfiles != len(files) {
90			t.Fatalf("#%d: expected %d files, got %d files", i, tt.wfiles, len(files))
91		}
92	}
93}
94
95func testSnapshotSend(t *testing.T, sm *snap.Message) (bool, []os.FileInfo) {
96	d, err := ioutil.TempDir(os.TempDir(), "snapdir")
97	if err != nil {
98		t.Fatal(err)
99	}
100	defer os.RemoveAll(d)
101
102	r := &fakeRaft{}
103	tr := &Transport{pipelineRt: &http.Transport{}, ClusterID: types.ID(1), Raft: r}
104	ch := make(chan struct{}, 1)
105	h := &syncHandler{newSnapshotHandler(tr, r, snap.New(d), types.ID(1)), ch}
106	srv := httptest.NewServer(h)
107	defer srv.Close()
108
109	picker := mustNewURLPicker(t, []string{srv.URL})
110	snapsend := newSnapshotSender(tr, picker, types.ID(1), newPeerStatus(types.ID(1)))
111	defer snapsend.stop()
112
113	snapsend.send(*sm)
114
115	sent := false
116	select {
117	case <-time.After(time.Second):
118		t.Fatalf("timed out sending snapshot")
119	case sent = <-sm.CloseNotify():
120	}
121
122	// wait for handler to finish accepting snapshot
123	<-ch
124
125	files, rerr := ioutil.ReadDir(d)
126	if rerr != nil {
127		t.Fatal(rerr)
128	}
129	return sent, files
130}
131
132type errReadCloser struct{ err error }
133
134func (s *errReadCloser) Read(p []byte) (int, error) { return 0, s.err }
135func (s *errReadCloser) Close() error               { return s.err }
136
137type syncHandler struct {
138	h  http.Handler
139	ch chan<- struct{}
140}
141
142func (sh *syncHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
143	sh.h.ServeHTTP(w, r)
144	sh.ch <- struct{}{}
145}
146