1package pubsub
2
3import (
4	"context"
5	"fmt"
6	"runtime"
7	"time"
8
9	"github.com/libp2p/go-libp2p-core/peer"
10)
11
12const (
13	defaultValidateQueueSize   = 32
14	defaultValidateConcurrency = 1024
15	defaultValidateThrottle    = 8192
16)
17
18// Validator is a function that validates a message with a binary decision: accept or reject.
19type Validator func(context.Context, peer.ID, *Message) bool
20
21// ValidatorEx is an extended validation function that validates a message with an enumerated decision
22type ValidatorEx func(context.Context, peer.ID, *Message) ValidationResult
23
24// ValidationResult represents the decision of an extended validator
25type ValidationResult int
26
27const (
28	// ValidationAccept is a validation decision that indicates a valid message that should be accepted and
29	// delivered to the application and forwarded to the network.
30	ValidationAccept = ValidationResult(0)
31	// ValidationReject is a validation decision that indicates an invalid message that should not be
32	// delivered to the application or forwarded to the application. Furthermore the peer that forwarded
33	// the message should be penalized by peer scoring routers.
34	ValidationReject = ValidationResult(1)
35	// ValidationIgnore is a validation decision that indicates a message that should be ignored: it will
36	// be neither delivered to the application nor forwarded to the network. However, in contrast to
37	// ValidationReject, the peer that forwarded the message must not be penalized by peer scoring routers.
38	ValidationIgnore = ValidationResult(2)
39	// internal
40	validationThrottled = ValidationResult(-1)
41)
42
43// ValidatorOpt is an option for RegisterTopicValidator.
44type ValidatorOpt func(addVal *addValReq) error
45
46// validation represents the validator pipeline.
47// The validator pipeline performs signature validation and runs a
48// sequence of user-configured validators per-topic. It is possible to
49// adjust various concurrency parameters, such as the number of
50// workers and the max number of simultaneous validations. The user
51// can also attach inline validators that will be executed
52// synchronously; this may be useful to prevent superfluous
53// context-switching for lightweight tasks.
54type validation struct {
55	p *PubSub
56
57	tracer *pubsubTracer
58
59	// topicVals tracks per topic validators
60	topicVals map[string]*topicVal
61
62	// validateQ is the front-end to the validation pipeline
63	validateQ chan *validateReq
64
65	// validateThrottle limits the number of active validation goroutines
66	validateThrottle chan struct{}
67
68	// this is the number of synchronous validation workers
69	validateWorkers int
70}
71
72// validation requests
73type validateReq struct {
74	vals []*topicVal
75	src  peer.ID
76	msg  *Message
77}
78
79// representation of topic validators
80type topicVal struct {
81	topic            string
82	validate         ValidatorEx
83	validateTimeout  time.Duration
84	validateThrottle chan struct{}
85	validateInline   bool
86}
87
88// async request to add a topic validators
89type addValReq struct {
90	topic    string
91	validate interface{}
92	timeout  time.Duration
93	throttle int
94	inline   bool
95	resp     chan error
96}
97
98// async request to remove a topic validator
99type rmValReq struct {
100	topic string
101	resp  chan error
102}
103
104// newValidation creates a new validation pipeline
105func newValidation() *validation {
106	return &validation{
107		topicVals:        make(map[string]*topicVal),
108		validateQ:        make(chan *validateReq, defaultValidateQueueSize),
109		validateThrottle: make(chan struct{}, defaultValidateThrottle),
110		validateWorkers:  runtime.NumCPU(),
111	}
112}
113
114// Start attaches the validation pipeline to a pubsub instance and starts background
115// workers
116func (v *validation) Start(p *PubSub) {
117	v.p = p
118	v.tracer = p.tracer
119	for i := 0; i < v.validateWorkers; i++ {
120		go v.validateWorker()
121	}
122}
123
124// AddValidator adds a new validator
125func (v *validation) AddValidator(req *addValReq) {
126	topic := req.topic
127
128	_, ok := v.topicVals[topic]
129	if ok {
130		req.resp <- fmt.Errorf("Duplicate validator for topic %s", topic)
131		return
132	}
133
134	makeValidatorEx := func(v Validator) ValidatorEx {
135		return func(ctx context.Context, p peer.ID, msg *Message) ValidationResult {
136			if v(ctx, p, msg) {
137				return ValidationAccept
138			} else {
139				return ValidationReject
140			}
141		}
142	}
143
144	var validator ValidatorEx
145	switch v := req.validate.(type) {
146	case func(ctx context.Context, p peer.ID, msg *Message) bool:
147		validator = makeValidatorEx(Validator(v))
148	case Validator:
149		validator = makeValidatorEx(v)
150
151	case func(ctx context.Context, p peer.ID, msg *Message) ValidationResult:
152		validator = ValidatorEx(v)
153	case ValidatorEx:
154		validator = v
155
156	default:
157		req.resp <- fmt.Errorf("Unknown validator type for topic %s; must be an instance of Validator or ValidatorEx", topic)
158		return
159	}
160
161	val := &topicVal{
162		topic:            topic,
163		validate:         validator,
164		validateTimeout:  0,
165		validateThrottle: make(chan struct{}, defaultValidateConcurrency),
166		validateInline:   req.inline,
167	}
168
169	if req.timeout > 0 {
170		val.validateTimeout = req.timeout
171	}
172
173	if req.throttle > 0 {
174		val.validateThrottle = make(chan struct{}, req.throttle)
175	}
176
177	v.topicVals[topic] = val
178	req.resp <- nil
179}
180
181// RemoveValidator removes an existing validator
182func (v *validation) RemoveValidator(req *rmValReq) {
183	topic := req.topic
184
185	_, ok := v.topicVals[topic]
186	if ok {
187		delete(v.topicVals, topic)
188		req.resp <- nil
189	} else {
190		req.resp <- fmt.Errorf("No validator for topic %s", topic)
191	}
192}
193
194// Push pushes a message into the validation pipeline.
195// It returns true if the message can be forwarded immediately without validation.
196func (v *validation) Push(src peer.ID, msg *Message) bool {
197	vals := v.getValidators(msg)
198
199	if len(vals) > 0 || msg.Signature != nil {
200		select {
201		case v.validateQ <- &validateReq{vals, src, msg}:
202		default:
203			log.Debugf("message validation throttled: queue full; dropping message from %s", src)
204			v.tracer.RejectMessage(msg, rejectValidationQueueFull)
205		}
206		return false
207	}
208
209	return true
210}
211
212// getValidators returns all validators that apply to a given message
213func (v *validation) getValidators(msg *Message) []*topicVal {
214	topic := msg.GetTopic()
215
216	val, ok := v.topicVals[topic]
217	if !ok {
218		return nil
219	}
220
221	return []*topicVal{val}
222}
223
224// validateWorker is an active goroutine performing inline validation
225func (v *validation) validateWorker() {
226	for {
227		select {
228		case req := <-v.validateQ:
229			v.validate(req.vals, req.src, req.msg)
230		case <-v.p.ctx.Done():
231			return
232		}
233	}
234}
235
236// validate performs validation and only sends the message if all validators succeed
237// signature validation is performed synchronously, while user validators are invoked
238// asynchronously, throttled by the global validation throttle.
239func (v *validation) validate(vals []*topicVal, src peer.ID, msg *Message) {
240	// If signature verification is enabled, but signing is disabled,
241	// the Signature is required to be nil upon receiving the message in PubSub.pushMsg.
242	if msg.Signature != nil {
243		if !v.validateSignature(msg) {
244			log.Debugf("message signature validation failed; dropping message from %s", src)
245			v.tracer.RejectMessage(msg, rejectInvalidSignature)
246			return
247		}
248	}
249
250	// we can mark the message as seen now that we have verified the signature
251	// and avoid invoking user validators more than once
252	id := v.p.msgID(msg.Message)
253	if !v.p.markSeen(id) {
254		v.tracer.DuplicateMessage(msg)
255		return
256	} else {
257		v.tracer.ValidateMessage(msg)
258	}
259
260	var inline, async []*topicVal
261	for _, val := range vals {
262		if val.validateInline {
263			inline = append(inline, val)
264		} else {
265			async = append(async, val)
266		}
267	}
268
269	// apply inline (synchronous) validators
270	result := ValidationAccept
271loop:
272	for _, val := range inline {
273		switch val.validateMsg(v.p.ctx, src, msg) {
274		case ValidationAccept:
275		case ValidationReject:
276			result = ValidationReject
277			break loop
278		case ValidationIgnore:
279			result = ValidationIgnore
280		}
281	}
282
283	if result == ValidationReject {
284		log.Debugf("message validation failed; dropping message from %s", src)
285		v.tracer.RejectMessage(msg, rejectValidationFailed)
286		return
287	}
288
289	// apply async validators
290	if len(async) > 0 {
291		select {
292		case v.validateThrottle <- struct{}{}:
293			go func() {
294				v.doValidateTopic(async, src, msg, result)
295				<-v.validateThrottle
296			}()
297		default:
298			log.Debugf("message validation throttled; dropping message from %s", src)
299			v.tracer.RejectMessage(msg, rejectValidationThrottled)
300		}
301		return
302	}
303
304	if result == ValidationIgnore {
305		v.tracer.RejectMessage(msg, rejectValidationIgnored)
306		return
307	}
308
309	// no async validators, accepted message, send it!
310	v.p.sendMsg <- msg
311}
312
313func (v *validation) validateSignature(msg *Message) bool {
314	err := verifyMessageSignature(msg.Message)
315	if err != nil {
316		log.Debugf("signature verification error: %s", err.Error())
317		return false
318	}
319
320	return true
321}
322
323func (v *validation) doValidateTopic(vals []*topicVal, src peer.ID, msg *Message, r ValidationResult) {
324	result := v.validateTopic(vals, src, msg)
325
326	if result == ValidationAccept && r != ValidationAccept {
327		result = r
328	}
329
330	switch result {
331	case ValidationAccept:
332		v.p.sendMsg <- msg
333	case ValidationReject:
334		log.Debugf("message validation failed; dropping message from %s", src)
335		v.tracer.RejectMessage(msg, rejectValidationFailed)
336		return
337	case ValidationIgnore:
338		log.Debugf("message validation punted; ignoring message from %s", src)
339		v.tracer.RejectMessage(msg, rejectValidationIgnored)
340		return
341	case validationThrottled:
342		log.Debugf("message validation throttled; ignoring message from %s", src)
343		v.tracer.RejectMessage(msg, rejectValidationThrottled)
344
345	default:
346		// BUG: this would be an internal programming error, so a panic seems appropiate.
347		panic(fmt.Errorf("Unexpected validation result: %d", result))
348	}
349}
350
351func (v *validation) validateTopic(vals []*topicVal, src peer.ID, msg *Message) ValidationResult {
352	if len(vals) == 1 {
353		return v.validateSingleTopic(vals[0], src, msg)
354	}
355
356	ctx, cancel := context.WithCancel(v.p.ctx)
357	defer cancel()
358
359	rch := make(chan ValidationResult, len(vals))
360	rcount := 0
361
362	for _, val := range vals {
363		rcount++
364
365		select {
366		case val.validateThrottle <- struct{}{}:
367			go func(val *topicVal) {
368				rch <- val.validateMsg(ctx, src, msg)
369				<-val.validateThrottle
370			}(val)
371
372		default:
373			log.Debugf("validation throttled for topic %s", val.topic)
374			rch <- validationThrottled
375		}
376	}
377
378	result := ValidationAccept
379loop:
380	for i := 0; i < rcount; i++ {
381		switch <-rch {
382		case ValidationAccept:
383		case ValidationReject:
384			result = ValidationReject
385			break loop
386		case ValidationIgnore:
387			// throttled validation has the same effect, but takes precedence over Ignore as it is not
388			// known whether the throttled validator would have signaled rejection.
389			if result != validationThrottled {
390				result = ValidationIgnore
391			}
392		case validationThrottled:
393			result = validationThrottled
394		}
395	}
396
397	return result
398}
399
400// fast path for single topic validation that avoids the extra goroutine
401func (v *validation) validateSingleTopic(val *topicVal, src peer.ID, msg *Message) ValidationResult {
402	select {
403	case val.validateThrottle <- struct{}{}:
404		res := val.validateMsg(v.p.ctx, src, msg)
405		<-val.validateThrottle
406		return res
407
408	default:
409		log.Debugf("validation throttled for topic %s", val.topic)
410		return validationThrottled
411	}
412}
413
414func (val *topicVal) validateMsg(ctx context.Context, src peer.ID, msg *Message) ValidationResult {
415	start := time.Now()
416	defer func() {
417		log.Debugf("validation done; took %s", time.Since(start))
418	}()
419
420	if val.validateTimeout > 0 {
421		var cancel func()
422		ctx, cancel = context.WithTimeout(ctx, val.validateTimeout)
423		defer cancel()
424	}
425
426	r := val.validate(ctx, src, msg)
427	switch r {
428	case ValidationAccept:
429		fallthrough
430	case ValidationReject:
431		fallthrough
432	case ValidationIgnore:
433		return r
434
435	default:
436		log.Warnf("Unexpected result from validator: %d; ignoring message", r)
437		return ValidationIgnore
438	}
439}
440
441/// Options
442
443// WithValidateQueueSize sets the buffer of validate queue. Defaults to 32.
444// When queue is full, validation is throttled and new messages are dropped.
445func WithValidateQueueSize(n int) Option {
446	return func(ps *PubSub) error {
447		if n > 0 {
448			ps.val.validateQ = make(chan *validateReq, n)
449			return nil
450		}
451		return fmt.Errorf("validate queue size must be > 0")
452	}
453}
454
455// WithValidateThrottle sets the upper bound on the number of active validation
456// goroutines across all topics. The default is 8192.
457func WithValidateThrottle(n int) Option {
458	return func(ps *PubSub) error {
459		ps.val.validateThrottle = make(chan struct{}, n)
460		return nil
461	}
462}
463
464// WithValidateWorkers sets the number of synchronous validation worker goroutines.
465// Defaults to NumCPU.
466//
467// The synchronous validation workers perform signature validation, apply inline
468// user validators, and schedule asynchronous user validators.
469// You can adjust this parameter to devote less cpu time to synchronous validation.
470func WithValidateWorkers(n int) Option {
471	return func(ps *PubSub) error {
472		if n > 0 {
473			ps.val.validateWorkers = n
474			return nil
475		}
476		return fmt.Errorf("number of validation workers must be > 0")
477	}
478}
479
480// WithValidatorTimeout is an option that sets a timeout for an (asynchronous) topic validator.
481// By default there is no timeout in asynchronous validators.
482func WithValidatorTimeout(timeout time.Duration) ValidatorOpt {
483	return func(addVal *addValReq) error {
484		addVal.timeout = timeout
485		return nil
486	}
487}
488
489// WithValidatorConcurrency is an option that sets the topic validator throttle.
490// This controls the number of active validation goroutines for the topic; the default is 1024.
491func WithValidatorConcurrency(n int) ValidatorOpt {
492	return func(addVal *addValReq) error {
493		addVal.throttle = n
494		return nil
495	}
496}
497
498// WithValidatorInline is an option that sets the validation disposition to synchronous:
499// it will be executed inline in validation front-end, without spawning a new goroutine.
500// This is suitable for simple or cpu-bound validators that do not block.
501func WithValidatorInline(inline bool) ValidatorOpt {
502	return func(addVal *addValReq) error {
503		addVal.inline = inline
504		return nil
505	}
506}
507