1// This file and its contents are licensed under the Apache License 2.0.
2// Please see the included NOTICE for copyright information and
3// LICENSE for a copy of the license.
4
5package extension
6
7import (
8	"context"
9	"fmt"
10	"sort"
11	"strings"
12
13	"github.com/blang/semver/v4"
14	"github.com/jackc/pgx/v4"
15	"github.com/timescale/promscale/pkg/log"
16	"github.com/timescale/promscale/pkg/pgmodel/common/errors"
17	"github.com/timescale/promscale/pkg/pgmodel/common/schema"
18	"github.com/timescale/promscale/pkg/version"
19)
20
21var (
22	ExtensionIsInstalled      = false
23	PromscaleExtensionVersion semver.Version
24)
25
26type ExtensionMigrateOptions struct {
27	Install           bool
28	Upgrade           bool
29	UpgradePreRelease bool
30}
31
32// MigrateTimescaleDBExtension installs or updates TimescaleDB
33// Note that after this call any previous connections can break
34// so this has to be called ahead of opening connections.
35//
36// Also this takes a connection string not a connection because for
37// updates the ALTER has to be the first command on the connection
38// thus we cannot reuse existing connections
39func InstallUpgradeTimescaleDBExtensions(connstr string, extOptions ExtensionMigrateOptions) error {
40	db, err := pgx.Connect(context.Background(), connstr)
41	if err != nil {
42		return err
43	}
44	defer func() { _ = db.Close(context.Background()) }()
45
46	err = MigrateExtension(db, "timescaledb", "public", version.TimescaleVersionRange, version.TimescaleVersionRangeFullString, extOptions)
47	if err != nil {
48		return fmt.Errorf("could not install timescaledb: %w", err)
49	}
50
51	return nil
52
53}
54
55func InstallUpgradePromscaleExtensions(db *pgx.Conn, extOptions ExtensionMigrateOptions) (bool, error) {
56	ExtensionIsInstalled = false
57	err := MigrateExtension(db, "promscale", schema.Ext, version.ExtVersionRange, version.ExtVersionRangeString, extOptions)
58	if err != nil {
59		log.Warn("msg", fmt.Sprintf("could not install promscale: %v. continuing without extension", err))
60	}
61
62	if err = CheckExtensionsVersion(db, false, extOptions); err != nil {
63		return false, fmt.Errorf("error encountered while migrating extension: %w", err)
64	}
65
66	return ExtensionIsInstalled, nil
67}
68
69// CheckVersions is responsible for verifying the version compatibility of installed Postgresql database and extensions.
70func CheckVersions(conn *pgx.Conn, migrationFailedDueToLockError bool, extOptions ExtensionMigrateOptions) error {
71	if err := checkPgVersion(conn); err != nil {
72		return fmt.Errorf("problem checking PostgreSQL version: %w", err)
73	}
74	if err := CheckExtensionsVersion(conn, migrationFailedDueToLockError, extOptions); err != nil {
75		return fmt.Errorf("problem checking extension version: %w", err)
76	}
77	return nil
78}
79
80func checkPgVersion(conn *pgx.Conn) error {
81	var versionString string
82	if err := conn.QueryRow(context.Background(), "SHOW server_version_num;").Scan(&versionString); err != nil {
83		return fmt.Errorf("error fetching postgresql version: %w", err)
84	}
85	// Semver valid versions need to satisfy major, minor and patch numbers respectively. Also, none
86	// of these can be preceded by "0", otherwise would be rendered invalid.
87	//
88	// Postgres server_version_num outputs a 6 digit number. Out of the 6 digits, the first two and
89	// the last two are the significant ones. The ones in the middle are always 0 unless Postgres
90	// releases a 3 digit version (i.e., 100+ version). This means that middle two should be removed
91	// since if they are not, then they will lead to invalid parsing of semver version.
92	// Reference: https://bit.ly/3lOnUNh
93	versionString = fmt.Sprintf("%s.%s.0", versionString[:2], trimLeadingZeros(versionString[4:]))
94	v, err := semver.Parse(versionString)
95	if err != nil {
96		return fmt.Errorf("could not parse postgresql version string: %w", err)
97	}
98	if !version.VerifyPgVersion(v) {
99		return fmt.Errorf("Incompatible postgresql version. Supported server version %s", version.PgVersionNumRange)
100	}
101	return nil
102}
103
104// CheckExtensionsVersion checks for the correct version and enables the extension if
105// it is at the right version
106func CheckExtensionsVersion(conn *pgx.Conn, migrationFailedDueToLockError bool, extOptions ExtensionMigrateOptions) error {
107	if err := checkTimescaleDBVersion(conn); err != nil {
108		return fmt.Errorf("problem checking timescaledb extension version: %w", err)
109	}
110	if err := checkPromscaleExtensionVersion(conn, migrationFailedDueToLockError, extOptions); err != nil {
111		return fmt.Errorf("problem checking promscale extension version: %w", err)
112	}
113	return nil
114}
115
116func checkTimescaleDBVersion(conn *pgx.Conn) error {
117	timescaleVersion, isInstalled, err := fetchInstalledExtensionVersion(conn, "timescaledb")
118	if err != nil {
119		return fmt.Errorf("could not get the installed extension version: %w", err)
120	}
121	if !isInstalled {
122		log.Warn("msg", "Running Promscale without TimescaleDB. Some features will be disabled.")
123		return nil
124	}
125	switch version.VerifyTimescaleVersion(timescaleVersion) {
126	case version.Warn:
127		safeRanges := strings.Split(version.TimescaleVersionRangeString.Safe, " ")
128		log.Warn(
129			"msg",
130			fmt.Sprintf(
131				"Might lead to incompatibility issues due to TimescaleDB version. Expected version within %s to %s.",
132				safeRanges[0], safeRanges[1],
133			),
134			"Installed Timescaledb version:", timescaleVersion.String(),
135		)
136	case version.Err:
137		safeRanges := strings.Split(version.TimescaleVersionRangeString.Safe, " ")
138		return fmt.Errorf("incompatible Timescaledb version: %s. Expected version within %s to %s", timescaleVersion.String(), safeRanges[0], safeRanges[1])
139	case version.Safe:
140	}
141	return nil
142}
143
144func checkPromscaleExtensionVersion(conn *pgx.Conn, migrationFailedDueToLockError bool, extOptions ExtensionMigrateOptions) error {
145	currentVersion, newVersion, err := extensionVersions(conn, "promscale", version.ExtVersionRange, version.ExtVersionRangeString, extOptions)
146	if err != nil {
147		ExtensionIsInstalled = false
148		if err == errors.ErrExtUnavailable {
149			//the promscale extension is optional
150			return nil
151		}
152		return fmt.Errorf("could not get the extension versions: %w", err)
153	}
154	if (currentVersion == nil || currentVersion.Compare(*newVersion) < 0) && migrationFailedDueToLockError {
155		log.Warn("msg", "Unable to install/update the Promscale extension; failed to acquire the lock. Ensure there are no other connectors running and try again.")
156	}
157	if currentVersion == nil {
158		ExtensionIsInstalled = false
159		return nil
160	}
161	if version.ExtVersionRange(*currentVersion) {
162		ExtensionIsInstalled = true
163		PromscaleExtensionVersion = *currentVersion
164	} else {
165		ExtensionIsInstalled = false
166	}
167	return nil
168}
169
170func extensionVersions(conn *pgx.Conn, extName string, validRange semver.Range, rangeString string, extOptions ExtensionMigrateOptions) (currentVersion *semver.Version, newVersion *semver.Version, err error) {
171	availableVersions, err := fetchAvailableExtensionVersions(conn, extName)
172	if err != nil {
173		return nil, nil, fmt.Errorf("problem fetching available version: %w", err)
174	}
175	if len(availableVersions) == 0 {
176		return nil, nil, errors.ErrExtUnavailable
177	}
178
179	defaultVersion, err := fetchDefaultExtensionVersions(conn, extName)
180	if err != nil {
181		return nil, nil, fmt.Errorf("problem fetching default version: %w", err)
182	}
183
184	current, isInstalled, err := fetchInstalledExtensionVersion(conn, extName)
185	if err != nil {
186		return nil, nil, fmt.Errorf("problem getting the installed extension version: %w", err)
187	}
188
189	new, ok := getNewExtensionVersion(extName, availableVersions, defaultVersion, validRange, isInstalled, extOptions.UpgradePreRelease, current)
190	if !ok {
191		return nil, nil, fmt.Errorf("The %v extension is not available at the right version, need version: %v and the default version is %s ", extName, rangeString, defaultVersion)
192	}
193
194	if !isInstalled {
195		return nil, &new, nil
196	}
197	return &current, &new, nil
198}
199
200func MigrateExtension(conn *pgx.Conn, extName string, extSchemaName string, validRange semver.Range, rangeString string, extOptions ExtensionMigrateOptions) error {
201	currentVersion, newVersion, err := extensionVersions(conn, extName, validRange, rangeString, extOptions)
202	if err != nil {
203		return err
204	}
205	isInstalled := currentVersion != nil
206
207	if !isInstalled && !extOptions.Install {
208		log.Info("msg", "skipping "+extName+" extension install as install extension is disabled.")
209		return nil
210	}
211
212	if isInstalled && !extOptions.Upgrade {
213		log.Info("msg", "skipping "+extName+" extension upgrade as upgrade extension is disabled. The current extension version is "+currentVersion.String())
214		if !validRange(*currentVersion) {
215			return fmt.Errorf("The %v extension is not installed at the right version and upgrades are disabled, need version: %v and the installed version is %s ", extName, rangeString, currentVersion)
216		}
217		return nil
218	}
219
220	if !isInstalled {
221		_, extErr := conn.Exec(context.Background(),
222			fmt.Sprintf("CREATE EXTENSION IF NOT EXISTS %s WITH SCHEMA %s VERSION '%s'",
223				extName, extSchemaName, getSqlVersion(*newVersion, extName)))
224		if extErr != nil {
225			return extErr
226		}
227		log.Info("msg", "successfully created "+extName+" extension at version "+newVersion.String())
228		return nil
229	}
230
231	comparator := currentVersion.Compare(*newVersion)
232	if comparator > 0 {
233		//currentVersion greater than what we can handle, don't use the extension
234		return fmt.Errorf("the extension at a greater version than supported by the connector: %v > %v", currentVersion, newVersion)
235	} else if comparator == 0 {
236		//Nothing to do we are at the correct version
237		return nil
238	} else {
239		//Upgrade to the right version
240		connAlter := conn
241		if extName == "timescaledb" {
242			//TimescaleDB requires a fresh connection for altering
243			//Note: all previously opened connections will become invalid
244			connAlter, err = pgx.ConnectConfig(context.Background(), conn.Config())
245			if err != nil {
246				return err
247			}
248			defer func() { _ = connAlter.Close(context.Background()) }()
249		}
250		_, err := connAlter.Exec(context.Background(),
251			fmt.Sprintf("ALTER EXTENSION %s UPDATE TO '%s'", extName,
252				getSqlVersion(*newVersion, extName)))
253		// if migration fails, Do not crash just log an error. As there is an extension already present.
254		if err != nil {
255			if !validRange(*currentVersion) {
256				return fmt.Errorf("The %v extension is not installed at the right version and the upgrades failed, need version: %v and the installed version is %s and the upgrade failed with %w ", extName, rangeString, currentVersion, err)
257			}
258			log.Error("msg", fmt.Sprintf("Failed to migrate extension %v from %v to %v: %v", extName, currentVersion, newVersion, err))
259		}
260		log.Info("msg", "successfully updated extension", "extension_name", extName, "old_version", currentVersion, "new_version", newVersion)
261	}
262
263	return nil
264}
265
266func fetchAvailableExtensionVersions(conn *pgx.Conn, extName string) (semver.Versions, error) {
267	var versionStrings []string
268	versions := make(semver.Versions, 0)
269	err := conn.QueryRow(context.Background(),
270		"SELECT array_agg(version) FROM pg_available_extension_versions WHERE name = $1", extName).Scan(&versionStrings)
271
272	if err != nil {
273		return versions, err
274	}
275	if len(versionStrings) == 0 {
276		return versions, nil
277	}
278
279	for i := range versionStrings {
280		vString := correctVersionString(versionStrings[i], extName)
281		// ignore mock ext versions
282		ok := strings.HasPrefix(vString, "mock")
283		if !ok {
284			v, err := semver.Parse(vString)
285			if err != nil {
286				return versions, fmt.Errorf("Could not parse available extension version %v: %w", vString, err)
287			}
288			versions = append(versions, v)
289		}
290	}
291
292	return versions, nil
293}
294
295func fetchDefaultExtensionVersions(conn *pgx.Conn, extName string) (semver.Version, error) {
296	var versionString string
297	err := conn.QueryRow(context.Background(),
298		"SELECT default_version FROM pg_available_extensions WHERE name = $1", extName).Scan(&versionString)
299	if err != nil {
300		return semver.Version{}, err
301	}
302
303	versionString = correctVersionString(versionString, extName)
304	v, err := semver.Parse(versionString)
305	if err != nil {
306		return v, fmt.Errorf("Could not parse default extension version %v: %w", versionString, err)
307	}
308
309	return v, nil
310}
311
312func fetchInstalledExtensionVersion(conn *pgx.Conn, extensionName string) (semver.Version, bool, error) {
313	var versionString string
314	if err := conn.QueryRow(
315		context.Background(),
316		"SELECT extversion FROM pg_extension WHERE extname=$1;",
317		extensionName,
318	).Scan(&versionString); err != nil {
319		if err == pgx.ErrNoRows {
320			return semver.Version{}, false, nil
321		}
322		return semver.Version{}, true, err
323	}
324
325	versionString = correctVersionString(versionString, extensionName)
326
327	v, err := semver.Parse(versionString)
328	if err != nil {
329		return v, true, fmt.Errorf("could not parse current %s extension version %v: %w", extensionName, versionString, err)
330	}
331	return v, true, nil
332}
333
334func correctVersionString(v string, extName string) string {
335	//we originally published the extension as "0.1" which isn't a valid semver
336	if extName == "promscale" && v == "0.1" {
337		return "0.1.0"
338	}
339	return v
340}
341
342func getSqlVersion(v semver.Version, extName string) string {
343	if extName == "promscale" && v.String() == "0.1.0" {
344		return "0.1"
345	}
346	return v.String()
347}
348
349// getNewExtensionVersion returns the highest version allowed by validRange
350func getNewExtensionVersion(extName string,
351	availableVersions semver.Versions,
352	defaultVersion semver.Version,
353	validRange semver.Range,
354	validCurrentVersion, upgradePreRelease bool,
355	currentVersion semver.Version) (semver.Version, bool) {
356	//sort higher extensions first
357	sort.Sort(sort.Reverse(availableVersions))
358	printedWarning := false
359	for i := range availableVersions {
360		/* skip any versions above the default version for auto-upgrade */
361		if availableVersions[i].GT(defaultVersion) {
362			continue
363		}
364
365		// if upgradePreRelease is false skip the pre-releases.
366		// if the current installed version is an rc version return it as an valid available version.
367		if len(availableVersions[i].Pre) > 0 && !upgradePreRelease && availableVersions[i].String() != currentVersion.String() {
368			log.Warn("msg", "skipping upgrade to prerelease version "+availableVersions[i].String()+" as --upgrade-prerelease-extensions is disabled.")
369			continue
370		}
371
372		/* Do not auto-upgrade across Major versions of extensions */
373		if validCurrentVersion && currentVersion.Major != availableVersions[i].Major {
374			/* Print a warning if there is a a non-prerelease newer major version available */
375			if !printedWarning && availableVersions[i].Major > currentVersion.Major && len(availableVersions[i].Pre) == 0 {
376				log.Warn("msg", "Newer major version of "+extName+" is available, but has to be upgraded manually with ALTER EXTENSION (we do not upgrade across major versions automatically).",
377					"available_version", availableVersions[i].String())
378				printedWarning = true
379			}
380			continue
381		}
382		if validRange(availableVersions[i]) {
383			return availableVersions[i], true
384		}
385	}
386	return semver.Version{}, false
387}
388
389// trimLeadingZeros removes the leading zeros passed in the version number.
390func trimLeadingZeros(s string) string {
391	if s = strings.TrimLeft(s, "0"); s == "" {
392		return "0"
393	}
394	return s
395}
396