1package testutil
2
3import (
4	"fmt"
5	"io"
6	"net"
7	"strings"
8	"testing"
9	"time"
10
11	"github.com/hashicorp/go-msgpack/codec"
12	cstructs "github.com/hashicorp/nomad/client/structs"
13	"github.com/hashicorp/nomad/nomad/structs"
14	"github.com/stretchr/testify/require"
15)
16
17// StreamingRPC may be satisfied by client.Client or server.Server.
18type StreamingRPC interface {
19	StreamingRpcHandler(method string) (structs.StreamingRpcHandler, error)
20}
21
22// StreamingRPCErrorTestCase is a test case to be passed to the
23// assertStreamingRPCError func.
24type StreamingRPCErrorTestCase struct {
25	Name   string
26	RPC    string
27	Req    interface{}
28	Assert func(error) bool
29}
30
31// AssertStreamingRPCError asserts a streaming RPC's error matches the given
32// assertion in the test case.
33func AssertStreamingRPCError(t *testing.T, s StreamingRPC, tc StreamingRPCErrorTestCase) {
34	handler, err := s.StreamingRpcHandler(tc.RPC)
35	require.NoError(t, err)
36
37	// Create a pipe
38	p1, p2 := net.Pipe()
39	defer p1.Close()
40	defer p2.Close()
41
42	errCh := make(chan error, 1)
43	streamMsg := make(chan *cstructs.StreamErrWrapper, 1)
44
45	// Start the handler
46	go handler(p2)
47
48	// Start the decoder
49	go func() {
50		decoder := codec.NewDecoder(p1, structs.MsgpackHandle)
51		for {
52			var msg cstructs.StreamErrWrapper
53			if err := decoder.Decode(&msg); err != nil {
54				if err == io.EOF || strings.Contains(err.Error(), "closed") {
55					return
56				}
57				errCh <- fmt.Errorf("error decoding: %v", err)
58			}
59
60			streamMsg <- &msg
61		}
62	}()
63
64	// Send the request
65	encoder := codec.NewEncoder(p1, structs.MsgpackHandle)
66	require.NoError(t, encoder.Encode(tc.Req))
67
68	timeout := time.After(5 * time.Second)
69
70	for {
71		select {
72		case <-timeout:
73			t.Fatal("timeout")
74		case err := <-errCh:
75			require.NoError(t, err)
76		case msg := <-streamMsg:
77			// Convert RpcError to error
78			var err error
79			if msg.Error != nil {
80				err = msg.Error
81			}
82			require.True(t, tc.Assert(err), "(%T) %s", msg.Error, msg.Error)
83			return
84		}
85	}
86}
87