1package pb
2
3import (
4	"encoding/json"
5	"errors"
6	"time"
7
8	"github.com/golang/protobuf/ptypes"
9	"github.com/hashicorp/vault/sdk/helper/errutil"
10	"github.com/hashicorp/vault/sdk/helper/parseutil"
11	"github.com/hashicorp/vault/sdk/helper/wrapping"
12	"github.com/hashicorp/vault/sdk/logical"
13)
14
15const (
16	ErrTypeUnknown uint32 = iota
17	ErrTypeUserError
18	ErrTypeInternalError
19	ErrTypeCodedError
20	ErrTypeStatusBadRequest
21	ErrTypeUnsupportedOperation
22	ErrTypeUnsupportedPath
23	ErrTypeInvalidRequest
24	ErrTypePermissionDenied
25	ErrTypeMultiAuthzPending
26)
27
28func ProtoErrToErr(e *ProtoError) error {
29	if e == nil {
30		return nil
31	}
32
33	var err error
34	switch e.ErrType {
35	case ErrTypeUnknown:
36		err = errors.New(e.ErrMsg)
37	case ErrTypeUserError:
38		err = errutil.UserError{Err: e.ErrMsg}
39	case ErrTypeInternalError:
40		err = errutil.InternalError{Err: e.ErrMsg}
41	case ErrTypeCodedError:
42		err = logical.CodedError(int(e.ErrCode), e.ErrMsg)
43	case ErrTypeStatusBadRequest:
44		err = &logical.StatusBadRequest{Err: e.ErrMsg}
45	case ErrTypeUnsupportedOperation:
46		err = logical.ErrUnsupportedOperation
47	case ErrTypeUnsupportedPath:
48		err = logical.ErrUnsupportedPath
49	case ErrTypeInvalidRequest:
50		err = logical.ErrInvalidRequest
51	case ErrTypePermissionDenied:
52		err = logical.ErrPermissionDenied
53	case ErrTypeMultiAuthzPending:
54		err = logical.ErrMultiAuthzPending
55	}
56
57	return err
58}
59
60func ErrToProtoErr(e error) *ProtoError {
61	if e == nil {
62		return nil
63	}
64	pbErr := &ProtoError{
65		ErrMsg:  e.Error(),
66		ErrType: ErrTypeUnknown,
67	}
68
69	switch e.(type) {
70	case errutil.UserError:
71		pbErr.ErrType = ErrTypeUserError
72	case errutil.InternalError:
73		pbErr.ErrType = ErrTypeInternalError
74	case logical.HTTPCodedError:
75		pbErr.ErrType = ErrTypeCodedError
76		pbErr.ErrCode = int64(e.(logical.HTTPCodedError).Code())
77	case *logical.StatusBadRequest:
78		pbErr.ErrType = ErrTypeStatusBadRequest
79	}
80
81	switch {
82	case e == logical.ErrUnsupportedOperation:
83		pbErr.ErrType = ErrTypeUnsupportedOperation
84	case e == logical.ErrUnsupportedPath:
85		pbErr.ErrType = ErrTypeUnsupportedPath
86	case e == logical.ErrInvalidRequest:
87		pbErr.ErrType = ErrTypeInvalidRequest
88	case e == logical.ErrPermissionDenied:
89		pbErr.ErrType = ErrTypePermissionDenied
90	case e == logical.ErrMultiAuthzPending:
91		pbErr.ErrType = ErrTypeMultiAuthzPending
92	}
93
94	return pbErr
95}
96
97func ErrToString(e error) string {
98	if e == nil {
99		return ""
100	}
101
102	return e.Error()
103}
104
105func LogicalStorageEntryToProtoStorageEntry(e *logical.StorageEntry) *StorageEntry {
106	if e == nil {
107		return nil
108	}
109
110	return &StorageEntry{
111		Key:      e.Key,
112		Value:    e.Value,
113		SealWrap: e.SealWrap,
114	}
115}
116
117func ProtoStorageEntryToLogicalStorageEntry(e *StorageEntry) *logical.StorageEntry {
118	if e == nil {
119		return nil
120	}
121
122	return &logical.StorageEntry{
123		Key:      e.Key,
124		Value:    e.Value,
125		SealWrap: e.SealWrap,
126	}
127}
128
129func ProtoLeaseOptionsToLogicalLeaseOptions(l *LeaseOptions) (logical.LeaseOptions, error) {
130	if l == nil {
131		return logical.LeaseOptions{}, nil
132	}
133
134	t, err := ptypes.Timestamp(l.IssueTime)
135	return logical.LeaseOptions{
136		TTL:       time.Duration(l.TTL),
137		Renewable: l.Renewable,
138		Increment: time.Duration(l.Increment),
139		IssueTime: t,
140		MaxTTL:    time.Duration(l.MaxTTL),
141	}, err
142}
143
144func LogicalLeaseOptionsToProtoLeaseOptions(l logical.LeaseOptions) (*LeaseOptions, error) {
145	t, err := ptypes.TimestampProto(l.IssueTime)
146	if err != nil {
147		return nil, err
148	}
149
150	return &LeaseOptions{
151		TTL:       int64(l.TTL),
152		Renewable: l.Renewable,
153		Increment: int64(l.Increment),
154		IssueTime: t,
155		MaxTTL:    int64(l.MaxTTL),
156	}, err
157}
158
159func ProtoSecretToLogicalSecret(s *Secret) (*logical.Secret, error) {
160	if s == nil {
161		return nil, nil
162	}
163
164	data := map[string]interface{}{}
165	err := json.Unmarshal([]byte(s.InternalData), &data)
166	if err != nil {
167		return nil, err
168	}
169
170	lease, err := ProtoLeaseOptionsToLogicalLeaseOptions(s.LeaseOptions)
171	if err != nil {
172		return nil, err
173	}
174
175	return &logical.Secret{
176		LeaseOptions: lease,
177		InternalData: data,
178		LeaseID:      s.LeaseID,
179	}, nil
180}
181
182func LogicalSecretToProtoSecret(s *logical.Secret) (*Secret, error) {
183	if s == nil {
184		return nil, nil
185	}
186
187	buf, err := json.Marshal(s.InternalData)
188	if err != nil {
189		return nil, err
190	}
191
192	lease, err := LogicalLeaseOptionsToProtoLeaseOptions(s.LeaseOptions)
193	if err != nil {
194		return nil, err
195	}
196
197	return &Secret{
198		LeaseOptions: lease,
199		InternalData: string(buf[:]),
200		LeaseID:      s.LeaseID,
201	}, err
202}
203
204func LogicalRequestToProtoRequest(r *logical.Request) (*Request, error) {
205	if r == nil {
206		return nil, nil
207	}
208
209	buf, err := json.Marshal(r.Data)
210	if err != nil {
211		return nil, err
212	}
213
214	secret, err := LogicalSecretToProtoSecret(r.Secret)
215	if err != nil {
216		return nil, err
217	}
218
219	auth, err := LogicalAuthToProtoAuth(r.Auth)
220	if err != nil {
221		return nil, err
222	}
223
224	headers := map[string]*Header{}
225	for k, v := range r.Headers {
226		headers[k] = &Header{Header: v}
227	}
228
229	return &Request{
230		ID:                       r.ID,
231		ReplicationCluster:       r.ReplicationCluster,
232		Operation:                string(r.Operation),
233		Path:                     r.Path,
234		Data:                     string(buf[:]),
235		Secret:                   secret,
236		Auth:                     auth,
237		Headers:                  headers,
238		ClientToken:              r.ClientToken,
239		ClientTokenAccessor:      r.ClientTokenAccessor,
240		DisplayName:              r.DisplayName,
241		MountPoint:               r.MountPoint,
242		MountType:                r.MountType,
243		MountAccessor:            r.MountAccessor,
244		WrapInfo:                 LogicalRequestWrapInfoToProtoRequestWrapInfo(r.WrapInfo),
245		ClientTokenRemainingUses: int64(r.ClientTokenRemainingUses),
246		Connection:               LogicalConnectionToProtoConnection(r.Connection),
247		EntityID:                 r.EntityID,
248		PolicyOverride:           r.PolicyOverride,
249		Unauthenticated:          r.Unauthenticated,
250	}, nil
251}
252
253func ProtoRequestToLogicalRequest(r *Request) (*logical.Request, error) {
254	if r == nil {
255		return nil, nil
256	}
257
258	data := map[string]interface{}{}
259	err := json.Unmarshal([]byte(r.Data), &data)
260	if err != nil {
261		return nil, err
262	}
263
264	secret, err := ProtoSecretToLogicalSecret(r.Secret)
265	if err != nil {
266		return nil, err
267	}
268
269	auth, err := ProtoAuthToLogicalAuth(r.Auth)
270	if err != nil {
271		return nil, err
272	}
273
274	var headers map[string][]string
275	if len(r.Headers) > 0 {
276		headers = make(map[string][]string, len(r.Headers))
277		for k, v := range r.Headers {
278			headers[k] = v.Header
279		}
280	}
281
282	return &logical.Request{
283		ID:                       r.ID,
284		ReplicationCluster:       r.ReplicationCluster,
285		Operation:                logical.Operation(r.Operation),
286		Path:                     r.Path,
287		Data:                     data,
288		Secret:                   secret,
289		Auth:                     auth,
290		Headers:                  headers,
291		ClientToken:              r.ClientToken,
292		ClientTokenAccessor:      r.ClientTokenAccessor,
293		DisplayName:              r.DisplayName,
294		MountPoint:               r.MountPoint,
295		MountType:                r.MountType,
296		MountAccessor:            r.MountAccessor,
297		WrapInfo:                 ProtoRequestWrapInfoToLogicalRequestWrapInfo(r.WrapInfo),
298		ClientTokenRemainingUses: int(r.ClientTokenRemainingUses),
299		Connection:               ProtoConnectionToLogicalConnection(r.Connection),
300		EntityID:                 r.EntityID,
301		PolicyOverride:           r.PolicyOverride,
302		Unauthenticated:          r.Unauthenticated,
303	}, nil
304}
305
306func LogicalConnectionToProtoConnection(c *logical.Connection) *Connection {
307	if c == nil {
308		return nil
309	}
310
311	return &Connection{
312		RemoteAddr: c.RemoteAddr,
313	}
314}
315
316func ProtoConnectionToLogicalConnection(c *Connection) *logical.Connection {
317	if c == nil {
318		return nil
319	}
320
321	return &logical.Connection{
322		RemoteAddr: c.RemoteAddr,
323	}
324}
325
326func LogicalRequestWrapInfoToProtoRequestWrapInfo(i *logical.RequestWrapInfo) *RequestWrapInfo {
327	if i == nil {
328		return nil
329	}
330
331	return &RequestWrapInfo{
332		TTL:      int64(i.TTL),
333		Format:   i.Format,
334		SealWrap: i.SealWrap,
335	}
336}
337
338func ProtoRequestWrapInfoToLogicalRequestWrapInfo(i *RequestWrapInfo) *logical.RequestWrapInfo {
339	if i == nil {
340		return nil
341	}
342
343	return &logical.RequestWrapInfo{
344		TTL:      time.Duration(i.TTL),
345		Format:   i.Format,
346		SealWrap: i.SealWrap,
347	}
348}
349
350func ProtoResponseToLogicalResponse(r *Response) (*logical.Response, error) {
351	if r == nil {
352		return nil, nil
353	}
354
355	secret, err := ProtoSecretToLogicalSecret(r.Secret)
356	if err != nil {
357		return nil, err
358	}
359
360	auth, err := ProtoAuthToLogicalAuth(r.Auth)
361	if err != nil {
362		return nil, err
363	}
364
365	data := map[string]interface{}{}
366	err = json.Unmarshal([]byte(r.Data), &data)
367	if err != nil {
368		return nil, err
369	}
370
371	wrapInfo, err := ProtoResponseWrapInfoToLogicalResponseWrapInfo(r.WrapInfo)
372	if err != nil {
373		return nil, err
374	}
375
376	var headers map[string][]string
377	if len(r.Headers) > 0 {
378		headers = make(map[string][]string, len(r.Headers))
379		for k, v := range r.Headers {
380			headers[k] = v.Header
381		}
382	}
383
384	return &logical.Response{
385		Secret:   secret,
386		Auth:     auth,
387		Data:     data,
388		Redirect: r.Redirect,
389		Warnings: r.Warnings,
390		WrapInfo: wrapInfo,
391		Headers:  headers,
392	}, nil
393}
394
395func ProtoResponseWrapInfoToLogicalResponseWrapInfo(i *ResponseWrapInfo) (*wrapping.ResponseWrapInfo, error) {
396	if i == nil {
397		return nil, nil
398	}
399
400	t, err := ptypes.Timestamp(i.CreationTime)
401	if err != nil {
402		return nil, err
403	}
404
405	return &wrapping.ResponseWrapInfo{
406		TTL:             time.Duration(i.TTL),
407		Token:           i.Token,
408		Accessor:        i.Accessor,
409		CreationTime:    t,
410		WrappedAccessor: i.WrappedAccessor,
411		WrappedEntityID: i.WrappedEntityID,
412		Format:          i.Format,
413		CreationPath:    i.CreationPath,
414		SealWrap:        i.SealWrap,
415	}, nil
416}
417
418func LogicalResponseWrapInfoToProtoResponseWrapInfo(i *wrapping.ResponseWrapInfo) (*ResponseWrapInfo, error) {
419	if i == nil {
420		return nil, nil
421	}
422
423	t, err := ptypes.TimestampProto(i.CreationTime)
424	if err != nil {
425		return nil, err
426	}
427
428	return &ResponseWrapInfo{
429		TTL:             int64(i.TTL),
430		Token:           i.Token,
431		Accessor:        i.Accessor,
432		CreationTime:    t,
433		WrappedAccessor: i.WrappedAccessor,
434		WrappedEntityID: i.WrappedEntityID,
435		Format:          i.Format,
436		CreationPath:    i.CreationPath,
437		SealWrap:        i.SealWrap,
438	}, nil
439}
440
441func LogicalResponseToProtoResponse(r *logical.Response) (*Response, error) {
442	if r == nil {
443		return nil, nil
444	}
445
446	secret, err := LogicalSecretToProtoSecret(r.Secret)
447	if err != nil {
448		return nil, err
449	}
450
451	auth, err := LogicalAuthToProtoAuth(r.Auth)
452	if err != nil {
453		return nil, err
454	}
455
456	buf, err := json.Marshal(r.Data)
457	if err != nil {
458		return nil, err
459	}
460
461	wrapInfo, err := LogicalResponseWrapInfoToProtoResponseWrapInfo(r.WrapInfo)
462	if err != nil {
463		return nil, err
464	}
465
466	headers := map[string]*Header{}
467	for k, v := range r.Headers {
468		headers[k] = &Header{Header: v}
469	}
470
471	return &Response{
472		Secret:   secret,
473		Auth:     auth,
474		Data:     string(buf[:]),
475		Redirect: r.Redirect,
476		Warnings: r.Warnings,
477		WrapInfo: wrapInfo,
478		Headers:  headers,
479	}, nil
480}
481
482func LogicalAuthToProtoAuth(a *logical.Auth) (*Auth, error) {
483	if a == nil {
484		return nil, nil
485	}
486
487	buf, err := json.Marshal(a.InternalData)
488	if err != nil {
489		return nil, err
490	}
491
492	lo, err := LogicalLeaseOptionsToProtoLeaseOptions(a.LeaseOptions)
493	if err != nil {
494		return nil, err
495	}
496
497	boundCIDRs := make([]string, len(a.BoundCIDRs))
498	for i, cidr := range a.BoundCIDRs {
499		boundCIDRs[i] = cidr.String()
500	}
501
502	return &Auth{
503		LeaseOptions:     lo,
504		TokenType:        uint32(a.TokenType),
505		InternalData:     string(buf[:]),
506		DisplayName:      a.DisplayName,
507		Policies:         a.Policies,
508		TokenPolicies:    a.TokenPolicies,
509		IdentityPolicies: a.IdentityPolicies,
510		NoDefaultPolicy:  a.NoDefaultPolicy,
511		Metadata:         a.Metadata,
512		ClientToken:      a.ClientToken,
513		Accessor:         a.Accessor,
514		Period:           int64(a.Period),
515		NumUses:          int64(a.NumUses),
516		EntityID:         a.EntityID,
517		Alias:            a.Alias,
518		GroupAliases:     a.GroupAliases,
519		BoundCIDRs:       boundCIDRs,
520		ExplicitMaxTTL:   int64(a.ExplicitMaxTTL),
521	}, nil
522}
523
524func ProtoAuthToLogicalAuth(a *Auth) (*logical.Auth, error) {
525	if a == nil {
526		return nil, nil
527	}
528
529	data := map[string]interface{}{}
530	err := json.Unmarshal([]byte(a.InternalData), &data)
531	if err != nil {
532		return nil, err
533	}
534
535	lo, err := ProtoLeaseOptionsToLogicalLeaseOptions(a.LeaseOptions)
536	if err != nil {
537		return nil, err
538	}
539
540	boundCIDRs, err := parseutil.ParseAddrs(a.BoundCIDRs)
541	if err != nil {
542		return nil, err
543	}
544	if len(boundCIDRs) == 0 {
545		// On inbound auths, if auth.BoundCIDRs is empty, it will be nil.
546		// Let's match that behavior outbound.
547		boundCIDRs = nil
548	}
549
550	return &logical.Auth{
551		LeaseOptions:     lo,
552		TokenType:        logical.TokenType(a.TokenType),
553		InternalData:     data,
554		DisplayName:      a.DisplayName,
555		Policies:         a.Policies,
556		TokenPolicies:    a.TokenPolicies,
557		IdentityPolicies: a.IdentityPolicies,
558		NoDefaultPolicy:  a.NoDefaultPolicy,
559		Metadata:         a.Metadata,
560		ClientToken:      a.ClientToken,
561		Accessor:         a.Accessor,
562		Period:           time.Duration(a.Period),
563		NumUses:          int(a.NumUses),
564		EntityID:         a.EntityID,
565		Alias:            a.Alias,
566		GroupAliases:     a.GroupAliases,
567		BoundCIDRs:       boundCIDRs,
568		ExplicitMaxTTL:   time.Duration(a.ExplicitMaxTTL),
569	}, nil
570}
571
572func LogicalTokenEntryToProtoTokenEntry(t *logical.TokenEntry) *TokenEntry {
573	if t == nil {
574		return nil
575	}
576
577	boundCIDRs := make([]string, len(t.BoundCIDRs))
578	for i, cidr := range t.BoundCIDRs {
579		boundCIDRs[i] = cidr.String()
580	}
581
582	return &TokenEntry{
583		ID:             t.ID,
584		Accessor:       t.Accessor,
585		Parent:         t.Parent,
586		Policies:       t.Policies,
587		Path:           t.Path,
588		Meta:           t.Meta,
589		DisplayName:    t.DisplayName,
590		NumUses:        int64(t.NumUses),
591		CreationTime:   t.CreationTime,
592		TTL:            int64(t.TTL),
593		ExplicitMaxTTL: int64(t.ExplicitMaxTTL),
594		Role:           t.Role,
595		Period:         int64(t.Period),
596		EntityID:       t.EntityID,
597		BoundCIDRs:     boundCIDRs,
598		NamespaceID:    t.NamespaceID,
599		CubbyholeID:    t.CubbyholeID,
600		Type:           uint32(t.Type),
601	}
602}
603
604func ProtoTokenEntryToLogicalTokenEntry(t *TokenEntry) (*logical.TokenEntry, error) {
605	if t == nil {
606		return nil, nil
607	}
608
609	boundCIDRs, err := parseutil.ParseAddrs(t.BoundCIDRs)
610	if err != nil {
611		return nil, err
612	}
613	if len(boundCIDRs) == 0 {
614		// On inbound auths, if auth.BoundCIDRs is empty, it will be nil.
615		// Let's match that behavior outbound.
616		boundCIDRs = nil
617	}
618
619	return &logical.TokenEntry{
620		ID:             t.ID,
621		Accessor:       t.Accessor,
622		Parent:         t.Parent,
623		Policies:       t.Policies,
624		Path:           t.Path,
625		Meta:           t.Meta,
626		DisplayName:    t.DisplayName,
627		NumUses:        int(t.NumUses),
628		CreationTime:   t.CreationTime,
629		TTL:            time.Duration(t.TTL),
630		ExplicitMaxTTL: time.Duration(t.ExplicitMaxTTL),
631		Role:           t.Role,
632		Period:         time.Duration(t.Period),
633		EntityID:       t.EntityID,
634		BoundCIDRs:     boundCIDRs,
635		NamespaceID:    t.NamespaceID,
636		CubbyholeID:    t.CubbyholeID,
637		Type:           logical.TokenType(t.Type),
638	}, nil
639}
640