1// Copyright 2017 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_ctxtags
5
6import (
7	"reflect"
8)
9
10// RequestFieldExtractorFunc is a user-provided function that extracts field information from a gRPC request.
11// It is called from tags middleware on arrival of unary request or a server-stream request.
12// Keys and values will be added to the context tags of the request. If there are no fields, you should return a nil.
13type RequestFieldExtractorFunc func(fullMethod string, req interface{}) map[string]interface{}
14
15type requestFieldsExtractor interface {
16	// ExtractRequestFields is a method declared on a Protobuf message that extracts fields from the interface.
17	// The values from the extracted fields should be set in the appendToMap, in order to avoid allocations.
18	ExtractRequestFields(appendToMap map[string]interface{})
19}
20
21// CodeGenRequestFieldExtractor is a function that relies on code-generated functions that export log fields from requests.
22// These are usually coming from a protoc-plugin that generates additional information based on custom field options.
23func CodeGenRequestFieldExtractor(fullMethod string, req interface{}) map[string]interface{} {
24	if ext, ok := req.(requestFieldsExtractor); ok {
25		retMap := make(map[string]interface{})
26		ext.ExtractRequestFields(retMap)
27		if len(retMap) == 0 {
28			return nil
29		}
30		return retMap
31	}
32	return nil
33}
34
35// TagBasedRequestFieldExtractor is a function that relies on Go struct tags to export log fields from requests.
36// These are usually coming from a protoc-plugin, such as Gogo protobuf.
37//
38//  message Metadata {
39//     repeated string tags = 1 [ (gogoproto.moretags) = "log_field:\"meta_tags\"" ];
40//  }
41//
42// The tagName is configurable using the tagName variable. Here it would be "log_field".
43func TagBasedRequestFieldExtractor(tagName string) RequestFieldExtractorFunc {
44	return func(fullMethod string, req interface{}) map[string]interface{} {
45		retMap := make(map[string]interface{})
46		reflectMessageTags(req, retMap, tagName)
47		if len(retMap) == 0 {
48			return nil
49		}
50		return retMap
51	}
52}
53
54func reflectMessageTags(msg interface{}, existingMap map[string]interface{}, tagName string) {
55	v := reflect.ValueOf(msg)
56	// Only deal with pointers to structs.
57	if v.Kind() != reflect.Ptr || v.Elem().Kind() != reflect.Struct {
58		return
59	}
60	// Deref the pointer get to the struct.
61	v = v.Elem()
62	t := v.Type()
63	for i := 0; i < v.NumField(); i++ {
64		field := v.Field(i)
65		kind := field.Kind()
66		// Only recurse down direct pointers, which should only be to nested structs.
67		if (kind == reflect.Ptr || kind == reflect.Interface) && field.CanInterface() {
68			reflectMessageTags(field.Interface(), existingMap, tagName)
69		}
70		// In case of arrays/slices (repeated fields) go down to the concrete type.
71		if kind == reflect.Array || kind == reflect.Slice {
72			if field.Len() == 0 {
73				continue
74			}
75			kind = field.Index(0).Kind()
76		}
77		// Only be interested in
78		if (kind >= reflect.Bool && kind <= reflect.Float64) || kind == reflect.String {
79			if tag := t.Field(i).Tag.Get(tagName); tag != "" {
80				existingMap[tag] = field.Interface()
81			}
82		}
83	}
84	return
85}
86