1package nodes
2
3import (
4	"context"
5	"database/sql"
6	"errors"
7	"fmt"
8	"math/rand"
9	"sync"
10	"time"
11
12	grpc_prometheus "github.com/grpc-ecosystem/go-grpc-prometheus"
13	"github.com/sirupsen/logrus"
14	gitalyauth "gitlab.com/gitlab-org/gitaly/v14/auth"
15	"gitlab.com/gitlab-org/gitaly/v14/internal/gitaly/client"
16	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/commonerr"
17	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config"
18	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/datastore"
19	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/grpc-proxy/proxy"
20	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/metrics"
21	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/middleware"
22	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/nodes/tracker"
23	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry"
24	prommetrics "gitlab.com/gitlab-org/gitaly/v14/internal/prometheus/metrics"
25	"google.golang.org/grpc"
26	healthpb "google.golang.org/grpc/health/grpc_health_v1"
27)
28
29// Shard is a primary with a set of secondaries
30type Shard struct {
31	Primary     Node
32	Secondaries []Node
33}
34
35func (s Shard) GetNode(storage string) (Node, error) {
36	if storage == s.Primary.GetStorage() {
37		return s.Primary, nil
38	}
39
40	for _, node := range s.Secondaries {
41		if storage == node.GetStorage() {
42			return node, nil
43		}
44	}
45
46	return nil, fmt.Errorf("node with storage %q does not exist", storage)
47}
48
49// GetHealthySecondaries returns all secondaries of the shard whose which are
50// currently known to be healthy.
51func (s Shard) GetHealthySecondaries() []Node {
52	healthySecondaries := make([]Node, 0, len(s.Secondaries))
53	for _, secondary := range s.Secondaries {
54		if !secondary.IsHealthy() {
55			continue
56		}
57		healthySecondaries = append(healthySecondaries, secondary)
58	}
59	return healthySecondaries
60}
61
62// Manager is responsible for returning shards for virtual storages
63type Manager interface {
64	GetShard(ctx context.Context, virtualStorageName string) (Shard, error)
65	// GetSyncedNode returns a random storage node based on the state of the replication.
66	// It returns primary in case there are no up to date secondaries or error occurs.
67	GetSyncedNode(ctx context.Context, virtualStorageName, repoPath string) (Node, error)
68	// HealthyNodes returns healthy storages by virtual storage.
69	HealthyNodes() map[string][]string
70	// Nodes returns nodes by their virtual storages.
71	Nodes() map[string][]Node
72}
73
74const (
75	// healthcheckTimeout is the max duration allowed for checking of node health status.
76	// If check takes more time it considered as failed.
77	healthcheckTimeout = 1 * time.Second
78	// healthcheckThreshold is the number of consecutive healthpb.HealthCheckResponse_SERVING necessary
79	// for deeming a node "healthy"
80	healthcheckThreshold = 3
81)
82
83// Node represents some metadata of a node as well as a connection
84type Node interface {
85	GetStorage() string
86	GetAddress() string
87	GetToken() string
88	GetConnection() *grpc.ClientConn
89	// IsHealthy reports if node is healthy and can handle requests.
90	// Node considered healthy if last 'healthcheckThreshold' checks were positive.
91	IsHealthy() bool
92	// CheckHealth executes health check for the node and tracks last 'healthcheckThreshold' checks for it.
93	CheckHealth(context.Context) (bool, error)
94}
95
96// Mgr is a concrete type that adheres to the Manager interface
97type Mgr struct {
98	// strategies is a map of strategies keyed on virtual storage name
99	strategies map[string]leaderElectionStrategy
100	db         *sql.DB
101	// nodes contains nodes by their virtual storages
102	nodes map[string][]Node
103	csg   datastore.ConsistentStoragesGetter
104}
105
106// leaderElectionStrategy defines the interface by which primary and
107// secondaries are managed.
108type leaderElectionStrategy interface {
109	start(bootstrapInterval, monitorInterval time.Duration)
110	checkNodes(context.Context) error
111	GetShard(ctx context.Context) (Shard, error)
112}
113
114// ErrPrimaryNotHealthy indicates the primary of a shard is not in a healthy state and hence
115// should not be used for a new request
116var ErrPrimaryNotHealthy = errors.New("primary gitaly is not healthy")
117
118const dialTimeout = 10 * time.Second
119
120// Dial dials a node with the necessary interceptors configured.
121func Dial(ctx context.Context, node *config.Node, registry *protoregistry.Registry, errorTracker tracker.ErrorTracker, handshaker client.Handshaker) (*grpc.ClientConn, error) {
122	streamInterceptors := []grpc.StreamClientInterceptor{
123		grpc_prometheus.StreamClientInterceptor,
124	}
125
126	if errorTracker != nil {
127		streamInterceptors = append(streamInterceptors, middleware.StreamErrorHandler(registry, errorTracker, node.Storage))
128	}
129
130	dialOpts := []grpc.DialOption{
131		grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())),
132		grpc.WithPerRPCCredentials(gitalyauth.RPCCredentialsV2(node.Token)),
133		grpc.WithChainStreamInterceptor(streamInterceptors...),
134		grpc.WithChainUnaryInterceptor(grpc_prometheus.UnaryClientInterceptor),
135	}
136
137	return client.Dial(ctx, node.Address, dialOpts, handshaker)
138}
139
140// NewManager creates a new NodeMgr based on virtual storage configs
141func NewManager(
142	log *logrus.Entry,
143	c config.Config,
144	db *sql.DB,
145	csg datastore.ConsistentStoragesGetter,
146	latencyHistogram prommetrics.HistogramVec,
147	registry *protoregistry.Registry,
148	errorTracker tracker.ErrorTracker,
149	handshaker client.Handshaker,
150) (*Mgr, error) {
151	if !c.Failover.Enabled {
152		errorTracker = nil
153	}
154
155	ctx, cancel := context.WithTimeout(context.Background(), dialTimeout)
156	defer cancel()
157
158	nodes := make(map[string][]Node, len(c.VirtualStorages))
159	strategies := make(map[string]leaderElectionStrategy, len(c.VirtualStorages))
160	for _, virtualStorage := range c.VirtualStorages {
161		log = log.WithField("virtual_storage", virtualStorage.Name)
162
163		ns := make([]*nodeStatus, 0, len(virtualStorage.Nodes))
164		for _, node := range virtualStorage.Nodes {
165			conn, err := Dial(ctx, node, registry, errorTracker, handshaker)
166			if err != nil {
167				return nil, err
168			}
169
170			cs := newConnectionStatus(*node, conn, log, latencyHistogram, errorTracker)
171			ns = append(ns, cs)
172		}
173
174		for _, node := range ns {
175			nodes[virtualStorage.Name] = append(nodes[virtualStorage.Name], node)
176		}
177
178		if c.Failover.Enabled {
179			if c.Failover.ElectionStrategy == config.ElectionStrategySQL {
180				strategies[virtualStorage.Name] = newSQLElector(virtualStorage.Name, c, db, log, ns)
181			} else {
182				strategies[virtualStorage.Name] = newLocalElector(virtualStorage.Name, log, ns)
183			}
184		} else {
185			strategies[virtualStorage.Name] = newDisabledElector(virtualStorage.Name, ns)
186		}
187	}
188
189	return &Mgr{
190		db:         db,
191		strategies: strategies,
192		nodes:      nodes,
193		csg:        csg,
194	}, nil
195}
196
197// Start will bootstrap the node manager by calling healthcheck on the nodes as well as kicking off
198// the monitoring process. Start must be called before NodeMgr can be used.
199func (n *Mgr) Start(bootstrapInterval, monitorInterval time.Duration) {
200	for _, strategy := range n.strategies {
201		strategy.start(bootstrapInterval, monitorInterval)
202	}
203}
204
205// checkShards performs health checks on all the available shards. The
206// election strategy is responsible for determining the criteria for
207// when to elect a new primary and when a node is down.
208func (n *Mgr) checkShards() {
209	for _, strategy := range n.strategies {
210		ctx := context.Background()
211		strategy.checkNodes(ctx)
212	}
213}
214
215// ErrVirtualStorageNotExist indicates the node manager is not aware of the virtual storage for which a shard is being requested
216var ErrVirtualStorageNotExist = errors.New("virtual storage does not exist")
217
218// GetShard retrieves a shard for a virtual storage name
219func (n *Mgr) GetShard(ctx context.Context, virtualStorageName string) (Shard, error) {
220	strategy, ok := n.strategies[virtualStorageName]
221	if !ok {
222		return Shard{}, fmt.Errorf("virtual storage %q: %w", virtualStorageName, ErrVirtualStorageNotExist)
223	}
224
225	return strategy.GetShard(ctx)
226}
227
228// GetPrimary returns the current primary of a repository. This is an adapter so NodeManager can be used
229// as a praefect.PrimaryGetter in newer code which written to support repository specific primaries.
230func (n *Mgr) GetPrimary(ctx context.Context, virtualStorage, _ string) (string, error) {
231	shard, err := n.GetShard(ctx, virtualStorage)
232	if err != nil {
233		return "", err
234	}
235
236	return shard.Primary.GetStorage(), nil
237}
238
239func (n *Mgr) GetSyncedNode(ctx context.Context, virtualStorageName, repoPath string) (Node, error) {
240	upToDateStorages, err := n.csg.GetConsistentStorages(ctx, virtualStorageName, repoPath)
241	if err != nil && !errors.As(err, new(commonerr.RepositoryNotFoundError)) {
242		return nil, err
243	}
244
245	if len(upToDateStorages) == 0 {
246		// this possible when there is no data yet in the database for the repository
247		shard, err := n.GetShard(ctx, virtualStorageName)
248		if err != nil {
249			return nil, fmt.Errorf("get shard for %q: %w", virtualStorageName, err)
250		}
251
252		upToDateStorages = map[string]struct{}{shard.Primary.GetStorage(): {}}
253	}
254
255	healthyStorages := make([]Node, 0, len(upToDateStorages))
256	for _, node := range n.Nodes()[virtualStorageName] {
257		if !node.IsHealthy() {
258			continue
259		}
260
261		if _, ok := upToDateStorages[node.GetStorage()]; !ok {
262			continue
263		}
264
265		healthyStorages = append(healthyStorages, node)
266	}
267
268	if len(healthyStorages) == 0 {
269		return nil, fmt.Errorf("no healthy nodes: %w", ErrPrimaryNotHealthy)
270	}
271
272	return healthyStorages[rand.Intn(len(healthyStorages))], nil
273}
274
275func (n *Mgr) HealthyNodes() map[string][]string {
276	healthy := make(map[string][]string, len(n.nodes))
277	for vs, nodes := range n.nodes {
278		storages := make([]string, 0, len(nodes))
279		for _, node := range nodes {
280			if node.IsHealthy() {
281				storages = append(storages, node.GetStorage())
282			}
283		}
284
285		healthy[vs] = storages
286	}
287
288	return healthy
289}
290
291func (n *Mgr) Nodes() map[string][]Node { return n.nodes }
292
293func newConnectionStatus(node config.Node, cc *grpc.ClientConn, l logrus.FieldLogger, latencyHist prommetrics.HistogramVec, errorTracker tracker.ErrorTracker) *nodeStatus {
294	return &nodeStatus{
295		node:        node,
296		clientConn:  cc,
297		log:         l,
298		latencyHist: latencyHist,
299		errTracker:  errorTracker,
300	}
301}
302
303type nodeStatus struct {
304	node        config.Node
305	clientConn  *grpc.ClientConn
306	log         logrus.FieldLogger
307	latencyHist prommetrics.HistogramVec
308	mtx         sync.RWMutex
309	statuses    []bool
310	errTracker  tracker.ErrorTracker
311}
312
313// GetStorage gets the storage name of a node
314func (n *nodeStatus) GetStorage() string {
315	return n.node.Storage
316}
317
318// GetAddress gets the address of a node
319func (n *nodeStatus) GetAddress() string {
320	return n.node.Address
321}
322
323// GetToken gets the token of a node
324func (n *nodeStatus) GetToken() string {
325	return n.node.Token
326}
327
328// GetConnection gets the client connection of a node
329func (n *nodeStatus) GetConnection() *grpc.ClientConn {
330	return n.clientConn
331}
332
333func (n *nodeStatus) IsHealthy() bool {
334	n.mtx.RLock()
335	healthy := n.isHealthy()
336	n.mtx.RUnlock()
337	return healthy
338}
339
340func (n *nodeStatus) isHealthy() bool {
341	if len(n.statuses) < healthcheckThreshold {
342		return false
343	}
344
345	for _, ok := range n.statuses[len(n.statuses)-healthcheckThreshold:] {
346		if !ok {
347			return false
348		}
349	}
350
351	return true
352}
353
354func (n *nodeStatus) updateStatus(status bool) {
355	n.mtx.Lock()
356	n.statuses = append(n.statuses, status)
357	if len(n.statuses) > healthcheckThreshold {
358		n.statuses = n.statuses[1:]
359	}
360	n.mtx.Unlock()
361}
362
363func (n *nodeStatus) CheckHealth(ctx context.Context) (bool, error) {
364	health := healthpb.NewHealthClient(n.clientConn)
365	if n.errTracker != nil {
366		health = tracker.NewHealthClient(health, n.GetStorage(), n.errTracker)
367	}
368
369	ctx, cancel := context.WithTimeout(ctx, healthcheckTimeout)
370	defer cancel()
371
372	start := time.Now()
373	resp, err := health.Check(ctx, &healthpb.HealthCheckRequest{})
374	n.latencyHist.WithLabelValues(n.node.Storage).Observe(time.Since(start).Seconds())
375	if err != nil {
376		n.log.WithError(err).WithFields(logrus.Fields{
377			"storage": n.node.Storage,
378			"address": n.node.Address,
379		}).Warn("error when pinging healthcheck")
380	}
381
382	status := resp.GetStatus() == healthpb.HealthCheckResponse_SERVING
383
384	metrics.NodeLastHealthcheckGauge.WithLabelValues(n.GetStorage()).Set(metrics.BoolAsFloat(status))
385
386	n.updateStatus(status)
387
388	return status, err
389}
390