1package webhook
2
3import (
4	"encoding/hex"
5	"fmt"
6	"testing"
7	"time"
8)
9
10var testPayload = []byte(`{
11  "id": "evt_test_webhook",
12  "object": "event"
13}`)
14var testSecret = "whsec_test_secret"
15
16type SignedPayload struct {
17	timestamp time.Time
18	payload   []byte
19	secret    string
20	scheme    string
21	signature []byte
22	header    string
23}
24
25func newSignedPayload(options ...func(*SignedPayload)) *SignedPayload {
26	signedPayload := &SignedPayload{}
27	signedPayload.timestamp = time.Now()
28	signedPayload.payload = testPayload
29	signedPayload.secret = testSecret
30	signedPayload.scheme = "v1"
31
32	for _, opt := range options {
33		opt(signedPayload)
34	}
35
36	if signedPayload.signature == nil {
37		signedPayload.signature = ComputeSignature(signedPayload.timestamp, signedPayload.payload, signedPayload.secret)
38	}
39	signedPayload.header = generateHeader(*signedPayload)
40	return signedPayload
41}
42
43func (p *SignedPayload) hexSignature() string {
44	return hex.EncodeToString(p.signature)
45}
46
47func generateHeader(p SignedPayload) string {
48	return fmt.Sprintf("t=%d,%s=%s", p.timestamp.Unix(), p.scheme, hex.EncodeToString(p.signature))
49}
50
51func TestTokenNew(t *testing.T) {
52	p := newSignedPayload()
53
54	evt, err := ConstructEvent(p.payload, p.header, p.secret)
55	if err != nil {
56		t.Errorf("Error validating signature: %v", err)
57	} else if evt.ID != "evt_test_webhook" {
58		t.Errorf("Expected a parsed event matching the test payload, got %v", evt)
59	}
60
61	p = newSignedPayload(func(p *SignedPayload) {
62		p.payload = append(p.payload, byte('['))
63	})
64	evt, err = ConstructEvent(p.payload, p.header, p.secret)
65	if err == nil {
66		t.Errorf("Invalid JSON did not cause a parse error")
67	}
68
69	p = newSignedPayload()
70	err = ValidatePayload(p.payload, "", p.secret)
71	if err != ErrNotSigned {
72		t.Errorf("Expected ErrNotSigned from missing signature, got %v", err)
73	}
74	evt, err = ConstructEvent(p.payload, "", p.secret)
75	if err != ErrNotSigned {
76		t.Errorf("Expected ErrNotSigned from missing signature, got %v", err)
77	}
78
79	evt, err = ConstructEvent(p.payload, "v1,t=1", p.secret)
80	if err != ErrInvalidHeader {
81		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
82	}
83
84	err = ValidatePayload(p.payload, "t=", p.secret)
85	if err != ErrInvalidHeader {
86		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
87	}
88	evt, err = ConstructEvent(p.payload, "t=", p.secret)
89	if err != ErrInvalidHeader {
90		t.Errorf("Expected ErrInvalidHeader from bad header format, got %v", err)
91	}
92
93	err = ValidatePayload(p.payload, p.header+",v1=bad_signature", p.secret)
94	if err != nil {
95		t.Errorf("Received unexpected %v error with an unreadable signature in the header (should be ignored)", err)
96	}
97	evt, err = ConstructEvent(p.payload, p.header+",v1=bad_signature", p.secret)
98	if err != nil {
99		t.Errorf("Received unexpected %v error with an unreadable signature in the header (should be ignored)", err)
100	}
101
102	p = newSignedPayload(func(p *SignedPayload) {
103		p.scheme = "v0"
104	})
105	err = ValidatePayload(p.payload, p.header, p.secret)
106	if err != ErrNoValidSignature {
107		t.Errorf("Expected error from mismatched schema, got %v", err)
108	}
109	evt, err = ConstructEvent(p.payload, p.header, p.secret)
110	if err != ErrNoValidSignature {
111		t.Errorf("Expected error from mismatched schema, got %v", err)
112	}
113
114	p = newSignedPayload(func(p *SignedPayload) {
115		p.signature = []byte("deadbeef")
116	})
117	err = ValidatePayload(p.payload, p.header, p.secret)
118	if err != ErrNoValidSignature {
119		t.Errorf("Expected error from fake signature, got %v", err)
120	}
121	evt, err = ConstructEvent(p.payload, p.header, p.secret)
122	if err != ErrNoValidSignature {
123		t.Errorf("Expected error from fake signature, got %v", err)
124	}
125
126	p = newSignedPayload()
127	p2 := newSignedPayload(func(p *SignedPayload) {
128		p.secret = testSecret + "_rolled_key"
129	})
130	headerWithRolledKey := p.header + ",v1=" + p2.hexSignature()
131	if p.hexSignature() == p2.hexSignature() {
132		t.Errorf("Got the same signature with two different secret keys")
133	}
134
135	err = ValidatePayload(p.payload, headerWithRolledKey, p.secret)
136	if err != nil {
137		t.Errorf("Expected to be able to decode webhook with old key after rolling key, but got %v", err)
138	}
139	evt, err = ConstructEvent(p.payload, headerWithRolledKey, p.secret)
140	if err != nil {
141		t.Errorf("Expected to be able to decode webhook with old key after rolling key, but got %v", err)
142	}
143	err = ValidatePayload(p.payload, headerWithRolledKey, p2.secret)
144	if err != nil {
145		t.Errorf("Expected to be able to decode webhook with new key after rolling key, but got %v", err)
146	}
147	evt, err = ConstructEvent(p.payload, headerWithRolledKey, p2.secret)
148	if err != nil {
149		t.Errorf("Expected to be able to decode webhook with new key after rolling key, but got %v", err)
150	}
151
152	p = newSignedPayload(func(p *SignedPayload) {
153		p.timestamp = time.Now().Add(-15 * time.Second)
154	})
155	err = ValidatePayloadWithTolerance(p.payload, p.header, p.secret, 10*time.Second)
156	if err != ErrTooOld {
157		t.Errorf("Received %v error when validating timestamp outside of allowed timing window", err)
158	}
159	evt, err = ConstructEventWithTolerance(p.payload, p.header, p.secret, 10*time.Second)
160	if err != ErrTooOld {
161		t.Errorf("Received %v error when validating timestamp outside of allowed timing window", err)
162	}
163
164	err = ValidatePayloadWithTolerance(p.payload, p.header, p.secret, 20*time.Second)
165	if err != nil {
166		t.Errorf("Received %v error when validating timestamp inside allowed timing window", err)
167	}
168	evt, err = ConstructEventWithTolerance(p.payload, p.header, p.secret, 20*time.Second)
169	if err != nil {
170		t.Errorf("Received %v error when validating timestamp inside allowed timing window", err)
171	}
172
173	p = newSignedPayload(func(p *SignedPayload) {
174		p.timestamp = time.Unix(12345, 0)
175	})
176	err = ValidatePayloadIgnoringTolerance(p.payload, p.header, p.secret)
177	if err != nil {
178		t.Errorf("Received %v error when timestamp outside window but no tolerance specified", err)
179	}
180	evt, err = ConstructEventIgnoringTolerance(p.payload, p.header, p.secret)
181	if err != nil {
182		t.Errorf("Received %v error when timestamp outside window but no tolerance specified", err)
183	}
184}
185