1package redisotel
2
3import (
4	"context"
5
6	"github.com/go-redis/redis/extra/rediscmd"
7	"github.com/go-redis/redis/v8"
8	"go.opentelemetry.io/otel"
9	"go.opentelemetry.io/otel/attribute"
10	"go.opentelemetry.io/otel/codes"
11	"go.opentelemetry.io/otel/trace"
12)
13
14var tracer = otel.Tracer("github.com/go-redis/redis")
15
16type TracingHook struct{}
17
18var _ redis.Hook = TracingHook{}
19
20func (TracingHook) BeforeProcess(ctx context.Context, cmd redis.Cmder) (context.Context, error) {
21	if !trace.SpanFromContext(ctx).IsRecording() {
22		return ctx, nil
23	}
24
25	ctx, span := tracer.Start(ctx, cmd.FullName())
26	span.SetAttributes(
27		attribute.String("db.system", "redis"),
28		attribute.String("db.statement", rediscmd.CmdString(cmd)),
29	)
30
31	return ctx, nil
32}
33
34func (TracingHook) AfterProcess(ctx context.Context, cmd redis.Cmder) error {
35	span := trace.SpanFromContext(ctx)
36	if err := cmd.Err(); err != nil {
37		recordError(ctx, span, err)
38	}
39	span.End()
40	return nil
41}
42
43func (TracingHook) BeforeProcessPipeline(ctx context.Context, cmds []redis.Cmder) (context.Context, error) {
44	if !trace.SpanFromContext(ctx).IsRecording() {
45		return ctx, nil
46	}
47
48	summary, cmdsString := rediscmd.CmdsString(cmds)
49
50	ctx, span := tracer.Start(ctx, "pipeline "+summary)
51	span.SetAttributes(
52		attribute.String("db.system", "redis"),
53		attribute.Int("db.redis.num_cmd", len(cmds)),
54		attribute.String("db.statement", cmdsString),
55	)
56
57	return ctx, nil
58}
59
60func (TracingHook) AfterProcessPipeline(ctx context.Context, cmds []redis.Cmder) error {
61	span := trace.SpanFromContext(ctx)
62	if err := cmds[0].Err(); err != nil {
63		recordError(ctx, span, err)
64	}
65	span.End()
66	return nil
67}
68
69func recordError(ctx context.Context, span trace.Span, err error) {
70	if err != redis.Nil {
71		span.RecordError(err)
72		span.SetStatus(codes.Error, err.Error())
73	}
74}
75