1/*
2 *
3 * Copyright 2020 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package local implements local transport credentials.
20// Local credentials reports the security level based on the type
21// of connetion. If the connection is local TCP, NoSecurity will be
22// reported, and if the connection is UDS, PrivacyAndIntegrity will be
23// reported. If local credentials is not used in local connections
24// (local TCP or UDS), it will fail.
25//
26// Experimental
27//
28// Notice: This package is EXPERIMENTAL and may be changed or removed in a
29// later release.
30package local
31
32import (
33	"context"
34	"fmt"
35	"net"
36	"strings"
37
38	"google.golang.org/grpc/credentials"
39)
40
41// info contains the auth information for a local connection.
42// It implements the AuthInfo interface.
43type info struct {
44	credentials.CommonAuthInfo
45}
46
47// AuthType returns the type of info as a string.
48func (info) AuthType() string {
49	return "local"
50}
51
52// localTC is the credentials required to establish a local connection.
53type localTC struct {
54	info credentials.ProtocolInfo
55}
56
57func (c *localTC) Info() credentials.ProtocolInfo {
58	return c.info
59}
60
61// getSecurityLevel returns the security level for a local connection.
62// It returns an error if a connection is not local.
63func getSecurityLevel(network, addr string) (credentials.SecurityLevel, error) {
64	switch {
65	// Local TCP connection
66	case strings.HasPrefix(addr, "127."), strings.HasPrefix(addr, "[::1]:"):
67		return credentials.NoSecurity, nil
68	// UDS connection
69	case network == "unix":
70		return credentials.PrivacyAndIntegrity, nil
71	// Not a local connection and should fail
72	default:
73		return credentials.InvalidSecurityLevel, fmt.Errorf("local credentials rejected connection to non-local address %q", addr)
74	}
75}
76
77func (*localTC) ClientHandshake(ctx context.Context, authority string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
78	secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
79	if err != nil {
80		return nil, nil, err
81	}
82	return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
83}
84
85func (*localTC) ServerHandshake(conn net.Conn) (net.Conn, credentials.AuthInfo, error) {
86	secLevel, err := getSecurityLevel(conn.RemoteAddr().Network(), conn.RemoteAddr().String())
87	if err != nil {
88		return nil, nil, err
89	}
90	return conn, info{credentials.CommonAuthInfo{SecurityLevel: secLevel}}, nil
91}
92
93// NewCredentials returns a local credential implementing credentials.TransportCredentials.
94func NewCredentials() credentials.TransportCredentials {
95	return &localTC{
96		info: credentials.ProtocolInfo{
97			SecurityProtocol: "local",
98		},
99	}
100}
101
102// Clone makes a copy of Local credentials.
103func (c *localTC) Clone() credentials.TransportCredentials {
104	return &localTC{info: c.info}
105}
106
107// OverrideServerName overrides the server name used to verify the hostname on the returned certificates from the server.
108// Since this feature is specific to TLS (SNI + hostname verification check), it does not take any effet for local credentials.
109func (c *localTC) OverrideServerName(serverNameOverride string) error {
110	c.info.ServerName = serverNameOverride
111	return nil
112}
113