1package diagnose
2
3import (
4	"context"
5	"fmt"
6	"io"
7	"strings"
8	"time"
9
10	"github.com/hashicorp/vault/sdk/helper/strutil"
11	"go.opentelemetry.io/otel/attribute"
12	"go.opentelemetry.io/otel/codes"
13	sdktrace "go.opentelemetry.io/otel/sdk/trace"
14	"go.opentelemetry.io/otel/trace"
15)
16
17const (
18	warningEventName          = "warning"
19	skippedEventName          = "skipped"
20	actionKey                 = "actionKey"
21	spotCheckOkEventName      = "spot-check-ok"
22	spotCheckWarnEventName    = "spot-check-warn"
23	spotCheckErrorEventName   = "spot-check-error"
24	spotCheckSkippedEventName = "spot-check-skipped"
25	adviceEventName           = "advice"
26	errorMessageKey           = attribute.Key("error.message")
27	nameKey                   = attribute.Key("name")
28	messageKey                = attribute.Key("message")
29	adviceKey                 = attribute.Key("advice")
30)
31
32var (
33	MainSection = trace.WithAttributes(attribute.Key("diagnose").String("main-section"))
34)
35
36var diagnoseSession = struct{}{}
37var noopTracer = trace.NewNoopTracerProvider().Tracer("vault-diagnose")
38
39type testFunction func(context.Context) error
40
41type Session struct {
42	tc          *TelemetryCollector
43	tracer      trace.Tracer
44	tp          *sdktrace.TracerProvider
45	SkipFilters []string
46}
47
48// New initializes a Diagnose tracing session.  In particular this wires a TelemetryCollector, which
49// synchronously receives and tracks OpenTelemetry spans in order to provide a tree structure of results
50// when the outermost span ends.
51func New(w io.Writer) *Session {
52	tc := NewTelemetryCollector(w)
53	//so, _ := stdout.NewExporter(stdout.WithPrettyPrint())
54	tp := sdktrace.NewTracerProvider(
55		sdktrace.WithSampler(sdktrace.AlwaysSample()),
56		//sdktrace.WithSpanProcessor(sdktrace.NewSimpleSpanProcessor(so)),
57		sdktrace.WithSpanProcessor(tc),
58	)
59	tracer := tp.Tracer("vault-diagnose")
60	sess := &Session{
61		tp:     tp,
62		tc:     tc,
63		tracer: tracer,
64	}
65	return sess
66}
67
68// IsSkipped returns true if skipName is present in the SkipFilters list.  Can be used in combination with Skip to mark a
69// span skipped and conditionally skips some logic.
70func (s *Session) IsSkipped(spanName string) bool {
71	return strutil.StrListContainsCaseInsensitive(s.SkipFilters, spanName)
72}
73
74// Context returns a new context with a defined diagnose session
75func Context(ctx context.Context, sess *Session) context.Context {
76	return context.WithValue(ctx, diagnoseSession, sess)
77}
78
79// CurrentSession retrieves the active diagnose session from the context, or nil if none.
80func CurrentSession(ctx context.Context) *Session {
81	sessionCtxVal := ctx.Value(diagnoseSession)
82	if sessionCtxVal != nil {
83
84		return sessionCtxVal.(*Session)
85
86	}
87	return nil
88}
89
90// Finalize ends the Diagnose session, returning the root of the result tree.  This will be empty until
91// the outermost span ends.
92func (s *Session) Finalize(ctx context.Context) *Result {
93	s.tp.ForceFlush(ctx)
94	return s.tc.RootResult
95}
96
97// StartSpan starts a "diagnose" span, which is really just an OpenTelemetry Tracing span.
98func StartSpan(ctx context.Context, spanName string, options ...trace.SpanOption) (context.Context, trace.Span) {
99	session := CurrentSession(ctx)
100	if session != nil {
101		return session.tracer.Start(ctx, spanName, options...)
102	} else {
103		return noopTracer.Start(ctx, spanName, options...)
104	}
105}
106
107// Success sets the span to Successful (overriding any previous status) and sets the message to the input.
108func Success(ctx context.Context, message string) {
109	span := trace.SpanFromContext(ctx)
110	span.SetStatus(codes.Ok, message)
111}
112
113// Fail records a failure in the current span
114func Fail(ctx context.Context, message string) {
115	span := trace.SpanFromContext(ctx)
116	span.SetStatus(codes.Error, message)
117}
118
119// Error records an error in the current span (but unlike Fail, doesn't set the overall span status to Error)
120func Error(ctx context.Context, err error, options ...trace.EventOption) error {
121	span := trace.SpanFromContext(ctx)
122	span.RecordError(err, options...)
123	return err
124}
125
126// Skipped marks the current span skipped
127func Skipped(ctx context.Context, message string) {
128	span := trace.SpanFromContext(ctx)
129	span.AddEvent(skippedEventName)
130	span.SetStatus(codes.Error, message)
131}
132
133// Warn records a warning on the current span
134func Warn(ctx context.Context, msg string) {
135	span := trace.SpanFromContext(ctx)
136	span.AddEvent(warningEventName, trace.WithAttributes(messageKey.String(msg)))
137}
138
139// SpotOk adds an Ok result without adding a new Span.  This should be used for instantaneous checks with no
140// possible sub-spans
141func SpotOk(ctx context.Context, checkName, message string, options ...trace.EventOption) {
142	addSpotCheckResult(ctx, spotCheckOkEventName, checkName, message, options...)
143}
144
145// SpotWarn adds a Warning result without adding a new Span.  This should be used for instantaneous checks with no
146// possible sub-spans
147func SpotWarn(ctx context.Context, checkName, message string, options ...trace.EventOption) {
148	addSpotCheckResult(ctx, spotCheckWarnEventName, checkName, message, options...)
149}
150
151// SpotError adds an Error result without adding a new Span.  This should be used for instantaneous checks with no
152// possible sub-spans
153func SpotError(ctx context.Context, checkName string, err error, options ...trace.EventOption) error {
154	var message string
155	if err != nil {
156		message = err.Error()
157	}
158	addSpotCheckResult(ctx, spotCheckErrorEventName, checkName, message, options...)
159	return err
160}
161
162// SpotSkipped adds a Skipped result without adding a new Span.
163func SpotSkipped(ctx context.Context, checkName, message string, options ...trace.EventOption) {
164	addSpotCheckResult(ctx, spotCheckSkippedEventName, checkName, message, options...)
165}
166
167// Advice builds an EventOption containing advice message.  Use to add to spot results.
168func Advice(message string) trace.EventOption {
169	return trace.WithAttributes(adviceKey.String(message))
170}
171
172// Advise adds advice to the current diagnose span
173func Advise(ctx context.Context, message string) {
174	span := trace.SpanFromContext(ctx)
175	span.AddEvent(adviceEventName, Advice(message))
176}
177
178func addSpotCheckResult(ctx context.Context, eventName, checkName, message string, options ...trace.EventOption) {
179	span := trace.SpanFromContext(ctx)
180	attrs := append(options, trace.WithAttributes(nameKey.String(checkName)))
181	if message != "" {
182		attrs = append(attrs, trace.WithAttributes(messageKey.String(message)))
183	}
184	span.AddEvent(eventName, attrs...)
185}
186
187func SpotCheck(ctx context.Context, checkName string, f func() error) error {
188	sess := CurrentSession(ctx)
189	if sess.IsSkipped(checkName) {
190		SpotSkipped(ctx, checkName, "skipped as requested")
191		return nil
192	}
193
194	err := f()
195	if err != nil {
196		SpotError(ctx, checkName, err)
197		return err
198	} else {
199		SpotOk(ctx, checkName, "")
200	}
201	return nil
202}
203
204// Test creates a new named span, and executes the provided function within it.  If the function returns an error,
205// the span is considered to have failed.
206func Test(ctx context.Context, spanName string, function testFunction, options ...trace.SpanOption) error {
207	ctx, span := StartSpan(ctx, spanName, options...)
208	defer span.End()
209	sess := CurrentSession(ctx)
210	if sess.IsSkipped(spanName) {
211		Skipped(ctx, "skipped as requested")
212		return nil
213	}
214
215	err := function(ctx)
216	if err != nil {
217		span.SetStatus(codes.Error, err.Error())
218	}
219	return err
220}
221
222// WithTimeout wraps a context consuming function, and when called, returns an error if the sub-function does not
223// complete within the timeout, e.g.
224//
225// diagnose.Test(ctx, "my-span", diagnose.WithTimeout(5 * time.Second, myTestFunc))
226func WithTimeout(d time.Duration, f testFunction) testFunction {
227	return func(ctx context.Context) error {
228		rch := make(chan error)
229		t := time.NewTimer(d)
230		defer t.Stop()
231		go func() { rch <- f(ctx) }()
232		select {
233		case <-t.C:
234			return fmt.Errorf("Timeout after %s.", d.String())
235		case err := <-rch:
236			return err
237		}
238	}
239}
240
241// CapitalizeFirstLetter returns a string with the first letter capitalized
242func CapitalizeFirstLetter(msg string) string {
243	words := strings.Split(msg, " ")
244	if len(words) == 0 {
245		return ""
246	}
247	if len(words) > 1 {
248		return strings.Title(words[0]) + " " + strings.Join(words[1:], " ")
249	}
250	return strings.Title(words[0])
251}
252