1package config
2
3import (
4	"errors"
5	"fmt"
6	"os"
7	"strings"
8	"time"
9
10	"github.com/pelletier/go-toml"
11	promclient "github.com/prometheus/client_golang/prometheus"
12	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config"
13	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config/auth"
14	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config/log"
15	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config/prometheus"
16	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/config/sentry"
17)
18
19// ElectionStrategy is a Praefect primary election strategy.
20type ElectionStrategy string
21
22// validate validates the election strategy is a valid one.
23func (es ElectionStrategy) validate() error {
24	switch es {
25	case ElectionStrategyLocal, ElectionStrategySQL, ElectionStrategyPerRepository:
26		return nil
27	default:
28		return fmt.Errorf("invalid election strategy: %q", es)
29	}
30}
31
32const (
33	// ElectionStrategyLocal configures a single node, in-memory election strategy.
34	ElectionStrategyLocal ElectionStrategy = "local"
35	// ElectionStrategySQL configures an SQL based strategy that elects a primary for a virtual storage.
36	ElectionStrategySQL ElectionStrategy = "sql"
37	// ElectionStrategyPerRepository configures an SQL based strategy that elects different primaries per repository.
38	ElectionStrategyPerRepository ElectionStrategy = "per_repository"
39
40	minimalSyncCheckInterval = time.Minute
41	minimalSyncRunInterval   = time.Minute
42)
43
44type Failover struct {
45	Enabled bool `toml:"enabled"`
46	// ElectionStrategy is the strategy to use for electing primaries nodes.
47	ElectionStrategy         ElectionStrategy `toml:"election_strategy"`
48	ErrorThresholdWindow     config.Duration  `toml:"error_threshold_window"`
49	WriteErrorThresholdCount uint32           `toml:"write_error_threshold_count"`
50	ReadErrorThresholdCount  uint32           `toml:"read_error_threshold_count"`
51	// BootstrapInterval allows set a time duration that would be used on startup to make initial health check.
52	// The default value is 1s.
53	BootstrapInterval config.Duration `toml:"bootstrap_interval"`
54	// MonitorInterval allows set a time duration that would be used after bootstrap is completed to execute health checks.
55	// The default value is 3s.
56	MonitorInterval config.Duration `toml:"monitor_interval"`
57}
58
59// ErrorThresholdsConfigured checks whether returns whether the errors thresholds are configured. If they
60// are configured but in an invalid way, an error is returned.
61func (f Failover) ErrorThresholdsConfigured() (bool, error) {
62	if f.ErrorThresholdWindow == 0 && f.WriteErrorThresholdCount == 0 && f.ReadErrorThresholdCount == 0 {
63		return false, nil
64	}
65
66	if f.ErrorThresholdWindow == 0 {
67		return false, errors.New("threshold window not set")
68	}
69
70	if f.WriteErrorThresholdCount == 0 {
71		return false, errors.New("write error threshold not set")
72	}
73
74	if f.ReadErrorThresholdCount == 0 {
75		return false, errors.New("read error threshold not set")
76	}
77
78	return true, nil
79}
80
81// Reconciliation contains reconciliation specific configuration options.
82type Reconciliation struct {
83	// SchedulingInterval the interval between each automatic reconciliation run. If set to 0,
84	// automatic reconciliation is disabled.
85	SchedulingInterval config.Duration `toml:"scheduling_interval"`
86	// HistogramBuckets configures the reconciliation scheduling duration histogram's buckets.
87	HistogramBuckets []float64 `toml:"histogram_buckets"`
88}
89
90// DefaultReconciliationConfig returns the default values for reconciliation configuration.
91func DefaultReconciliationConfig() Reconciliation {
92	return Reconciliation{
93		SchedulingInterval: 5 * config.Duration(time.Minute),
94		HistogramBuckets:   promclient.DefBuckets,
95	}
96}
97
98// Replication contains replication specific configuration options.
99type Replication struct {
100	// BatchSize controls how many replication jobs to dequeue and lock
101	// in a single call to the database.
102	BatchSize uint `toml:"batch_size"`
103	// ParallelStorageProcessingWorkers is a number of workers used to process replication
104	// events per virtual storage (how many storages would be processed in parallel).
105	ParallelStorageProcessingWorkers uint `toml:"parallel_storage_processing_workers"`
106}
107
108// DefaultReplicationConfig returns the default values for replication configuration.
109func DefaultReplicationConfig() Replication {
110	return Replication{BatchSize: 10, ParallelStorageProcessingWorkers: 1}
111}
112
113// Config is a container for everything found in the TOML config file
114type Config struct {
115	AllowLegacyElectors  bool              `toml:"i_understand_my_election_strategy_is_unsupported_and_will_be_removed_without_warning"`
116	Reconciliation       Reconciliation    `toml:"reconciliation"`
117	Replication          Replication       `toml:"replication"`
118	ListenAddr           string            `toml:"listen_addr"`
119	TLSListenAddr        string            `toml:"tls_listen_addr"`
120	SocketPath           string            `toml:"socket_path"`
121	VirtualStorages      []*VirtualStorage `toml:"virtual_storage"`
122	Logging              log.Config        `toml:"logging"`
123	Sentry               sentry.Config     `toml:"sentry"`
124	PrometheusListenAddr string            `toml:"prometheus_listen_addr"`
125	Prometheus           prometheus.Config `toml:"prometheus"`
126	Auth                 auth.Config       `toml:"auth"`
127	TLS                  config.TLS        `toml:"tls"`
128	DB                   `toml:"database"`
129	Failover             Failover `toml:"failover"`
130	// Keep for legacy reasons: remove after Omnibus has switched
131	FailoverEnabled     bool                `toml:"failover_enabled"`
132	MemoryQueueEnabled  bool                `toml:"memory_queue_enabled"`
133	GracefulStopTimeout config.Duration     `toml:"graceful_stop_timeout"`
134	RepositoriesCleanup RepositoriesCleanup `toml:"repositories_cleanup"`
135	// ForceCreateRepositories will enable force-creation of repositories in the
136	// coordinator when routing repository-scoped mutators. This must never be used
137	// outside of tests.
138	ForceCreateRepositories bool `toml:"force_create_repositories_for_testing_purposes"`
139}
140
141// VirtualStorage represents a set of nodes for a storage
142type VirtualStorage struct {
143	Name  string  `toml:"name"`
144	Nodes []*Node `toml:"node"`
145	// DefaultReplicationFactor is the replication factor set for new repositories.
146	// A valid value is inclusive between 1 and the number of configured storages in the
147	// virtual storage. Setting the value to 0 or below causes Praefect to not store any
148	// host assignments, falling back to the behavior of replicating to every configured
149	// storage
150	DefaultReplicationFactor int `toml:"default_replication_factor"`
151}
152
153// FromFile loads the config for the passed file path
154func FromFile(filePath string) (Config, error) {
155	b, err := os.ReadFile(filePath)
156	if err != nil {
157		return Config{}, err
158	}
159
160	conf := &Config{
161		Reconciliation: DefaultReconciliationConfig(),
162		Replication:    DefaultReplicationConfig(),
163		Prometheus:     prometheus.DefaultConfig(),
164		// Sets the default Failover, to be overwritten when deserializing the TOML
165		Failover:            Failover{Enabled: true, ElectionStrategy: ElectionStrategyPerRepository},
166		RepositoriesCleanup: DefaultRepositoriesCleanup(),
167	}
168	if err := toml.Unmarshal(b, conf); err != nil {
169		return Config{}, err
170	}
171
172	// TODO: Remove this after failover_enabled has moved under a separate failover section. This is for
173	// backwards compatibility only
174	if conf.FailoverEnabled {
175		conf.Failover.Enabled = true
176	}
177
178	conf.setDefaults()
179
180	return *conf, nil
181}
182
183var (
184	errDuplicateStorage         = errors.New("internal gitaly storages are not unique")
185	errGitalyWithoutAddr        = errors.New("all gitaly nodes must have an address")
186	errGitalyWithoutStorage     = errors.New("all gitaly nodes must have a storage")
187	errNoGitalyServers          = errors.New("no primary gitaly backends configured")
188	errNoListener               = errors.New("no listen address or socket path configured")
189	errNoVirtualStorages        = errors.New("no virtual storages configured")
190	errStorageAddressDuplicate  = errors.New("multiple storages have the same address")
191	errVirtualStoragesNotUnique = errors.New("virtual storages must have unique names")
192	errVirtualStorageUnnamed    = errors.New("virtual storages must have a name")
193)
194
195// Validate establishes if the config is valid
196func (c *Config) Validate() error {
197	if err := c.Failover.ElectionStrategy.validate(); err != nil {
198		return err
199	}
200
201	if c.ListenAddr == "" && c.SocketPath == "" && c.TLSListenAddr == "" {
202		return errNoListener
203	}
204
205	if len(c.VirtualStorages) == 0 {
206		return errNoVirtualStorages
207	}
208
209	if c.Replication.BatchSize < 1 {
210		return fmt.Errorf("replication batch size was %d but must be >=1", c.Replication.BatchSize)
211	}
212
213	allAddresses := make(map[string]struct{})
214	virtualStorages := make(map[string]struct{}, len(c.VirtualStorages))
215
216	for _, virtualStorage := range c.VirtualStorages {
217		if virtualStorage.Name == "" {
218			return errVirtualStorageUnnamed
219		}
220
221		if len(virtualStorage.Nodes) == 0 {
222			return fmt.Errorf("virtual storage %q: %w", virtualStorage.Name, errNoGitalyServers)
223		}
224
225		if _, ok := virtualStorages[virtualStorage.Name]; ok {
226			return fmt.Errorf("virtual storage %q: %w", virtualStorage.Name, errVirtualStoragesNotUnique)
227		}
228		virtualStorages[virtualStorage.Name] = struct{}{}
229
230		storages := make(map[string]struct{}, len(virtualStorage.Nodes))
231		for _, node := range virtualStorage.Nodes {
232			if node.Storage == "" {
233				return fmt.Errorf("virtual storage %q: %w", virtualStorage.Name, errGitalyWithoutStorage)
234			}
235
236			if node.Address == "" {
237				return fmt.Errorf("virtual storage %q: %w", virtualStorage.Name, errGitalyWithoutAddr)
238			}
239
240			if _, found := storages[node.Storage]; found {
241				return fmt.Errorf("virtual storage %q: %w", virtualStorage.Name, errDuplicateStorage)
242			}
243			storages[node.Storage] = struct{}{}
244
245			if _, found := allAddresses[node.Address]; found {
246				return fmt.Errorf("virtual storage %q: address %q : %w", virtualStorage.Name, node.Address, errStorageAddressDuplicate)
247			}
248			allAddresses[node.Address] = struct{}{}
249		}
250
251		if virtualStorage.DefaultReplicationFactor > len(virtualStorage.Nodes) {
252			return fmt.Errorf(
253				"virtual storage %q has a default replication factor (%d) which is higher than the number of storages (%d)",
254				virtualStorage.Name, virtualStorage.DefaultReplicationFactor, len(virtualStorage.Nodes),
255			)
256		}
257	}
258
259	if c.RepositoriesCleanup.RunInterval.Duration() > 0 {
260		if c.RepositoriesCleanup.CheckInterval.Duration() < minimalSyncCheckInterval {
261			return fmt.Errorf("repositories_cleanup.check_interval is less then %s, which could lead to a database performance problem", minimalSyncCheckInterval.String())
262		}
263		if c.RepositoriesCleanup.RunInterval.Duration() < minimalSyncRunInterval {
264			return fmt.Errorf("repositories_cleanup.run_interval is less then %s, which could lead to a database performance problem", minimalSyncRunInterval.String())
265		}
266	}
267
268	return nil
269}
270
271// NeedsSQL returns true if the driver for SQL needs to be initialized
272func (c *Config) NeedsSQL() bool {
273	return !c.MemoryQueueEnabled || (c.Failover.Enabled && c.Failover.ElectionStrategy != ElectionStrategyLocal)
274}
275
276func (c *Config) setDefaults() {
277	if c.GracefulStopTimeout.Duration() == 0 {
278		c.GracefulStopTimeout = config.Duration(time.Minute)
279	}
280
281	if c.Failover.Enabled {
282		if c.Failover.BootstrapInterval.Duration() == 0 {
283			c.Failover.BootstrapInterval = config.Duration(time.Second)
284		}
285
286		if c.Failover.MonitorInterval.Duration() == 0 {
287			c.Failover.MonitorInterval = config.Duration(3 * time.Second)
288		}
289	}
290}
291
292// VirtualStorageNames returns names of all virtual storages configured.
293func (c *Config) VirtualStorageNames() []string {
294	names := make([]string, len(c.VirtualStorages))
295	for i, virtual := range c.VirtualStorages {
296		names[i] = virtual.Name
297	}
298	return names
299}
300
301// StorageNames returns storage names by virtual storage.
302func (c *Config) StorageNames() map[string][]string {
303	storages := make(map[string][]string, len(c.VirtualStorages))
304	for _, vs := range c.VirtualStorages {
305		nodes := make([]string, len(vs.Nodes))
306		for i, n := range vs.Nodes {
307			nodes[i] = n.Storage
308		}
309
310		storages[vs.Name] = nodes
311	}
312
313	return storages
314}
315
316// DefaultReplicationFactors returns a map with the default replication factors of
317// the virtual storages.
318func (c Config) DefaultReplicationFactors() map[string]int {
319	replicationFactors := make(map[string]int, len(c.VirtualStorages))
320	for _, vs := range c.VirtualStorages {
321		replicationFactors[vs.Name] = vs.DefaultReplicationFactor
322	}
323
324	return replicationFactors
325}
326
327// DBConnection holds Postgres client configuration data.
328type DBConnection struct {
329	Host        string `toml:"host"`
330	Port        int    `toml:"port"`
331	User        string `toml:"user"`
332	Password    string `toml:"password"`
333	DBName      string `toml:"dbname"`
334	SSLMode     string `toml:"sslmode"`
335	SSLCert     string `toml:"sslcert"`
336	SSLKey      string `toml:"sslkey"`
337	SSLRootCert string `toml:"sslrootcert"`
338}
339
340// DB holds database configuration data.
341type DB struct {
342	Host        string `toml:"host"`
343	Port        int    `toml:"port"`
344	User        string `toml:"user"`
345	Password    string `toml:"password"`
346	DBName      string `toml:"dbname"`
347	SSLMode     string `toml:"sslmode"`
348	SSLCert     string `toml:"sslcert"`
349	SSLKey      string `toml:"sslkey"`
350	SSLRootCert string `toml:"sslrootcert"`
351
352	SessionPooled DBConnection `toml:"session_pooled"`
353
354	// The following configuration keys are deprecated and
355	// will be removed. Use Host and Port attributes of
356	// SessionPooled instead.
357	HostNoProxy string `toml:"host_no_proxy"`
358	PortNoProxy int    `toml:"port_no_proxy"`
359}
360
361func coalesceStr(values ...string) string {
362	for _, cur := range values {
363		if cur != "" {
364			return cur
365		}
366	}
367	return ""
368}
369
370func coalesceInt(values ...int) int {
371	for _, cur := range values {
372		if cur != 0 {
373			return cur
374		}
375	}
376	return 0
377}
378
379// ToPQString returns a connection string that can be passed to github.com/lib/pq.
380func (db DB) ToPQString(direct bool) string {
381	var hostVal, userVal, passwordVal, dbNameVal string
382	var sslModeVal, sslCertVal, sslKeyVal, sslRootCertVal string
383	var portVal int
384
385	if direct {
386		hostVal = coalesceStr(db.SessionPooled.Host, db.HostNoProxy, db.Host)
387		portVal = coalesceInt(db.SessionPooled.Port, db.PortNoProxy, db.Port)
388		userVal = coalesceStr(db.SessionPooled.User, db.User)
389		passwordVal = coalesceStr(db.SessionPooled.Password, db.Password)
390		dbNameVal = coalesceStr(db.SessionPooled.DBName, db.DBName)
391		sslModeVal = coalesceStr(db.SessionPooled.SSLMode, db.SSLMode)
392		sslCertVal = coalesceStr(db.SessionPooled.SSLCert, db.SSLCert)
393		sslKeyVal = coalesceStr(db.SessionPooled.SSLKey, db.SSLKey)
394		sslRootCertVal = coalesceStr(db.SessionPooled.SSLRootCert, db.SSLRootCert)
395	} else {
396		hostVal = db.Host
397		portVal = db.Port
398		userVal = db.User
399		passwordVal = db.Password
400		dbNameVal = db.DBName
401		sslModeVal = db.SSLMode
402		sslCertVal = db.SSLCert
403		sslKeyVal = db.SSLKey
404		sslRootCertVal = db.SSLRootCert
405	}
406
407	var fields []string
408	if portVal > 0 {
409		fields = append(fields, fmt.Sprintf("port=%d", portVal))
410	}
411
412	for _, kv := range []struct{ key, value string }{
413		{"host", hostVal},
414		{"user", userVal},
415		{"password", passwordVal},
416		{"dbname", dbNameVal},
417		{"sslmode", sslModeVal},
418		{"sslcert", sslCertVal},
419		{"sslkey", sslKeyVal},
420		{"sslrootcert", sslRootCertVal},
421		{"binary_parameters", "yes"},
422	} {
423		if len(kv.value) == 0 {
424			continue
425		}
426
427		kv.value = strings.ReplaceAll(kv.value, "'", `\'`)
428		kv.value = strings.ReplaceAll(kv.value, " ", `\ `)
429
430		fields = append(fields, kv.key+"="+kv.value)
431	}
432
433	return strings.Join(fields, " ")
434}
435
436// RepositoriesCleanup configures repository synchronisation.
437type RepositoriesCleanup struct {
438	// CheckInterval is a time period used to check if operation should be executed.
439	// It is recommended to keep it less than run_interval configuration as some
440	// nodes may be out of service, so they can be stale for too long.
441	CheckInterval config.Duration `toml:"check_interval"`
442	// RunInterval: the check runs if the previous operation was done at least RunInterval before.
443	RunInterval config.Duration `toml:"run_interval"`
444	// RepositoriesInBatch is the number of repositories to pass as a batch for processing.
445	RepositoriesInBatch int `toml:"repositories_in_batch"`
446}
447
448// DefaultRepositoriesCleanup contains default configuration values for the RepositoriesCleanup.
449func DefaultRepositoriesCleanup() RepositoriesCleanup {
450	return RepositoriesCleanup{
451		CheckInterval:       config.Duration(30 * time.Minute),
452		RunInterval:         config.Duration(24 * time.Hour),
453		RepositoriesInBatch: 16,
454	}
455}
456