1/*
2   Copyright The containerd Authors.
3
4   Licensed under the Apache License, Version 2.0 (the "License");
5   you may not use this file except in compliance with the License.
6   You may obtain a copy of the License at
7
8       http://www.apache.org/licenses/LICENSE-2.0
9
10   Unless required by applicable law or agreed to in writing, software
11   distributed under the License is distributed on an "AS IS" BASIS,
12   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   See the License for the specific language governing permissions and
14   limitations under the License.
15*/
16
17package typeurl
18
19import (
20	"encoding/json"
21	"path"
22	"reflect"
23	"sync"
24
25	"github.com/gogo/protobuf/proto"
26	"github.com/gogo/protobuf/types"
27	"github.com/pkg/errors"
28)
29
30var (
31	mu       sync.Mutex
32	registry = make(map[reflect.Type]string)
33)
34
35var ErrNotFound = errors.New("not found")
36
37// Register a type with the base url of the type
38func Register(v interface{}, args ...string) {
39	var (
40		t = tryDereference(v)
41		p = path.Join(args...)
42	)
43	mu.Lock()
44	defer mu.Unlock()
45	if et, ok := registry[t]; ok {
46		if et != p {
47			panic(errors.Errorf("type registred with alternate path %q != %q", et, p))
48		}
49		return
50	}
51	registry[t] = p
52}
53
54// TypeURL returns the type url for a registred type
55func TypeURL(v interface{}) (string, error) {
56	mu.Lock()
57	u, ok := registry[tryDereference(v)]
58	mu.Unlock()
59	if !ok {
60		// fallback to the proto registry if it is a proto message
61		pb, ok := v.(proto.Message)
62		if !ok {
63			return "", errors.Wrapf(ErrNotFound, "type %s", reflect.TypeOf(v))
64		}
65		return proto.MessageName(pb), nil
66	}
67	return u, nil
68}
69
70// Is returns true if the type of the Any is the same as v
71func Is(any *types.Any, v interface{}) bool {
72	// call to check that v is a pointer
73	tryDereference(v)
74	url, err := TypeURL(v)
75	if err != nil {
76		return false
77	}
78	return any.TypeUrl == url
79}
80
81// MarshalAny marshals the value v into an any with the correct TypeUrl.
82// If the provided object is already a proto.Any message, then it will be
83// returned verbatim. If it is of type proto.Message, it will be marshaled as a
84// protocol buffer. Otherwise, the object will be marshaled to json.
85func MarshalAny(v interface{}) (*types.Any, error) {
86	var marshal func(v interface{}) ([]byte, error)
87	switch t := v.(type) {
88	case *types.Any:
89		// avoid reserializing the type if we have an any.
90		return t, nil
91	case proto.Message:
92		marshal = func(v interface{}) ([]byte, error) {
93			return proto.Marshal(t)
94		}
95	default:
96		marshal = json.Marshal
97	}
98
99	url, err := TypeURL(v)
100	if err != nil {
101		return nil, err
102	}
103
104	data, err := marshal(v)
105	if err != nil {
106		return nil, err
107	}
108	return &types.Any{
109		TypeUrl: url,
110		Value:   data,
111	}, nil
112}
113
114// UnmarshalAny unmarshals the any type into a concrete type
115func UnmarshalAny(any *types.Any) (interface{}, error) {
116	t, err := getTypeByUrl(any.TypeUrl)
117	if err != nil {
118		return nil, err
119	}
120	v := reflect.New(t.t).Interface()
121	if t.isProto {
122		err = proto.Unmarshal(any.Value, v.(proto.Message))
123	} else {
124		err = json.Unmarshal(any.Value, v)
125	}
126	return v, err
127}
128
129type urlType struct {
130	t       reflect.Type
131	isProto bool
132}
133
134func getTypeByUrl(url string) (urlType, error) {
135	for t, u := range registry {
136		if u == url {
137			return urlType{
138				t: t,
139			}, nil
140		}
141	}
142	// fallback to proto registry
143	t := proto.MessageType(url)
144	if t != nil {
145		return urlType{
146			// get the underlying Elem because proto returns a pointer to the type
147			t:       t.Elem(),
148			isProto: true,
149		}, nil
150	}
151	return urlType{}, errors.Wrapf(ErrNotFound, "type with url %s", url)
152}
153
154func tryDereference(v interface{}) reflect.Type {
155	t := reflect.TypeOf(v)
156	if t.Kind() == reflect.Ptr {
157		// require check of pointer but dereference to register
158		return t.Elem()
159	}
160	panic("v is not a pointer to a type")
161}
162