1// Copyright 2015 CoreOS, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package iptables
16
17import (
18	"bytes"
19	"fmt"
20	"io"
21	"net"
22	"os/exec"
23	"regexp"
24	"strconv"
25	"strings"
26	"syscall"
27)
28
29// Adds the output of stderr to exec.ExitError
30type Error struct {
31	exec.ExitError
32	cmd        exec.Cmd
33	msg        string
34	proto      Protocol
35	exitStatus *int //for overriding
36}
37
38func (e *Error) ExitStatus() int {
39	if e.exitStatus != nil {
40		return *e.exitStatus
41	}
42	return e.Sys().(syscall.WaitStatus).ExitStatus()
43}
44
45func (e *Error) Error() string {
46	return fmt.Sprintf("running %v: exit status %v: %v", e.cmd.Args, e.ExitStatus(), e.msg)
47}
48
49// IsNotExist returns true if the error is due to the chain or rule not existing
50func (e *Error) IsNotExist() bool {
51	return e.ExitStatus() == 1 &&
52		(e.msg == fmt.Sprintf("%s: Bad rule (does a matching rule exist in that chain?).\n", getIptablesCommand(e.proto)) ||
53			e.msg == fmt.Sprintf("%s: No chain/target/match by that name.\n", getIptablesCommand(e.proto)))
54}
55
56// Protocol to differentiate between IPv4 and IPv6
57type Protocol byte
58
59const (
60	ProtocolIPv4 Protocol = iota
61	ProtocolIPv6
62)
63
64type IPTables struct {
65	path           string
66	proto          Protocol
67	hasCheck       bool
68	hasWait        bool
69	hasRandomFully bool
70	v1             int
71	v2             int
72	v3             int
73	mode           string // the underlying iptables operating mode, e.g. nf_tables
74}
75
76// Stat represents a structured statistic entry.
77type Stat struct {
78	Packets     uint64     `json:"pkts"`
79	Bytes       uint64     `json:"bytes"`
80	Target      string     `json:"target"`
81	Protocol    string     `json:"prot"`
82	Opt         string     `json:"opt"`
83	Input       string     `json:"in"`
84	Output      string     `json:"out"`
85	Source      *net.IPNet `json:"source"`
86	Destination *net.IPNet `json:"destination"`
87	Options     string     `json:"options"`
88}
89
90// New creates a new IPTables.
91// For backwards compatibility, this always uses IPv4, i.e. "iptables".
92func New() (*IPTables, error) {
93	return NewWithProtocol(ProtocolIPv4)
94}
95
96// New creates a new IPTables for the given proto.
97// The proto will determine which command is used, either "iptables" or "ip6tables".
98func NewWithProtocol(proto Protocol) (*IPTables, error) {
99	path, err := exec.LookPath(getIptablesCommand(proto))
100	if err != nil {
101		return nil, err
102	}
103	vstring, err := getIptablesVersionString(path)
104	v1, v2, v3, mode, err := extractIptablesVersion(vstring)
105
106	checkPresent, waitPresent, randomFullyPresent := getIptablesCommandSupport(v1, v2, v3)
107
108	ipt := IPTables{
109		path:           path,
110		proto:          proto,
111		hasCheck:       checkPresent,
112		hasWait:        waitPresent,
113		hasRandomFully: randomFullyPresent,
114		v1:             v1,
115		v2:             v2,
116		v3:             v3,
117		mode:           mode,
118	}
119	return &ipt, nil
120}
121
122// Proto returns the protocol used by this IPTables.
123func (ipt *IPTables) Proto() Protocol {
124	return ipt.proto
125}
126
127// Exists checks if given rulespec in specified table/chain exists
128func (ipt *IPTables) Exists(table, chain string, rulespec ...string) (bool, error) {
129	if !ipt.hasCheck {
130		return ipt.existsForOldIptables(table, chain, rulespec)
131
132	}
133	cmd := append([]string{"-t", table, "-C", chain}, rulespec...)
134	err := ipt.run(cmd...)
135	eerr, eok := err.(*Error)
136	switch {
137	case err == nil:
138		return true, nil
139	case eok && eerr.ExitStatus() == 1:
140		return false, nil
141	default:
142		return false, err
143	}
144}
145
146// Insert inserts rulespec to specified table/chain (in specified pos)
147func (ipt *IPTables) Insert(table, chain string, pos int, rulespec ...string) error {
148	cmd := append([]string{"-t", table, "-I", chain, strconv.Itoa(pos)}, rulespec...)
149	return ipt.run(cmd...)
150}
151
152// Append appends rulespec to specified table/chain
153func (ipt *IPTables) Append(table, chain string, rulespec ...string) error {
154	cmd := append([]string{"-t", table, "-A", chain}, rulespec...)
155	return ipt.run(cmd...)
156}
157
158// AppendUnique acts like Append except that it won't add a duplicate
159func (ipt *IPTables) AppendUnique(table, chain string, rulespec ...string) error {
160	exists, err := ipt.Exists(table, chain, rulespec...)
161	if err != nil {
162		return err
163	}
164
165	if !exists {
166		return ipt.Append(table, chain, rulespec...)
167	}
168
169	return nil
170}
171
172// Delete removes rulespec in specified table/chain
173func (ipt *IPTables) Delete(table, chain string, rulespec ...string) error {
174	cmd := append([]string{"-t", table, "-D", chain}, rulespec...)
175	return ipt.run(cmd...)
176}
177
178// List rules in specified table/chain
179func (ipt *IPTables) List(table, chain string) ([]string, error) {
180	args := []string{"-t", table, "-S", chain}
181	return ipt.executeList(args)
182}
183
184// List rules (with counters) in specified table/chain
185func (ipt *IPTables) ListWithCounters(table, chain string) ([]string, error) {
186	args := []string{"-t", table, "-v", "-S", chain}
187	return ipt.executeList(args)
188}
189
190// ListChains returns a slice containing the name of each chain in the specified table.
191func (ipt *IPTables) ListChains(table string) ([]string, error) {
192	args := []string{"-t", table, "-S"}
193
194	result, err := ipt.executeList(args)
195	if err != nil {
196		return nil, err
197	}
198
199	// Iterate over rules to find all default (-P) and user-specified (-N) chains.
200	// Chains definition always come before rules.
201	// Format is the following:
202	// -P OUTPUT ACCEPT
203	// -N Custom
204	var chains []string
205	for _, val := range result {
206		if strings.HasPrefix(val, "-P") || strings.HasPrefix(val, "-N") {
207			chains = append(chains, strings.Fields(val)[1])
208		} else {
209			break
210		}
211	}
212	return chains, nil
213}
214
215// Stats lists rules including the byte and packet counts
216func (ipt *IPTables) Stats(table, chain string) ([][]string, error) {
217	args := []string{"-t", table, "-L", chain, "-n", "-v", "-x"}
218	lines, err := ipt.executeList(args)
219	if err != nil {
220		return nil, err
221	}
222
223	appendSubnet := func(addr string) string {
224		if strings.IndexByte(addr, byte('/')) < 0 {
225			if strings.IndexByte(addr, '.') < 0 {
226				return addr + "/128"
227			}
228			return addr + "/32"
229		}
230		return addr
231	}
232
233	ipv6 := ipt.proto == ProtocolIPv6
234
235	rows := [][]string{}
236	for i, line := range lines {
237		// Skip over chain name and field header
238		if i < 2 {
239			continue
240		}
241
242		// Fields:
243		// 0=pkts 1=bytes 2=target 3=prot 4=opt 5=in 6=out 7=source 8=destination 9=options
244		line = strings.TrimSpace(line)
245		fields := strings.Fields(line)
246
247		// The ip6tables verbose output cannot be naively split due to the default "opt"
248		// field containing 2 single spaces.
249		if ipv6 {
250			// Check if field 6 is "opt" or "source" address
251			dest := fields[6]
252			ip, _, _ := net.ParseCIDR(dest)
253			if ip == nil {
254				ip = net.ParseIP(dest)
255			}
256
257			// If we detected a CIDR or IP, the "opt" field is empty.. insert it.
258			if ip != nil {
259				f := []string{}
260				f = append(f, fields[:4]...)
261				f = append(f, "  ") // Empty "opt" field for ip6tables
262				f = append(f, fields[4:]...)
263				fields = f
264			}
265		}
266
267		// Adjust "source" and "destination" to include netmask, to match regular
268		// List output
269		fields[7] = appendSubnet(fields[7])
270		fields[8] = appendSubnet(fields[8])
271
272		// Combine "options" fields 9... into a single space-delimited field.
273		options := fields[9:]
274		fields = fields[:9]
275		fields = append(fields, strings.Join(options, " "))
276		rows = append(rows, fields)
277	}
278	return rows, nil
279}
280
281// ParseStat parses a single statistic row into a Stat struct. The input should
282// be a string slice that is returned from calling the Stat method.
283func (ipt *IPTables) ParseStat(stat []string) (parsed Stat, err error) {
284	// For forward-compatibility, expect at least 10 fields in the stat
285	if len(stat) < 10 {
286		return parsed, fmt.Errorf("stat contained fewer fields than expected")
287	}
288
289	// Convert the fields that are not plain strings
290	parsed.Packets, err = strconv.ParseUint(stat[0], 0, 64)
291	if err != nil {
292		return parsed, fmt.Errorf(err.Error(), "could not parse packets")
293	}
294	parsed.Bytes, err = strconv.ParseUint(stat[1], 0, 64)
295	if err != nil {
296		return parsed, fmt.Errorf(err.Error(), "could not parse bytes")
297	}
298	_, parsed.Source, err = net.ParseCIDR(stat[7])
299	if err != nil {
300		return parsed, fmt.Errorf(err.Error(), "could not parse source")
301	}
302	_, parsed.Destination, err = net.ParseCIDR(stat[8])
303	if err != nil {
304		return parsed, fmt.Errorf(err.Error(), "could not parse destination")
305	}
306
307	// Put the fields that are strings
308	parsed.Target = stat[2]
309	parsed.Protocol = stat[3]
310	parsed.Opt = stat[4]
311	parsed.Input = stat[5]
312	parsed.Output = stat[6]
313	parsed.Options = stat[9]
314
315	return parsed, nil
316}
317
318// StructuredStats returns statistics as structured data which may be further
319// parsed and marshaled.
320func (ipt *IPTables) StructuredStats(table, chain string) ([]Stat, error) {
321	rawStats, err := ipt.Stats(table, chain)
322	if err != nil {
323		return nil, err
324	}
325
326	structStats := []Stat{}
327	for _, rawStat := range rawStats {
328		stat, err := ipt.ParseStat(rawStat)
329		if err != nil {
330			return nil, err
331		}
332		structStats = append(structStats, stat)
333	}
334
335	return structStats, nil
336}
337
338func (ipt *IPTables) executeList(args []string) ([]string, error) {
339	var stdout bytes.Buffer
340	if err := ipt.runWithOutput(args, &stdout); err != nil {
341		return nil, err
342	}
343
344	rules := strings.Split(stdout.String(), "\n")
345
346	// strip trailing newline
347	if len(rules) > 0 && rules[len(rules)-1] == "" {
348		rules = rules[:len(rules)-1]
349	}
350
351	for i, rule := range rules {
352		rules[i] = filterRuleOutput(rule)
353	}
354
355	return rules, nil
356}
357
358// NewChain creates a new chain in the specified table.
359// If the chain already exists, it will result in an error.
360func (ipt *IPTables) NewChain(table, chain string) error {
361	return ipt.run("-t", table, "-N", chain)
362}
363
364const existsErr = 1
365
366// ClearChain flushed (deletes all rules) in the specified table/chain.
367// If the chain does not exist, a new one will be created
368func (ipt *IPTables) ClearChain(table, chain string) error {
369	err := ipt.NewChain(table, chain)
370
371	eerr, eok := err.(*Error)
372	switch {
373	case err == nil:
374		return nil
375	case eok && eerr.ExitStatus() == existsErr:
376		// chain already exists. Flush (clear) it.
377		return ipt.run("-t", table, "-F", chain)
378	default:
379		return err
380	}
381}
382
383// RenameChain renames the old chain to the new one.
384func (ipt *IPTables) RenameChain(table, oldChain, newChain string) error {
385	return ipt.run("-t", table, "-E", oldChain, newChain)
386}
387
388// DeleteChain deletes the chain in the specified table.
389// The chain must be empty
390func (ipt *IPTables) DeleteChain(table, chain string) error {
391	return ipt.run("-t", table, "-X", chain)
392}
393
394// ChangePolicy changes policy on chain to target
395func (ipt *IPTables) ChangePolicy(table, chain, target string) error {
396	return ipt.run("-t", table, "-P", chain, target)
397}
398
399// Check if the underlying iptables command supports the --random-fully flag
400func (ipt *IPTables) HasRandomFully() bool {
401	return ipt.hasRandomFully
402}
403
404// Return version components of the underlying iptables command
405func (ipt *IPTables) GetIptablesVersion() (int, int, int) {
406	return ipt.v1, ipt.v2, ipt.v3
407}
408
409// run runs an iptables command with the given arguments, ignoring
410// any stdout output
411func (ipt *IPTables) run(args ...string) error {
412	return ipt.runWithOutput(args, nil)
413}
414
415// runWithOutput runs an iptables command with the given arguments,
416// writing any stdout output to the given writer
417func (ipt *IPTables) runWithOutput(args []string, stdout io.Writer) error {
418	args = append([]string{ipt.path}, args...)
419	if ipt.hasWait {
420		args = append(args, "--wait")
421	} else {
422		fmu, err := newXtablesFileLock()
423		if err != nil {
424			return err
425		}
426		ul, err := fmu.tryLock()
427		if err != nil {
428			syscall.Close(fmu.fd)
429			return err
430		}
431		defer ul.Unlock()
432	}
433
434	var stderr bytes.Buffer
435	cmd := exec.Cmd{
436		Path:   ipt.path,
437		Args:   args,
438		Stdout: stdout,
439		Stderr: &stderr,
440	}
441
442	if err := cmd.Run(); err != nil {
443		switch e := err.(type) {
444		case *exec.ExitError:
445			return &Error{*e, cmd, stderr.String(), ipt.proto, nil}
446		default:
447			return err
448		}
449	}
450
451	return nil
452}
453
454// getIptablesCommand returns the correct command for the given protocol, either "iptables" or "ip6tables".
455func getIptablesCommand(proto Protocol) string {
456	if proto == ProtocolIPv6 {
457		return "ip6tables"
458	} else {
459		return "iptables"
460	}
461}
462
463// Checks if iptables has the "-C" and "--wait" flag
464func getIptablesCommandSupport(v1 int, v2 int, v3 int) (bool, bool, bool) {
465	return iptablesHasCheckCommand(v1, v2, v3), iptablesHasWaitCommand(v1, v2, v3), iptablesHasRandomFully(v1, v2, v3)
466}
467
468// getIptablesVersion returns the first three components of the iptables version
469// and the operating mode (e.g. nf_tables or legacy)
470// e.g. "iptables v1.3.66" would return (1, 3, 66, legacy, nil)
471func extractIptablesVersion(str string) (int, int, int, string, error) {
472	versionMatcher := regexp.MustCompile(`v([0-9]+)\.([0-9]+)\.([0-9]+)(?:\s+\((\w+))?`)
473	result := versionMatcher.FindStringSubmatch(str)
474	if result == nil {
475		return 0, 0, 0, "", fmt.Errorf("no iptables version found in string: %s", str)
476	}
477
478	v1, err := strconv.Atoi(result[1])
479	if err != nil {
480		return 0, 0, 0, "", err
481	}
482
483	v2, err := strconv.Atoi(result[2])
484	if err != nil {
485		return 0, 0, 0, "", err
486	}
487
488	v3, err := strconv.Atoi(result[3])
489	if err != nil {
490		return 0, 0, 0, "", err
491	}
492
493	mode := "legacy"
494	if result[4] != "" {
495		mode = result[4]
496	}
497	return v1, v2, v3, mode, nil
498}
499
500// Runs "iptables --version" to get the version string
501func getIptablesVersionString(path string) (string, error) {
502	cmd := exec.Command(path, "--version")
503	var out bytes.Buffer
504	cmd.Stdout = &out
505	err := cmd.Run()
506	if err != nil {
507		return "", err
508	}
509	return out.String(), nil
510}
511
512// Checks if an iptables version is after 1.4.11, when --check was added
513func iptablesHasCheckCommand(v1 int, v2 int, v3 int) bool {
514	if v1 > 1 {
515		return true
516	}
517	if v1 == 1 && v2 > 4 {
518		return true
519	}
520	if v1 == 1 && v2 == 4 && v3 >= 11 {
521		return true
522	}
523	return false
524}
525
526// Checks if an iptables version is after 1.4.20, when --wait was added
527func iptablesHasWaitCommand(v1 int, v2 int, v3 int) bool {
528	if v1 > 1 {
529		return true
530	}
531	if v1 == 1 && v2 > 4 {
532		return true
533	}
534	if v1 == 1 && v2 == 4 && v3 >= 20 {
535		return true
536	}
537	return false
538}
539
540// Checks if an iptables version is after 1.6.2, when --random-fully was added
541func iptablesHasRandomFully(v1 int, v2 int, v3 int) bool {
542	if v1 > 1 {
543		return true
544	}
545	if v1 == 1 && v2 > 6 {
546		return true
547	}
548	if v1 == 1 && v2 == 6 && v3 >= 2 {
549		return true
550	}
551	return false
552}
553
554// Checks if a rule specification exists for a table
555func (ipt *IPTables) existsForOldIptables(table, chain string, rulespec []string) (bool, error) {
556	rs := strings.Join(append([]string{"-A", chain}, rulespec...), " ")
557	args := []string{"-t", table, "-S"}
558	var stdout bytes.Buffer
559	err := ipt.runWithOutput(args, &stdout)
560	if err != nil {
561		return false, err
562	}
563	return strings.Contains(stdout.String(), rs), nil
564}
565
566// counterRegex is the regex used to detect nftables counter format
567var counterRegex = regexp.MustCompile(`^\[([0-9]+):([0-9]+)\] `)
568
569// filterRuleOutput works around some inconsistencies in output.
570// For example, when iptables is in legacy vs. nftables mode, it produces
571// different results.
572func filterRuleOutput(rule string) string {
573	out := rule
574
575	// work around an output difference in nftables mode where counters
576	// are output in iptables-save format, rather than iptables -S format
577	// The string begins with "[0:0]"
578	//
579	// Fixes #49
580	if groups := counterRegex.FindStringSubmatch(out); groups != nil {
581		// drop the brackets
582		out = out[len(groups[0]):]
583		out = fmt.Sprintf("%s -c %s %s", out, groups[1], groups[2])
584	}
585
586	return out
587}
588