1package protoutil 2 3import ( 4 "errors" 5 "fmt" 6 7 "github.com/golang/protobuf/proto" 8 "github.com/golang/protobuf/protoc-gen-go/descriptor" 9 "gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb" 10) 11 12// GetOpExtension gets the OperationMsg from a method descriptor 13func GetOpExtension(m *descriptor.MethodDescriptorProto) (*gitalypb.OperationMsg, error) { 14 ext, err := getExtension(m.GetOptions(), gitalypb.E_OpType) 15 if err != nil { 16 return nil, err 17 } 18 19 return ext.(*gitalypb.OperationMsg), nil 20} 21 22// IsInterceptedService returns whether the serivce is intercepted by Praefect. 23func IsInterceptedService(s *descriptor.ServiceDescriptorProto) (bool, error) { 24 return getBoolExtension(s.GetOptions(), gitalypb.E_Intercepted) 25} 26 27// GetRepositoryExtension gets the repository extension from a field descriptor 28func GetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { 29 return getBoolExtension(m.GetOptions(), gitalypb.E_Repository) 30} 31 32// GetStorageExtension gets the storage extension from a field descriptor 33func GetStorageExtension(m *descriptor.FieldDescriptorProto) (bool, error) { 34 return getBoolExtension(m.GetOptions(), gitalypb.E_Storage) 35} 36 37// GetTargetRepositoryExtension gets the target_repository extension from a field descriptor 38func GetTargetRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { 39 return getBoolExtension(m.GetOptions(), gitalypb.E_TargetRepository) 40} 41 42// GetAdditionalRepositoryExtension gets the target_repository extension from a field descriptor 43func GetAdditionalRepositoryExtension(m *descriptor.FieldDescriptorProto) (bool, error) { 44 return getBoolExtension(m.GetOptions(), gitalypb.E_AdditionalRepository) 45} 46 47func getBoolExtension(options proto.Message, extension *proto.ExtensionDesc) (bool, error) { 48 val, err := getExtension(options, extension) 49 if err != nil { 50 if errors.Is(err, proto.ErrMissingExtension) { 51 return false, nil 52 } 53 54 return false, err 55 } 56 57 return *val.(*bool), nil 58} 59 60func getExtension(options proto.Message, extension *proto.ExtensionDesc) (interface{}, error) { 61 if !proto.HasExtension(options, extension) { 62 return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.Name, proto.ErrMissingExtension) 63 } 64 65 ext, err := proto.GetExtension(options, extension) 66 if err != nil { 67 return nil, fmt.Errorf("protoutil.getExtension %q: %w", extension.Name, err) 68 } 69 70 return ext, nil 71} 72