1package nats
2
3import (
4	"context"
5	"encoding/json"
6	"github.com/go-kit/kit/endpoint"
7	"github.com/nats-io/nats.go"
8	"time"
9)
10
11// Publisher wraps a URL and provides a method that implements endpoint.Endpoint.
12type Publisher struct {
13	publisher *nats.Conn
14	subject   string
15	enc       EncodeRequestFunc
16	dec       DecodeResponseFunc
17	before    []RequestFunc
18	after     []PublisherResponseFunc
19	timeout   time.Duration
20}
21
22// NewPublisher constructs a usable Publisher for a single remote method.
23func NewPublisher(
24	publisher *nats.Conn,
25	subject string,
26	enc EncodeRequestFunc,
27	dec DecodeResponseFunc,
28	options ...PublisherOption,
29) *Publisher {
30	p := &Publisher{
31		publisher: publisher,
32		subject:   subject,
33		enc:       enc,
34		dec:       dec,
35		timeout:   10 * time.Second,
36	}
37	for _, option := range options {
38		option(p)
39	}
40	return p
41}
42
43// PublisherOption sets an optional parameter for clients.
44type PublisherOption func(*Publisher)
45
46// PublisherBefore sets the RequestFuncs that are applied to the outgoing NATS
47// request before it's invoked.
48func PublisherBefore(before ...RequestFunc) PublisherOption {
49	return func(p *Publisher) { p.before = append(p.before, before...) }
50}
51
52// PublisherAfter sets the ClientResponseFuncs applied to the incoming NATS
53// request prior to it being decoded. This is useful for obtaining anything off
54// of the response and adding onto the context prior to decoding.
55func PublisherAfter(after ...PublisherResponseFunc) PublisherOption {
56	return func(p *Publisher) { p.after = append(p.after, after...) }
57}
58
59// PublisherTimeout sets the available timeout for NATS request.
60func PublisherTimeout(timeout time.Duration) PublisherOption {
61	return func(p *Publisher) { p.timeout = timeout }
62}
63
64// Endpoint returns a usable endpoint that invokes the remote endpoint.
65func (p Publisher) Endpoint() endpoint.Endpoint {
66	return func(ctx context.Context, request interface{}) (interface{}, error) {
67		ctx, cancel := context.WithTimeout(ctx, p.timeout)
68		defer cancel()
69
70		msg := nats.Msg{Subject: p.subject}
71
72		if err := p.enc(ctx, &msg, request); err != nil {
73			return nil, err
74		}
75
76		for _, f := range p.before {
77			ctx = f(ctx, &msg)
78		}
79
80		resp, err := p.publisher.RequestWithContext(ctx, msg.Subject, msg.Data)
81		if err != nil {
82			return nil, err
83		}
84
85		for _, f := range p.after {
86			ctx = f(ctx, resp)
87		}
88
89		response, err := p.dec(ctx, resp)
90		if err != nil {
91			return nil, err
92		}
93
94		return response, nil
95	}
96}
97
98// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a
99// JSON object to the Data of the Msg. Many JSON-over-NATS services can use it as
100// a sensible default.
101func EncodeJSONRequest(_ context.Context, msg *nats.Msg, request interface{}) error {
102	b, err := json.Marshal(request)
103	if err != nil {
104		return err
105	}
106
107	msg.Data = b
108
109	return nil
110}
111