1package pgs
2
3import (
4	"errors"
5	"fmt"
6	"reflect"
7
8	"github.com/golang/protobuf/proto"
9)
10
11// An Extension is a custom option annotation that can be applied to an Entity to provide additional
12// semantic details and metadata about the Entity.
13type Extension interface {
14	Field
15
16	// ParentEntity returns the ParentEntity where the Extension is defined
17	DefinedIn() ParentEntity
18
19	// Extendee returns the Message that the Extension is extending
20	Extendee() Message
21
22	setExtendee(m Message)
23}
24
25type ext struct {
26	field
27
28	parent   ParentEntity
29	extendee Message
30	fqn      string
31}
32
33func (e *ext) FullyQualifiedName() string { return e.fqn }
34func (e *ext) Syntax() Syntax             { return e.parent.Syntax() }
35func (e *ext) Package() Package           { return e.parent.Package() }
36func (e *ext) File() File                 { return e.parent.File() }
37func (e *ext) BuildTarget() bool          { return e.parent.BuildTarget() }
38func (e *ext) DefinedIn() ParentEntity    { return e.parent }
39func (e *ext) Extendee() Message          { return e.extendee }
40func (e *ext) Message() Message           { return nil }
41func (e *ext) InOneOf() bool              { return false }
42func (e *ext) OneOf() OneOf               { return nil }
43func (e *ext) setMessage(m Message)       {} // noop
44func (e *ext) setOneOf(o OneOf)           {} // noop
45func (e *ext) setExtendee(m Message)      { e.extendee = m }
46
47func (e *ext) accept(v Visitor) (err error) {
48	if v == nil {
49		return
50	}
51
52	_, err = v.VisitExtension(e)
53	return
54}
55
56var extractor extExtractor
57
58func init() { extractor = protoExtExtractor{} }
59
60type extExtractor interface {
61	HasExtension(proto.Message, *proto.ExtensionDesc) bool
62	GetExtension(proto.Message, *proto.ExtensionDesc) (interface{}, error)
63}
64
65type protoExtExtractor struct{}
66
67func (e protoExtExtractor) HasExtension(pb proto.Message, ext *proto.ExtensionDesc) bool {
68	return proto.HasExtension(pb, ext)
69}
70
71func (e protoExtExtractor) GetExtension(pb proto.Message, ext *proto.ExtensionDesc) (interface{}, error) {
72	return proto.GetExtension(pb, ext)
73}
74
75func extension(opts proto.Message, e *proto.ExtensionDesc, out interface{}) (bool, error) {
76	if opts == nil || reflect.ValueOf(opts).IsNil() {
77		return false, nil
78	}
79
80	if e == nil {
81		return false, errors.New("nil *proto.ExtensionDesc parameter provided")
82	}
83
84	if out == nil {
85		return false, errors.New("nil extension output parameter provided")
86	}
87
88	o := reflect.ValueOf(out)
89	if o.Kind() != reflect.Ptr {
90		return false, errors.New("out parameter must be a pointer type")
91	}
92
93	if !extractor.HasExtension(opts, e) {
94		return false, nil
95	}
96
97	val, err := extractor.GetExtension(opts, e)
98	if err != nil || val == nil {
99		return false, err
100	}
101
102	v := reflect.ValueOf(val)
103	for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface {
104		v = v.Elem()
105	}
106
107	for o.Kind() == reflect.Ptr || o.Kind() == reflect.Interface {
108		if o.Kind() == reflect.Ptr && o.IsNil() {
109			o.Set(reflect.New(o.Type().Elem()))
110		}
111		o = o.Elem()
112	}
113
114	if v.Type().AssignableTo(o.Type()) {
115		o.Set(v)
116		return true, nil
117	}
118
119	return true, fmt.Errorf("cannot assign extension type %q to output type %q",
120		v.Type().String(),
121		o.Type().String())
122}
123