1package amqp_test
2
3import (
4	"context"
5	"encoding/json"
6	"errors"
7	"testing"
8	"time"
9
10	amqptransport "github.com/go-kit/kit/transport/amqp"
11	"github.com/streadway/amqp"
12)
13
14var (
15	errTypeAssertion = errors.New("type assertion error")
16)
17
18// mockChannel is a mock of *amqp.Channel.
19type mockChannel struct {
20	f          func(exchange, key string, mandatory, immediate bool)
21	c          chan<- amqp.Publishing
22	deliveries []amqp.Delivery
23}
24
25// Publish runs a test function f and sends resultant message to a channel.
26func (ch *mockChannel) Publish(exchange, key string, mandatory, immediate bool, msg amqp.Publishing) error {
27	ch.f(exchange, key, mandatory, immediate)
28	ch.c <- msg
29	return nil
30}
31
32var nullFunc = func(exchange, key string, mandatory, immediate bool) {
33}
34
35func (ch *mockChannel) Consume(queue, consumer string, autoAck, exclusive, noLocal, noWail bool, args amqp.Table) (<-chan amqp.Delivery, error) {
36	c := make(chan amqp.Delivery, len(ch.deliveries))
37	for _, d := range ch.deliveries {
38		c <- d
39	}
40	return c, nil
41}
42
43// TestSubscriberBadDecode checks if decoder errors are handled properly.
44func TestSubscriberBadDecode(t *testing.T) {
45	sub := amqptransport.NewSubscriber(
46		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
47		func(context.Context, *amqp.Delivery) (interface{}, error) { return nil, errors.New("err!") },
48		func(context.Context, *amqp.Publishing, interface{}) error {
49			return nil
50		},
51		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
52	)
53
54	outputChan := make(chan amqp.Publishing, 1)
55	ch := &mockChannel{f: nullFunc, c: outputChan}
56	sub.ServeDelivery(ch)(&amqp.Delivery{})
57
58	var msg amqp.Publishing
59	select {
60	case msg = <-outputChan:
61		break
62
63	case <-time.After(100 * time.Millisecond):
64		t.Fatal("Timed out waiting for publishing")
65	}
66	res, err := decodeSubscriberError(msg)
67	if err != nil {
68		t.Fatal(err)
69	}
70	if want, have := "err!", res.Error; want != have {
71		t.Errorf("want %s, have %s", want, have)
72	}
73}
74
75// TestSubscriberBadEndpoint checks if endpoint errors are handled properly.
76func TestSubscriberBadEndpoint(t *testing.T) {
77	sub := amqptransport.NewSubscriber(
78		func(context.Context, interface{}) (interface{}, error) { return nil, errors.New("err!") },
79		func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
80		func(context.Context, *amqp.Publishing, interface{}) error {
81			return nil
82		},
83		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
84	)
85
86	outputChan := make(chan amqp.Publishing, 1)
87	ch := &mockChannel{f: nullFunc, c: outputChan}
88	sub.ServeDelivery(ch)(&amqp.Delivery{})
89
90	var msg amqp.Publishing
91
92	select {
93	case msg = <-outputChan:
94		break
95
96	case <-time.After(100 * time.Millisecond):
97		t.Fatal("Timed out waiting for publishing")
98	}
99
100	res, err := decodeSubscriberError(msg)
101	if err != nil {
102		t.Fatal(err)
103	}
104	if want, have := "err!", res.Error; want != have {
105		t.Errorf("want %s, have %s", want, have)
106	}
107}
108
109// TestSubscriberBadEncoder checks if encoder errors are handled properly.
110func TestSubscriberBadEncoder(t *testing.T) {
111	sub := amqptransport.NewSubscriber(
112		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
113		func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
114		func(context.Context, *amqp.Publishing, interface{}) error {
115			return errors.New("err!")
116		},
117		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
118	)
119
120	outputChan := make(chan amqp.Publishing, 1)
121	ch := &mockChannel{f: nullFunc, c: outputChan}
122	sub.ServeDelivery(ch)(&amqp.Delivery{})
123
124	var msg amqp.Publishing
125
126	select {
127	case msg = <-outputChan:
128		break
129
130	case <-time.After(100 * time.Millisecond):
131		t.Fatal("Timed out waiting for publishing")
132	}
133
134	res, err := decodeSubscriberError(msg)
135	if err != nil {
136		t.Fatal(err)
137	}
138	if want, have := "err!", res.Error; want != have {
139		t.Errorf("want %s, have %s", want, have)
140	}
141}
142
143// TestSubscriberSuccess checks if CorrelationId and ReplyTo are set properly
144// and if the payload is encoded properly.
145func TestSubscriberSuccess(t *testing.T) {
146	cid := "correlation"
147	replyTo := "sender"
148	obj := testReq{
149		Squadron: 436,
150	}
151	b, err := json.Marshal(obj)
152	if err != nil {
153		t.Fatal(err)
154	}
155
156	sub := amqptransport.NewSubscriber(
157		testEndpoint,
158		testReqDecoder,
159		amqptransport.EncodeJSONResponse,
160		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
161	)
162
163	checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) {
164		if want, have := replyTo, key; want != have {
165			t.Errorf("want %s, have %s", want, have)
166		}
167	}
168
169	outputChan := make(chan amqp.Publishing, 1)
170	ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
171	sub.ServeDelivery(ch)(&amqp.Delivery{
172		CorrelationId: cid,
173		ReplyTo:       replyTo,
174		Body:          b,
175	})
176
177	var msg amqp.Publishing
178
179	select {
180	case msg = <-outputChan:
181		break
182
183	case <-time.After(100 * time.Millisecond):
184		t.Fatal("Timed out waiting for publishing")
185	}
186
187	if want, have := cid, msg.CorrelationId; want != have {
188		t.Errorf("want %s, have %s", want, have)
189	}
190
191	// check if error is not thrown
192	errRes, err := decodeSubscriberError(msg)
193	if err != nil {
194		t.Fatal(err)
195	}
196	if errRes.Error != "" {
197		t.Error("Received error from subscriber", errRes.Error)
198		return
199	}
200
201	// check obj vals
202	response, err := testResDecoder(msg.Body)
203	if err != nil {
204		t.Fatal(err)
205	}
206	res, ok := response.(testRes)
207	if !ok {
208		t.Error(errTypeAssertion)
209	}
210
211	if want, have := obj.Squadron, res.Squadron; want != have {
212		t.Errorf("want %d, have %d", want, have)
213	}
214	if want, have := names[obj.Squadron], res.Name; want != have {
215		t.Errorf("want %s, have %s", want, have)
216	}
217}
218
219// TestNopResponseSubscriber checks if setting responsePublisher to
220// NopResponsePublisher works properly by disabling response.
221func TestNopResponseSubscriber(t *testing.T) {
222	cid := "correlation"
223	replyTo := "sender"
224	obj := testReq{
225		Squadron: 436,
226	}
227	b, err := json.Marshal(obj)
228	if err != nil {
229		t.Fatal(err)
230	}
231
232	sub := amqptransport.NewSubscriber(
233		testEndpoint,
234		testReqDecoder,
235		amqptransport.EncodeJSONResponse,
236		amqptransport.SubscriberResponsePublisher(amqptransport.NopResponsePublisher),
237		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
238	)
239
240	checkReplyToFunc := func(exchange, key string, mandatory, immediate bool) {}
241
242	outputChan := make(chan amqp.Publishing, 1)
243	ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
244	sub.ServeDelivery(ch)(&amqp.Delivery{
245		CorrelationId: cid,
246		ReplyTo:       replyTo,
247		Body:          b,
248	})
249
250	select {
251	case <-outputChan:
252		t.Fatal("Subscriber with NopResponsePublisher replied.")
253	case <-time.After(100 * time.Millisecond):
254		break
255	}
256}
257
258// TestSubscriberMultipleBefore checks if options to set exchange, key, deliveryMode
259// are working.
260func TestSubscriberMultipleBefore(t *testing.T) {
261	exchange := "some exchange"
262	key := "some key"
263	deliveryMode := uint8(127)
264	contentType := "some content type"
265	contentEncoding := "some content encoding"
266	sub := amqptransport.NewSubscriber(
267		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
268		func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
269		amqptransport.EncodeJSONResponse,
270		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
271		amqptransport.SubscriberBefore(
272			amqptransport.SetPublishExchange(exchange),
273			amqptransport.SetPublishKey(key),
274			amqptransport.SetPublishDeliveryMode(deliveryMode),
275			amqptransport.SetContentType(contentType),
276			amqptransport.SetContentEncoding(contentEncoding),
277		),
278	)
279	checkReplyToFunc := func(exch, k string, mandatory, immediate bool) {
280		if want, have := exchange, exch; want != have {
281			t.Errorf("want %s, have %s", want, have)
282		}
283		if want, have := key, k; want != have {
284			t.Errorf("want %s, have %s", want, have)
285		}
286	}
287
288	outputChan := make(chan amqp.Publishing, 1)
289	ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
290	sub.ServeDelivery(ch)(&amqp.Delivery{})
291
292	var msg amqp.Publishing
293
294	select {
295	case msg = <-outputChan:
296		break
297
298	case <-time.After(100 * time.Millisecond):
299		t.Fatal("Timed out waiting for publishing")
300	}
301
302	// check if error is not thrown
303	errRes, err := decodeSubscriberError(msg)
304	if err != nil {
305		t.Fatal(err)
306	}
307	if errRes.Error != "" {
308		t.Error("Received error from subscriber", errRes.Error)
309		return
310	}
311
312	if want, have := contentType, msg.ContentType; want != have {
313		t.Errorf("want %s, have %s", want, have)
314	}
315
316	if want, have := contentEncoding, msg.ContentEncoding; want != have {
317		t.Errorf("want %s, have %s", want, have)
318	}
319
320	if want, have := deliveryMode, msg.DeliveryMode; want != have {
321		t.Errorf("want %d, have %d", want, have)
322	}
323}
324
325// TestDefaultContentMetaData checks that default ContentType and Content-Encoding
326// is not set as mentioned by AMQP specification.
327func TestDefaultContentMetaData(t *testing.T) {
328	defaultContentType := ""
329	defaultContentEncoding := ""
330	sub := amqptransport.NewSubscriber(
331		func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil },
332		func(context.Context, *amqp.Delivery) (interface{}, error) { return struct{}{}, nil },
333		amqptransport.EncodeJSONResponse,
334		amqptransport.SubscriberErrorEncoder(amqptransport.ReplyErrorEncoder),
335	)
336	checkReplyToFunc := func(exch, k string, mandatory, immediate bool) {}
337	outputChan := make(chan amqp.Publishing, 1)
338	ch := &mockChannel{f: checkReplyToFunc, c: outputChan}
339	sub.ServeDelivery(ch)(&amqp.Delivery{})
340
341	var msg amqp.Publishing
342
343	select {
344	case msg = <-outputChan:
345		break
346
347	case <-time.After(100 * time.Millisecond):
348		t.Fatal("Timed out waiting for publishing")
349	}
350
351	// check if error is not thrown
352	errRes, err := decodeSubscriberError(msg)
353	if err != nil {
354		t.Fatal(err)
355	}
356	if errRes.Error != "" {
357		t.Error("Received error from subscriber", errRes.Error)
358		return
359	}
360
361	if want, have := defaultContentType, msg.ContentType; want != have {
362		t.Errorf("want %s, have %s", want, have)
363	}
364	if want, have := defaultContentEncoding, msg.ContentEncoding; want != have {
365		t.Errorf("want %s, have %s", want, have)
366	}
367}
368
369func decodeSubscriberError(pub amqp.Publishing) (amqptransport.DefaultErrorResponse, error) {
370	var res amqptransport.DefaultErrorResponse
371	err := json.Unmarshal(pub.Body, &res)
372	return res, err
373}
374
375type testReq struct {
376	Squadron int `json:"s"`
377}
378type testRes struct {
379	Squadron int    `json:"s"`
380	Name     string `json:"n"`
381}
382
383func testEndpoint(_ context.Context, request interface{}) (interface{}, error) {
384	req, ok := request.(testReq)
385	if !ok {
386		return nil, errTypeAssertion
387	}
388	name, prs := names[req.Squadron]
389	if !prs {
390		return nil, errors.New("unknown squadron name")
391	}
392	res := testRes{
393		Squadron: req.Squadron,
394		Name:     name,
395	}
396	return res, nil
397}
398
399func testReqDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
400	var obj testReq
401	err := json.Unmarshal(d.Body, &obj)
402	return obj, err
403}
404
405func testReqEncoder(_ context.Context, p *amqp.Publishing, request interface{}) error {
406	req, ok := request.(testReq)
407	if !ok {
408		return errors.New("type assertion failure")
409	}
410	b, err := json.Marshal(req)
411	if err != nil {
412		return err
413	}
414	p.Body = b
415	return nil
416}
417
418func testResDeliveryDecoder(_ context.Context, d *amqp.Delivery) (interface{}, error) {
419	return testResDecoder(d.Body)
420}
421
422func testResDecoder(b []byte) (interface{}, error) {
423	var obj testRes
424	err := json.Unmarshal(b, &obj)
425	return obj, err
426}
427
428var names = map[int]string{
429	424: "tiger",
430	426: "thunderbird",
431	429: "bison",
432	436: "tusker",
433	437: "husky",
434}
435