1package main
2
3import (
4	"context"
5	"flag"
6	"fmt"
7	"io"
8	"strings"
9
10	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/config"
11	"gitlab.com/gitlab-org/gitaly/v14/proto/go/gitalypb"
12)
13
14const paramReplicationFactor = "replication-factor"
15
16type setReplicationFactorSubcommand struct {
17	stdout            io.Writer
18	virtualStorage    string
19	relativePath      string
20	replicationFactor int
21}
22
23func newSetReplicatioFactorSubcommand(stdout io.Writer) *setReplicationFactorSubcommand {
24	return &setReplicationFactorSubcommand{stdout: stdout}
25}
26
27func (cmd *setReplicationFactorSubcommand) FlagSet() *flag.FlagSet {
28	fs := flag.NewFlagSet("set-replication-factor", flag.ContinueOnError)
29	fs.StringVar(&cmd.virtualStorage, paramVirtualStorage, "", "name of the repository's virtual storage")
30	fs.StringVar(&cmd.relativePath, paramRelativePath, "", "repository to set the replication factor for")
31	fs.IntVar(&cmd.replicationFactor, paramReplicationFactor, -1, "desired replication factor")
32	return fs
33}
34
35func (cmd *setReplicationFactorSubcommand) Exec(flags *flag.FlagSet, cfg config.Config) error {
36	if flags.NArg() > 0 {
37		return unexpectedPositionalArgsError{Command: flags.Name()}
38	} else if cmd.virtualStorage == "" {
39		return requiredParameterError(paramVirtualStorage)
40	} else if cmd.relativePath == "" {
41		return requiredParameterError(paramRelativePath)
42	} else if cmd.replicationFactor < 0 {
43		return requiredParameterError(paramReplicationFactor)
44	}
45
46	nodeAddr, err := getNodeAddress(cfg)
47	if err != nil {
48		return err
49	}
50
51	conn, err := subCmdDial(nodeAddr, cfg.Auth.Token)
52	if err != nil {
53		return fmt.Errorf("error dialing: %w", err)
54	}
55	defer conn.Close()
56
57	client := gitalypb.NewPraefectInfoServiceClient(conn)
58	resp, err := client.SetReplicationFactor(context.TODO(), &gitalypb.SetReplicationFactorRequest{
59		VirtualStorage:    cmd.virtualStorage,
60		RelativePath:      cmd.relativePath,
61		ReplicationFactor: int32(cmd.replicationFactor),
62	})
63	if err != nil {
64		return err
65	}
66
67	fmt.Fprintf(cmd.stdout, "current assignments: %v\n", strings.Join(resp.Storages, ", "))
68
69	return nil
70}
71