1// Copyright 2020 Google LLC.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package http
6
7import (
8	"testing"
9
10	"crypto/tls"
11
12	"github.com/google/go-cmp/cmp"
13	"google.golang.org/api/internal"
14)
15
16func TestGetEndpoint(t *testing.T) {
17	testCases := []struct {
18		UserEndpoint    string
19		DefaultEndpoint string
20		Want            string
21		WantErr         bool
22	}{
23		{
24			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
25			Want:            "https://foo.googleapis.com/bar/baz",
26		},
27		{
28			UserEndpoint:    "myhost:3999",
29			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
30			Want:            "https://myhost:3999/bar/baz",
31		},
32		{
33			UserEndpoint:    "https://host/path/to/bar",
34			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
35			Want:            "https://host/path/to/bar",
36		},
37		{
38			UserEndpoint:    "host:port",
39			DefaultEndpoint: "",
40			WantErr:         true,
41		},
42	}
43
44	for _, tc := range testCases {
45		got, err := getEndpoint(&internal.DialSettings{
46			Endpoint:        tc.UserEndpoint,
47			DefaultEndpoint: tc.DefaultEndpoint,
48		}, nil)
49		if tc.WantErr && err == nil {
50			t.Errorf("want err, got nil err")
51			continue
52		}
53		if !tc.WantErr && err != nil {
54			t.Errorf("want nil err, got %v", err)
55			continue
56		}
57		if tc.Want != got {
58			t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
59		}
60	}
61}
62
63func TestGetEndpointWithClientCertSource(t *testing.T) {
64	dummyClientCertSource := func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }
65	testCases := []struct {
66		UserEndpoint    string
67		DefaultEndpoint string
68		Want            string
69		WantErr         bool
70	}{
71		{
72			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
73			Want:            "https://foo.mtls.googleapis.com/bar/baz",
74		},
75		{
76			DefaultEndpoint: "https://staging-foo.sandbox.googleapis.com/bar/baz",
77			Want:            "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
78		},
79		{
80			UserEndpoint:    "myhost:3999",
81			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
82			Want:            "https://myhost:3999/bar/baz",
83		},
84		{
85			UserEndpoint:    "https://host/path/to/bar",
86			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
87			Want:            "https://host/path/to/bar",
88		},
89		{
90			UserEndpoint:    "host:port",
91			DefaultEndpoint: "",
92			WantErr:         true,
93		},
94	}
95
96	for _, tc := range testCases {
97		got, err := getEndpoint(&internal.DialSettings{
98			Endpoint:        tc.UserEndpoint,
99			DefaultEndpoint: tc.DefaultEndpoint,
100		}, dummyClientCertSource)
101		if tc.WantErr && err == nil {
102			t.Errorf("want err, got nil err")
103			continue
104		}
105		if !tc.WantErr && err != nil {
106			t.Errorf("want nil err, got %v", err)
107			continue
108		}
109		if tc.Want != got {
110			t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
111		}
112	}
113}
114
115func TestGenerateDefaultMtlsEndpoint(t *testing.T) {
116	mtlsEndpoint := generateDefaultMtlsEndpoint("pubsub.googleapis.com")
117	wantMtlsEndpoint := "pubsub.mtls.googleapis.com"
118	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
119		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
120	}
121}
122
123func TestGenerateDefaultMtlsEndpointSandbox(t *testing.T) {
124	mtlsEndpoint := generateDefaultMtlsEndpoint("staging-pubsub.sandbox.googleapis.com")
125	wantMtlsEndpoint := "staging-pubsub.mtls.sandbox.googleapis.com"
126	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
127		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
128	}
129}
130
131func TestGenerateDefaultMtlsEndpointUnsupported(t *testing.T) {
132	mtlsEndpoint := generateDefaultMtlsEndpoint("unsupported.google.com")
133	wantMtlsEndpoint := "unsupported.google.com"
134	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
135		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
136	}
137}
138