1package protoutil
2
3import (
4	"errors"
5	"fmt"
6
7	"github.com/golang/protobuf/proto"
8	"github.com/golang/protobuf/protoc-gen-go/descriptor"
9	"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
10)
11
12// GetOpExtension gets the OperationMsg from a method descriptor
13func GetOpExtension(m *descriptor.MethodDescriptorProto) (*gitalypb.OperationMsg, error) {
14	ext, err := getExtension(m.GetOptions(), gitalypb.E_OpType)
15	if err != nil {
16		return nil, err
17	}
18
19	return ext.(*gitalypb.OperationMsg), nil
20}
21
22// IsInterceptedService returns whether the serivce is intercepted by Praefect.
23func IsInterceptedService(s *descriptor.ServiceDescriptorProto) (bool, error) {
24	return getBoolExtension(s.GetOptions(), gitalypb.E_Intercepted)
25}
26
27// GetRepositoryExtension gets the repository extension from a field descriptor
28func GetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
29	return getBoolExtension(m.GetOptions(), gitalypb.E_Repository)
30}
31
32// GetStorageExtension gets the storage extension from a field descriptor
33func GetStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
34	return getBoolExtension(m.GetOptions(), gitalypb.E_Storage)
35}
36
37// GetTargetRepositoryExtension gets the target_repository extension from a field descriptor
38func GetTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
39	return getBoolExtension(m.GetOptions(), gitalypb.E_TargetRepository)
40}
41
42// GetAdditionalRepositoryExtension gets the target_repository extension from a field descriptor
43func GetAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) {
44	return getBoolExtension(m.GetOptions(), gitalypb.E_AdditionalRepository)
45}
46
47func getBoolExtension(options proto.Message, extension *proto.ExtensionDesc) (bool, error) {
48	val, err := getExtension(options, extension)
49	if err != nil {
50		if errors.Is(err, proto.ErrMissingExtension) {
51			return false, nil
52		}
53
54		return false, err
55	}
56
57	return *val.(*bool), nil
58}
59
60func getExtension(options proto.Message, extension *proto.ExtensionDesc) (interface{}, error) {
61	if !proto.HasExtension(options, extension) {
62		return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.Name, proto.ErrMissingExtension)
63	}
64
65	ext, err := proto.GetExtension(options, extension)
66	if err != nil {
67		return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.Name, err)
68	}
69
70	return ext, nil
71}
72