1package netlink
2
3import (
4	"fmt"
5	"syscall"
6
7	"github.com/vishvananda/netlink/nl"
8)
9
10func writeStateAlgo(a *XfrmStateAlgo) []byte {
11	algo := nl.XfrmAlgo{
12		AlgKeyLen: uint32(len(a.Key) * 8),
13		AlgKey:    a.Key,
14	}
15	end := len(a.Name)
16	if end > 64 {
17		end = 64
18	}
19	copy(algo.AlgName[:end], a.Name)
20	return algo.Serialize()
21}
22
23func writeStateAlgoAuth(a *XfrmStateAlgo) []byte {
24	algo := nl.XfrmAlgoAuth{
25		AlgKeyLen:   uint32(len(a.Key) * 8),
26		AlgTruncLen: uint32(a.TruncateLen),
27		AlgKey:      a.Key,
28	}
29	end := len(a.Name)
30	if end > 64 {
31		end = 64
32	}
33	copy(algo.AlgName[:end], a.Name)
34	return algo.Serialize()
35}
36
37// XfrmStateAdd will add an xfrm state to the system.
38// Equivalent to: `ip xfrm state add $state`
39func XfrmStateAdd(state *XfrmState) error {
40	// A state with spi 0 can't be deleted so don't allow it to be set
41	if state.Spi == 0 {
42		return fmt.Errorf("Spi must be set when adding xfrm state.")
43	}
44	req := nl.NewNetlinkRequest(nl.XFRM_MSG_NEWSA, syscall.NLM_F_CREATE|syscall.NLM_F_EXCL|syscall.NLM_F_ACK)
45
46	msg := &nl.XfrmUsersaInfo{}
47	msg.Family = uint16(nl.GetIPFamily(state.Dst))
48	msg.Id.Daddr.FromIP(state.Dst)
49	msg.Saddr.FromIP(state.Src)
50	msg.Id.Proto = uint8(state.Proto)
51	msg.Mode = uint8(state.Mode)
52	msg.Id.Spi = nl.Swap32(uint32(state.Spi))
53	msg.Reqid = uint32(state.Reqid)
54	msg.ReplayWindow = uint8(state.ReplayWindow)
55	msg.Lft.SoftByteLimit = nl.XFRM_INF
56	msg.Lft.HardByteLimit = nl.XFRM_INF
57	msg.Lft.SoftPacketLimit = nl.XFRM_INF
58	msg.Lft.HardPacketLimit = nl.XFRM_INF
59	req.AddData(msg)
60
61	if state.Auth != nil {
62		out := nl.NewRtAttr(nl.XFRMA_ALG_AUTH_TRUNC, writeStateAlgoAuth(state.Auth))
63		req.AddData(out)
64	}
65	if state.Crypt != nil {
66		out := nl.NewRtAttr(nl.XFRMA_ALG_CRYPT, writeStateAlgo(state.Crypt))
67		req.AddData(out)
68	}
69	if state.Encap != nil {
70		encapData := make([]byte, nl.SizeofXfrmEncapTmpl)
71		encap := nl.DeserializeXfrmEncapTmpl(encapData)
72		encap.EncapType = uint16(state.Encap.Type)
73		encap.EncapSport = nl.Swap16(uint16(state.Encap.SrcPort))
74		encap.EncapDport = nl.Swap16(uint16(state.Encap.DstPort))
75		encap.EncapOa.FromIP(state.Encap.OriginalAddress)
76		out := nl.NewRtAttr(nl.XFRMA_ENCAP, encapData)
77		req.AddData(out)
78	}
79
80	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
81	return err
82}
83
84// XfrmStateDel will delete an xfrm state from the system. Note that
85// the Algos are ignored when matching the state to delete.
86// Equivalent to: `ip xfrm state del $state`
87func XfrmStateDel(state *XfrmState) error {
88	req := nl.NewNetlinkRequest(nl.XFRM_MSG_DELSA, syscall.NLM_F_ACK)
89
90	msg := &nl.XfrmUsersaId{}
91	msg.Daddr.FromIP(state.Dst)
92	msg.Family = uint16(nl.GetIPFamily(state.Dst))
93	msg.Proto = uint8(state.Proto)
94	msg.Spi = nl.Swap32(uint32(state.Spi))
95	req.AddData(msg)
96
97	saddr := nl.XfrmAddress{}
98	saddr.FromIP(state.Src)
99	srcdata := nl.NewRtAttr(nl.XFRMA_SRCADDR, saddr.Serialize())
100
101	req.AddData(srcdata)
102
103	_, err := req.Execute(syscall.NETLINK_XFRM, 0)
104	return err
105}
106
107// XfrmStateList gets a list of xfrm states in the system.
108// Equivalent to: `ip xfrm state show`.
109// The list can be filtered by ip family.
110func XfrmStateList(family int) ([]XfrmState, error) {
111	req := nl.NewNetlinkRequest(nl.XFRM_MSG_GETSA, syscall.NLM_F_DUMP)
112
113	msg := nl.NewIfInfomsg(family)
114	req.AddData(msg)
115
116	msgs, err := req.Execute(syscall.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
117	if err != nil {
118		return nil, err
119	}
120
121	var res []XfrmState
122	for _, m := range msgs {
123		msg := nl.DeserializeXfrmUsersaInfo(m)
124
125		if family != FAMILY_ALL && family != int(msg.Family) {
126			continue
127		}
128
129		var state XfrmState
130
131		state.Dst = msg.Id.Daddr.ToIP()
132		state.Src = msg.Saddr.ToIP()
133		state.Proto = Proto(msg.Id.Proto)
134		state.Mode = Mode(msg.Mode)
135		state.Spi = int(nl.Swap32(msg.Id.Spi))
136		state.Reqid = int(msg.Reqid)
137		state.ReplayWindow = int(msg.ReplayWindow)
138
139		attrs, err := nl.ParseRouteAttr(m[msg.Len():])
140		if err != nil {
141			return nil, err
142		}
143
144		for _, attr := range attrs {
145			switch attr.Attr.Type {
146			case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
147				var resAlgo *XfrmStateAlgo
148				if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
149					if state.Auth == nil {
150						state.Auth = new(XfrmStateAlgo)
151					}
152					resAlgo = state.Auth
153				} else {
154					state.Crypt = new(XfrmStateAlgo)
155					resAlgo = state.Crypt
156				}
157				algo := nl.DeserializeXfrmAlgo(attr.Value[:])
158				(*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
159				(*resAlgo).Key = algo.AlgKey
160			case nl.XFRMA_ALG_AUTH_TRUNC:
161				if state.Auth == nil {
162					state.Auth = new(XfrmStateAlgo)
163				}
164				algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
165				state.Auth.Name = nl.BytesToString(algo.AlgName[:])
166				state.Auth.Key = algo.AlgKey
167				state.Auth.TruncateLen = int(algo.AlgTruncLen)
168			case nl.XFRMA_ENCAP:
169				encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
170				state.Encap = new(XfrmStateEncap)
171				state.Encap.Type = EncapType(encap.EncapType)
172				state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
173				state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
174				state.Encap.OriginalAddress = encap.EncapOa.ToIP()
175			}
176
177		}
178		res = append(res, state)
179	}
180	return res, nil
181}
182