1package firewall
2
3import (
4	"encoding/json"
5	"fmt"
6	"io/ioutil"
7	"net"
8	"os"
9
10	"github.com/hetznercloud/hcloud-go/hcloud"
11	"github.com/hetznercloud/hcloud-go/hcloud/schema"
12	"github.com/spf13/cobra"
13
14	"github.com/hetznercloud/cli/internal/cmd/cmpl"
15	"github.com/hetznercloud/cli/internal/cmd/util"
16	"github.com/hetznercloud/cli/internal/state"
17)
18
19func newReplaceRulesCommand(cli *state.State) *cobra.Command {
20	cmd := &cobra.Command{
21		Use:                   "replace-rules FIREWALL FLAGS",
22		Short:                 "Replaces all rules from a Firewall from a file",
23		Args:                  cobra.ExactArgs(1),
24		ValidArgsFunction:     cmpl.SuggestArgs(cmpl.SuggestCandidatesF(cli.FirewallNames)),
25		TraverseChildren:      true,
26		DisableFlagsInUseLine: true,
27		PreRunE:               util.ChainRunE(cli.EnsureToken),
28		RunE:                  cli.Wrap(runFirewallReplaceRules),
29	}
30	cmd.Flags().String("rules-file", "", "JSON file containing your routes (use - to read from stdin). The structure of the file needs to be the same as within the API: https://docs.hetzner.cloud/#firewalls-get-a-firewall")
31	cmd.MarkFlagRequired("rules-file")
32	return cmd
33}
34
35func runFirewallReplaceRules(cli *state.State, cmd *cobra.Command, args []string) error {
36	idOrName := args[0]
37	firewall, _, err := cli.Client().Firewall.Get(cli.Context, idOrName)
38	if err != nil {
39		return err
40	}
41	if firewall == nil {
42		return fmt.Errorf("Firewall not found: %v", idOrName)
43	}
44
45	opts := hcloud.FirewallSetRulesOpts{}
46
47	rulesFile, _ := cmd.Flags().GetString("rules-file")
48
49	var data []byte
50	if rulesFile == "-" {
51		data, err = ioutil.ReadAll(os.Stdin)
52	} else {
53		data, err = ioutil.ReadFile(rulesFile)
54	}
55	if err != nil {
56		return err
57	}
58	var rules []schema.FirewallRule
59	err = json.Unmarshal(data, &rules)
60	if err != nil {
61		return err
62	}
63	for _, rule := range rules {
64		d := hcloud.FirewallRuleDirection(rule.Direction)
65		r := hcloud.FirewallRule{
66			Direction: d,
67			Protocol:  hcloud.FirewallRuleProtocol(rule.Protocol),
68			Port:      rule.Port,
69		}
70		switch d {
71		case hcloud.FirewallRuleDirectionOut:
72			r.DestinationIPs = make([]net.IPNet, 0, len(rule.DestinationIPs))
73			for i, ip := range rule.DestinationIPs {
74				_, n, err := net.ParseCIDR(ip)
75				if err != nil {
76					return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
77				}
78				r.DestinationIPs[i] = *n
79			}
80		case hcloud.FirewallRuleDirectionIn:
81			r.SourceIPs = make([]net.IPNet, 0, len(rule.SourceIPs))
82			for i, ip := range rule.SourceIPs {
83				_, n, err := net.ParseCIDR(ip)
84				if err != nil {
85					return fmt.Errorf("invalid CIDR on index %d : %s", i, err)
86				}
87				r.SourceIPs[i] = *n
88			}
89		}
90		opts.Rules = append(opts.Rules, r)
91	}
92
93	actions, _, err := cli.Client().Firewall.SetRules(cli.Context, firewall, opts)
94	if err != nil {
95		return err
96	}
97	if err := cli.ActionsProgresses(cli.Context, actions); err != nil {
98		return err
99	}
100	fmt.Printf("Firewall Rules for Firewall %d updated\n", firewall.ID)
101
102	return nil
103}
104