1package circuitbreaker
2
3import (
4	"context"
5	"net/http"
6
7	"github.com/opentracing/opentracing-go/ext"
8	"github.com/traefik/traefik/v2/pkg/config/dynamic"
9	"github.com/traefik/traefik/v2/pkg/log"
10	"github.com/traefik/traefik/v2/pkg/middlewares"
11	"github.com/traefik/traefik/v2/pkg/tracing"
12	"github.com/vulcand/oxy/cbreaker"
13)
14
15const (
16	typeName = "CircuitBreaker"
17)
18
19type circuitBreaker struct {
20	circuitBreaker *cbreaker.CircuitBreaker
21	name           string
22}
23
24// New creates a new circuit breaker middleware.
25func New(ctx context.Context, next http.Handler, confCircuitBreaker dynamic.CircuitBreaker, name string) (http.Handler, error) {
26	expression := confCircuitBreaker.Expression
27
28	logger := log.FromContext(middlewares.GetLoggerCtx(ctx, name, typeName))
29	logger.Debug("Creating middleware")
30	logger.Debug("Setting up with expression: %s", expression)
31
32	oxyCircuitBreaker, err := cbreaker.New(next, expression, createCircuitBreakerOptions(expression))
33	if err != nil {
34		return nil, err
35	}
36	return &circuitBreaker{
37		circuitBreaker: oxyCircuitBreaker,
38		name:           name,
39	}, nil
40}
41
42// NewCircuitBreakerOptions returns a new CircuitBreakerOption.
43func createCircuitBreakerOptions(expression string) cbreaker.CircuitBreakerOption {
44	return cbreaker.Fallback(http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) {
45		tracing.SetErrorWithEvent(req, "blocked by circuit-breaker (%q)", expression)
46		rw.WriteHeader(http.StatusServiceUnavailable)
47
48		if _, err := rw.Write([]byte(http.StatusText(http.StatusServiceUnavailable))); err != nil {
49			log.FromContext(req.Context()).Error(err)
50		}
51	}))
52}
53
54func (c *circuitBreaker) GetTracingInformation() (string, ext.SpanKindEnum) {
55	return c.name, tracing.SpanKindNoneEnum
56}
57
58func (c *circuitBreaker) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
59	c.circuitBreaker.ServeHTTP(rw, req)
60}
61