1/*
2** Zabbix
3** Copyright (C) 2001-2021 Zabbix SIA
4**
5** This program is free software; you can redistribute it and/or modify
6** it under the terms of the GNU General Public License as published by
7** the Free Software Foundation; either version 2 of the License, or
8** (at your option) any later version.
9**
10** This program is distributed in the hope that it will be useful,
11** but WITHOUT ANY WARRANTY; without even the implied warranty of
12** MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13** GNU General Public License for more details.
14**
15** You should have received a copy of the GNU General Public License
16** along with this program; if not, write to the Free Software
17** Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301, USA.
18**/
19
20// Package uri provides a helper for URI validation and parsing
21package uri
22
23import (
24	"errors"
25	"fmt"
26	"net"
27	"net/url"
28	"strconv"
29	"strings"
30)
31
32type URI struct {
33	scheme   string
34	host     string
35	port     string
36	rawQuery string
37	socket   string
38	user     string
39	password string
40	rawUri   string
41	path     string
42}
43
44func (u *URI) Scheme() string {
45	return u.scheme
46}
47
48func (u *URI) Host() string {
49	return u.host
50}
51
52func (u *URI) Socket() string {
53	return u.socket
54}
55
56func (u *URI) Port() string {
57	return u.port
58}
59
60func (u *URI) Query() string {
61	return u.rawQuery
62}
63
64func (u *URI) Path() string {
65	return u.path
66}
67
68func (u *URI) GetParam(key string) string {
69	params, err := url.ParseQuery(u.rawQuery)
70	if err != nil {
71		return ""
72	}
73
74	return params.Get(key)
75}
76
77func (u *URI) Password() string {
78	return u.password
79}
80
81func (u *URI) User() string {
82	return u.user
83}
84
85// Addr combines a host and a port into a network address ("host:port") or returns a socket.
86func (u *URI) Addr() string {
87	if u.socket != "" {
88		return u.socket
89	}
90
91	if u.port == "" {
92		return u.host
93	}
94
95	return net.JoinHostPort(u.host, u.port)
96}
97
98// String reassembles the URI to a valid URI string.
99func (u *URI) String() string {
100	t := &url.URL{
101		Scheme:   u.scheme,
102		RawQuery: u.rawQuery,
103	}
104
105	if u.socket != "" {
106		t.Path = u.socket
107	} else {
108		if u.port == "" {
109			t.Host = u.host
110		} else {
111			t.Host = net.JoinHostPort(u.host, u.port)
112		}
113	}
114
115	if u.user != "" {
116		if u.password != "" {
117			t.User = url.UserPassword(u.user, u.password)
118		} else {
119			t.User = url.User(u.user)
120		}
121	}
122
123	return t.String()
124}
125
126func (u *URI) withCreds(user, password string) *URI {
127	u.password = password
128	u.user = user
129
130	return u
131}
132
133type Defaults struct {
134	Port   string
135	Scheme string
136}
137
138// New parses a given rawUri and returns a new filled URI structure.
139// It ignores embedded credentials according to https://www.ietf.org/rfc/rfc3986.txt.
140// Use NewWithCreds to add credentials to a structure.
141func New(rawUri string, defaults *Defaults) (res *URI, err error) {
142	var (
143		isSocket bool
144		noScheme bool
145		port     string
146	)
147
148	rawUri = strings.TrimSpace(rawUri)
149
150	res = &URI{
151		rawUri: rawUri,
152	}
153
154	// https://tools.ietf.org/html/rfc6874#section-2
155	// %25 is allowed to escape a percent sign in IPv6 scoped-address literals
156	if !strings.Contains(rawUri, "%25") {
157		rawUri = strings.Replace(rawUri, "%", "%25", -1)
158	}
159
160	if noScheme = !strings.Contains(rawUri, ":/"); noScheme {
161		if defaults != nil && defaults.Scheme != "" {
162			rawUri = defaults.Scheme + "://" + rawUri
163		} else {
164			rawUri = "tcp://" + rawUri
165		}
166	}
167
168	u, err := url.Parse(rawUri)
169	if err != nil {
170		return nil, err
171	}
172
173	res.scheme = u.Scheme
174	port = u.Port()
175
176	if port == "" {
177		if defaults != nil {
178			port = defaults.Port
179		}
180	}
181
182	if port != "" {
183		if _, err = strconv.ParseUint(port, 10, 16); err != nil {
184			return nil, errors.New("port must be integer and must be between 0 and 65535")
185		}
186	}
187
188	isSocket = res.scheme == "unix" || (noScheme && u.Hostname() == "" && u.Path != "")
189	if isSocket {
190		if u.Path == "" {
191			return nil, errors.New("socket is required")
192		}
193
194		res.scheme = "unix"
195		res.socket = u.Path
196	} else {
197		if u.Hostname() == "" {
198			return nil, errors.New("host is required")
199		}
200
201		res.host = u.Hostname()
202		res.port = port
203		res.path = u.Path
204	}
205
206	res.rawQuery = u.RawQuery
207
208	return res, err
209}
210
211func NewWithCreds(rawUri, user, password string, defaults *Defaults) (res *URI, err error) {
212	res, err = New(rawUri, defaults)
213	if err != nil {
214		return nil, err
215	}
216
217	return res.withCreds(user, password), nil
218}
219
220type URIValidator struct {
221	Defaults       *Defaults
222	AllowedSchemes []string
223}
224
225func (v URIValidator) Validate(value *string) error {
226	if value == nil {
227		return nil
228	}
229
230	res, err := New(*value, v.Defaults)
231	if err != nil {
232		return err
233	}
234
235	if v.AllowedSchemes != nil {
236		for _, s := range v.AllowedSchemes {
237			if res.Scheme() == s {
238				return nil
239			}
240		}
241
242		return fmt.Errorf("allowed schemes: %s", strings.Join(v.AllowedSchemes, ", "))
243	}
244
245	return nil
246}
247
248func IsHostnameOnly(host string) error {
249	if strings.Contains(host, ":/") {
250		return fmt.Errorf("must not contain scheme")
251	}
252
253	uri, err := New(host, &Defaults{Port: "", Scheme: ""})
254	if err != nil {
255		return err
256	}
257
258	if uri.Port() != "" {
259		return fmt.Errorf("must not contain port")
260	}
261
262	if uri.Socket() != "" {
263		return fmt.Errorf("must not contain socket")
264	}
265
266	if uri.User() != "" {
267		return fmt.Errorf("must not contain user")
268	}
269
270	if uri.Password() != "" {
271		return fmt.Errorf("must not contain password")
272	}
273
274	if uri.Query() != "" {
275		return fmt.Errorf("must not contain query")
276	}
277
278	if uri.Path() != "" {
279		return fmt.Errorf("must not contain path")
280	}
281
282	return nil
283}
284