1// Licensed to Elasticsearch B.V. under one or more contributor
2// license agreements. See the NOTICE file distributed with
3// this work for additional information regarding copyright
4// ownership. Elasticsearch B.V. licenses this file to you under
5// the Apache License, Version 2.0 (the "License"); you may
6// not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9//     http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18package apmzerolog
19
20import (
21	"bytes"
22	"context"
23	"encoding/hex"
24	"encoding/json"
25	"io"
26	"strconv"
27	"time"
28
29	"github.com/pkg/errors"
30	"github.com/rs/zerolog"
31	"github.com/rs/zerolog/pkgerrors"
32
33	"go.elastic.co/apm"
34	"go.elastic.co/apm/stacktrace"
35)
36
37const (
38	// DefaultFatalFlushTimeout is the default value for Writer.FatalFlushTimeout.
39	DefaultFatalFlushTimeout = 5 * time.Second
40
41	// StackSourceLineName is the key for the line number of a stack frame.
42	StackSourceLineName = "line"
43
44	// StackSourceFunctionName is the key for the function name of a stack frame.
45	StackSourceFunctionName = "func"
46)
47
48func init() {
49	stacktrace.RegisterLibraryPackage("github.com/rs/zerolog")
50}
51
52// Writer is an implementation of zerolog.LevelWriter, reporting log records as
53// errors to the APM Server. If TraceContext is used to add trace IDs to the log
54// records, the errors reported will be associated with them.
55//
56// Because we only have access to the serialised form of the log record, we must
57// rely on enough information being encoded into the events. For error stack traces,
58// you must use zerolog's Stack() method, and set zerolog.ErrorStackMarshaler
59// either to github.com/rs/zerolog/pkgerrors.MarshalStack, or to the function
60// apmzerolog.MarshalErrorStack in this package. The pkgerrors.MarshalStack
61// implementation omits some information, whereas apmzerolog is designed to
62// convey the complete file location and fully qualified function name.
63type Writer struct {
64	// Tracer is the apm.Tracer to use for reporting errors.
65	// If Tracer is nil, then apm.DefaultTracer will be used.
66	Tracer *apm.Tracer
67
68	// FatalFlushTimeout is the amount of time to wait while
69	// flushing a fatal log message to the APM Server before
70	// the process is exited. If this is 0, then
71	// DefaultFatalFlushTimeout will be used. If the timeout
72	// is a negative value, then no flushing will be performed.
73	FatalFlushTimeout time.Duration
74}
75
76func (w *Writer) tracer() *apm.Tracer {
77	tracer := w.Tracer
78	if tracer == nil {
79		tracer = apm.DefaultTracer
80	}
81	return tracer
82}
83
84// Write is a no-op.
85func (*Writer) Write(p []byte) (int, error) {
86	return len(p), nil
87}
88
89// WriteLevel decodes the JSON-encoded log record in p, and reports it as an error using w.Tracer.
90func (w *Writer) WriteLevel(level zerolog.Level, p []byte) (int, error) {
91	if level < zerolog.ErrorLevel || level >= zerolog.NoLevel {
92		return len(p), nil
93	}
94	tracer := w.tracer()
95	if !tracer.Active() {
96		return len(p), nil
97	}
98	var logRecord logRecord
99	if err := logRecord.decode(bytes.NewReader(p)); err != nil {
100		return 0, err
101	}
102
103	errlog := tracer.NewErrorLog(apm.ErrorLogRecord{
104		Level:   level.String(),
105		Message: logRecord.message,
106		Error:   logRecord.err,
107	})
108	if !logRecord.timestamp.IsZero() {
109		errlog.Timestamp = logRecord.timestamp
110	}
111	errlog.Handled = true
112	errlog.SetStacktrace(1)
113	errlog.TraceID = logRecord.traceID
114	errlog.TransactionID = logRecord.transactionID
115	if logRecord.spanID.Validate() == nil {
116		errlog.ParentID = logRecord.spanID
117	} else {
118		errlog.ParentID = logRecord.transactionID
119	}
120	errlog.Send()
121
122	if level == zerolog.FatalLevel {
123		// Zap will exit the process following a fatal log message, so we flush the tracer.
124		flushTimeout := w.FatalFlushTimeout
125		if flushTimeout == 0 {
126			flushTimeout = DefaultFatalFlushTimeout
127		}
128		if flushTimeout >= 0 {
129			ctx, cancel := context.WithTimeout(context.Background(), flushTimeout)
130			defer cancel()
131			tracer.Flush(ctx.Done())
132		}
133	}
134	return len(p), nil
135}
136
137type logRecord struct {
138	message               string
139	timestamp             time.Time
140	err                   error
141	traceID               apm.TraceID
142	transactionID, spanID apm.SpanID
143}
144
145func (l *logRecord) decode(r io.Reader) (result error) {
146	m := make(map[string]interface{})
147	d := json.NewDecoder(r)
148	d.UseNumber()
149	if err := d.Decode(&m); err != nil {
150		return err
151	}
152
153	l.message, _ = m[zerolog.MessageFieldName].(string)
154	if strval, ok := m[zerolog.TimestampFieldName].(string); ok {
155		if t, err := time.Parse(zerolog.TimeFieldFormat, strval); err == nil {
156			l.timestamp = t.UTC()
157		}
158	}
159	if errmsg, ok := m[zerolog.ErrorFieldName].(string); ok {
160		err := &jsonError{message: errmsg}
161		if stack, ok := m[zerolog.ErrorStackFieldName].([]interface{}); ok {
162			frames := make([]stacktrace.Frame, 0, len(stack))
163			for i := range stack {
164				in, ok := stack[i].(map[string]interface{})
165				if !ok {
166					continue
167				}
168				var frame stacktrace.Frame
169				frame.File, _ = in[pkgerrors.StackSourceFileName].(string)
170				frame.Function, _ = in[StackSourceFunctionName].(string)
171				if strval, ok := in[StackSourceLineName].(string); ok {
172					if line, err := strconv.Atoi(strval); err == nil {
173						frame.Line = line
174					}
175				}
176				frames = append(frames, frame)
177			}
178			err.stack = frames
179		}
180		l.err = err
181	}
182
183	if strval, ok := m[SpanIDFieldName].(string); ok {
184		if err := decodeHex(l.spanID[:], strval); err != nil {
185			return errors.Wrap(err, "invalid span.id")
186		}
187	}
188
189	if strval, ok := m[TraceIDFieldName].(string); ok {
190		if err := decodeHex(l.traceID[:], strval); err != nil {
191			return errors.Wrap(err, "invalid trace.id")
192		}
193	}
194	if strval, ok := m[TransactionIDFieldName].(string); ok {
195		if err := decodeHex(l.transactionID[:], strval); err != nil {
196			return errors.Wrap(err, "invalid transaction.id")
197		}
198	}
199	return nil
200}
201
202func decodeHex(out []byte, in string) error {
203	if n := hex.EncodedLen(len(out)); n != len(in) {
204		return errors.Errorf(
205			"invalid value length (expected %d bytes, got %d)",
206			n, len(in),
207		)
208	}
209	_, err := hex.Decode(out, []byte(in))
210	return err
211}
212
213type jsonError struct {
214	message string
215	stack   []stacktrace.Frame
216}
217
218func (e *jsonError) Type() string {
219	return "error"
220}
221
222func (e *jsonError) Error() string {
223	return e.message
224}
225
226func (e *jsonError) StackTrace() []stacktrace.Frame {
227	return e.stack
228}
229