1package netlink
2
3import (
4	"fmt"
5	"unsafe"
6
7	"github.com/vishvananda/netlink/nl"
8	"golang.org/x/sys/unix"
9)
10
11func writeStateAlgo(a *XfrmStateAlgo) []byte {
12	algo := nl.XfrmAlgo{
13		AlgKeyLen: uint32(len(a.Key) * 8),
14		AlgKey:    a.Key,
15	}
16	end := len(a.Name)
17	if end > 64 {
18		end = 64
19	}
20	copy(algo.AlgName[:end], a.Name)
21	return algo.Serialize()
22}
23
24func writeStateAlgoAuth(a *XfrmStateAlgo) []byte {
25	algo := nl.XfrmAlgoAuth{
26		AlgKeyLen:   uint32(len(a.Key) * 8),
27		AlgTruncLen: uint32(a.TruncateLen),
28		AlgKey:      a.Key,
29	}
30	end := len(a.Name)
31	if end > 64 {
32		end = 64
33	}
34	copy(algo.AlgName[:end], a.Name)
35	return algo.Serialize()
36}
37
38func writeStateAlgoAead(a *XfrmStateAlgo) []byte {
39	algo := nl.XfrmAlgoAEAD{
40		AlgKeyLen: uint32(len(a.Key) * 8),
41		AlgICVLen: uint32(a.ICVLen),
42		AlgKey:    a.Key,
43	}
44	end := len(a.Name)
45	if end > 64 {
46		end = 64
47	}
48	copy(algo.AlgName[:end], a.Name)
49	return algo.Serialize()
50}
51
52func writeMark(m *XfrmMark) []byte {
53	mark := &nl.XfrmMark{
54		Value: m.Value,
55		Mask:  m.Mask,
56	}
57	if mark.Mask == 0 {
58		mark.Mask = ^uint32(0)
59	}
60	return mark.Serialize()
61}
62
63func writeReplayEsn(replayWindow int) []byte {
64	replayEsn := &nl.XfrmReplayStateEsn{
65		OSeq:         0,
66		Seq:          0,
67		OSeqHi:       0,
68		SeqHi:        0,
69		ReplayWindow: uint32(replayWindow),
70	}
71
72	// Linux stores the bitmap to identify the already received sequence packets in blocks of uint32 elements.
73	// Therefore bitmap length is the minimum number of uint32 elements needed. The following is a ceiling operation.
74	bytesPerElem := int(unsafe.Sizeof(replayEsn.BmpLen)) // Any uint32 variable is good for this
75	replayEsn.BmpLen = uint32((replayWindow + (bytesPerElem * 8) - 1) / (bytesPerElem * 8))
76
77	return replayEsn.Serialize()
78}
79
80// XfrmStateAdd will add an xfrm state to the system.
81// Equivalent to: `ip xfrm state add $state`
82func XfrmStateAdd(state *XfrmState) error {
83	return pkgHandle.XfrmStateAdd(state)
84}
85
86// XfrmStateAdd will add an xfrm state to the system.
87// Equivalent to: `ip xfrm state add $state`
88func (h *Handle) XfrmStateAdd(state *XfrmState) error {
89	return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_NEWSA)
90}
91
92// XfrmStateAllocSpi will allocate an xfrm state in the system.
93// Equivalent to: `ip xfrm state allocspi`
94func XfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
95	return pkgHandle.xfrmStateAllocSpi(state)
96}
97
98// XfrmStateUpdate will update an xfrm state to the system.
99// Equivalent to: `ip xfrm state update $state`
100func XfrmStateUpdate(state *XfrmState) error {
101	return pkgHandle.XfrmStateUpdate(state)
102}
103
104// XfrmStateUpdate will update an xfrm state to the system.
105// Equivalent to: `ip xfrm state update $state`
106func (h *Handle) XfrmStateUpdate(state *XfrmState) error {
107	return h.xfrmStateAddOrUpdate(state, nl.XFRM_MSG_UPDSA)
108}
109
110func (h *Handle) xfrmStateAddOrUpdate(state *XfrmState, nlProto int) error {
111
112	// A state with spi 0 can't be deleted so don't allow it to be set
113	if state.Spi == 0 {
114		return fmt.Errorf("Spi must be set when adding xfrm state.")
115	}
116	req := h.newNetlinkRequest(nlProto, unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
117
118	msg := xfrmUsersaInfoFromXfrmState(state)
119
120	if state.ESN {
121		if state.ReplayWindow == 0 {
122			return fmt.Errorf("ESN flag set without ReplayWindow")
123		}
124		msg.Flags |= nl.XFRM_STATE_ESN
125		msg.ReplayWindow = 0
126	}
127
128	limitsToLft(state.Limits, &msg.Lft)
129	req.AddData(msg)
130
131	if state.Auth != nil {
132		out := nl.NewRtAttr(nl.XFRMA_ALG_AUTH_TRUNC, writeStateAlgoAuth(state.Auth))
133		req.AddData(out)
134	}
135	if state.Crypt != nil {
136		out := nl.NewRtAttr(nl.XFRMA_ALG_CRYPT, writeStateAlgo(state.Crypt))
137		req.AddData(out)
138	}
139	if state.Aead != nil {
140		out := nl.NewRtAttr(nl.XFRMA_ALG_AEAD, writeStateAlgoAead(state.Aead))
141		req.AddData(out)
142	}
143	if state.Encap != nil {
144		encapData := make([]byte, nl.SizeofXfrmEncapTmpl)
145		encap := nl.DeserializeXfrmEncapTmpl(encapData)
146		encap.EncapType = uint16(state.Encap.Type)
147		encap.EncapSport = nl.Swap16(uint16(state.Encap.SrcPort))
148		encap.EncapDport = nl.Swap16(uint16(state.Encap.DstPort))
149		encap.EncapOa.FromIP(state.Encap.OriginalAddress)
150		out := nl.NewRtAttr(nl.XFRMA_ENCAP, encapData)
151		req.AddData(out)
152	}
153	if state.Mark != nil {
154		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
155		req.AddData(out)
156	}
157	if state.ESN {
158		out := nl.NewRtAttr(nl.XFRMA_REPLAY_ESN_VAL, writeReplayEsn(state.ReplayWindow))
159		req.AddData(out)
160	}
161	if state.OutputMark != 0 {
162		out := nl.NewRtAttr(nl.XFRMA_OUTPUT_MARK, nl.Uint32Attr(uint32(state.OutputMark)))
163		req.AddData(out)
164	}
165
166	ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid)))
167	req.AddData(ifId)
168
169	_, err := req.Execute(unix.NETLINK_XFRM, 0)
170	return err
171}
172
173func (h *Handle) xfrmStateAllocSpi(state *XfrmState) (*XfrmState, error) {
174	req := h.newNetlinkRequest(nl.XFRM_MSG_ALLOCSPI,
175		unix.NLM_F_CREATE|unix.NLM_F_EXCL|unix.NLM_F_ACK)
176
177	msg := &nl.XfrmUserSpiInfo{}
178	msg.XfrmUsersaInfo = *(xfrmUsersaInfoFromXfrmState(state))
179	// 1-255 is reserved by IANA for future use
180	msg.Min = 0x100
181	msg.Max = 0xffffffff
182	req.AddData(msg)
183
184	if state.Mark != nil {
185		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
186		req.AddData(out)
187	}
188
189	msgs, err := req.Execute(unix.NETLINK_XFRM, 0)
190	if err != nil {
191		return nil, err
192	}
193
194	return parseXfrmState(msgs[0], FAMILY_ALL)
195}
196
197// XfrmStateDel will delete an xfrm state from the system. Note that
198// the Algos are ignored when matching the state to delete.
199// Equivalent to: `ip xfrm state del $state`
200func XfrmStateDel(state *XfrmState) error {
201	return pkgHandle.XfrmStateDel(state)
202}
203
204// XfrmStateDel will delete an xfrm state from the system. Note that
205// the Algos are ignored when matching the state to delete.
206// Equivalent to: `ip xfrm state del $state`
207func (h *Handle) XfrmStateDel(state *XfrmState) error {
208	_, err := h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_DELSA)
209	return err
210}
211
212// XfrmStateList gets a list of xfrm states in the system.
213// Equivalent to: `ip [-4|-6] xfrm state show`.
214// The list can be filtered by ip family.
215func XfrmStateList(family int) ([]XfrmState, error) {
216	return pkgHandle.XfrmStateList(family)
217}
218
219// XfrmStateList gets a list of xfrm states in the system.
220// Equivalent to: `ip xfrm state show`.
221// The list can be filtered by ip family.
222func (h *Handle) XfrmStateList(family int) ([]XfrmState, error) {
223	req := h.newNetlinkRequest(nl.XFRM_MSG_GETSA, unix.NLM_F_DUMP)
224
225	msgs, err := req.Execute(unix.NETLINK_XFRM, nl.XFRM_MSG_NEWSA)
226	if err != nil {
227		return nil, err
228	}
229
230	var res []XfrmState
231	for _, m := range msgs {
232		if state, err := parseXfrmState(m, family); err == nil {
233			res = append(res, *state)
234		} else if err == familyError {
235			continue
236		} else {
237			return nil, err
238		}
239	}
240	return res, nil
241}
242
243// XfrmStateGet gets the xfrm state described by the ID, if found.
244// Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
245// Only the fields which constitue the SA ID must be filled in:
246// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
247// mark is optional
248func XfrmStateGet(state *XfrmState) (*XfrmState, error) {
249	return pkgHandle.XfrmStateGet(state)
250}
251
252// XfrmStateGet gets the xfrm state described by the ID, if found.
253// Equivalent to: `ip xfrm state get ID [ mark MARK [ mask MASK ] ]`.
254// Only the fields which constitue the SA ID must be filled in:
255// ID := [ src ADDR ] [ dst ADDR ] [ proto XFRM-PROTO ] [ spi SPI ]
256// mark is optional
257func (h *Handle) XfrmStateGet(state *XfrmState) (*XfrmState, error) {
258	return h.xfrmStateGetOrDelete(state, nl.XFRM_MSG_GETSA)
259}
260
261func (h *Handle) xfrmStateGetOrDelete(state *XfrmState, nlProto int) (*XfrmState, error) {
262	req := h.newNetlinkRequest(nlProto, unix.NLM_F_ACK)
263
264	msg := &nl.XfrmUsersaId{}
265	msg.Family = uint16(nl.GetIPFamily(state.Dst))
266	msg.Daddr.FromIP(state.Dst)
267	msg.Proto = uint8(state.Proto)
268	msg.Spi = nl.Swap32(uint32(state.Spi))
269	req.AddData(msg)
270
271	if state.Mark != nil {
272		out := nl.NewRtAttr(nl.XFRMA_MARK, writeMark(state.Mark))
273		req.AddData(out)
274	}
275	if state.Src != nil {
276		out := nl.NewRtAttr(nl.XFRMA_SRCADDR, state.Src.To16())
277		req.AddData(out)
278	}
279
280	ifId := nl.NewRtAttr(nl.XFRMA_IF_ID, nl.Uint32Attr(uint32(state.Ifid)))
281	req.AddData(ifId)
282
283	resType := nl.XFRM_MSG_NEWSA
284	if nlProto == nl.XFRM_MSG_DELSA {
285		resType = 0
286	}
287
288	msgs, err := req.Execute(unix.NETLINK_XFRM, uint16(resType))
289	if err != nil {
290		return nil, err
291	}
292
293	if nlProto == nl.XFRM_MSG_DELSA {
294		return nil, nil
295	}
296
297	s, err := parseXfrmState(msgs[0], FAMILY_ALL)
298	if err != nil {
299		return nil, err
300	}
301
302	return s, nil
303}
304
305var familyError = fmt.Errorf("family error")
306
307func xfrmStateFromXfrmUsersaInfo(msg *nl.XfrmUsersaInfo) *XfrmState {
308	var state XfrmState
309
310	state.Dst = msg.Id.Daddr.ToIP()
311	state.Src = msg.Saddr.ToIP()
312	state.Proto = Proto(msg.Id.Proto)
313	state.Mode = Mode(msg.Mode)
314	state.Spi = int(nl.Swap32(msg.Id.Spi))
315	state.Reqid = int(msg.Reqid)
316	state.ReplayWindow = int(msg.ReplayWindow)
317	lftToLimits(&msg.Lft, &state.Limits)
318	curToStats(&msg.Curlft, &msg.Stats, &state.Statistics)
319
320	return &state
321}
322
323func parseXfrmState(m []byte, family int) (*XfrmState, error) {
324	msg := nl.DeserializeXfrmUsersaInfo(m)
325
326	// This is mainly for the state dump
327	if family != FAMILY_ALL && family != int(msg.Family) {
328		return nil, familyError
329	}
330
331	state := xfrmStateFromXfrmUsersaInfo(msg)
332
333	attrs, err := nl.ParseRouteAttr(m[nl.SizeofXfrmUsersaInfo:])
334	if err != nil {
335		return nil, err
336	}
337
338	for _, attr := range attrs {
339		switch attr.Attr.Type {
340		case nl.XFRMA_ALG_AUTH, nl.XFRMA_ALG_CRYPT:
341			var resAlgo *XfrmStateAlgo
342			if attr.Attr.Type == nl.XFRMA_ALG_AUTH {
343				if state.Auth == nil {
344					state.Auth = new(XfrmStateAlgo)
345				}
346				resAlgo = state.Auth
347			} else {
348				state.Crypt = new(XfrmStateAlgo)
349				resAlgo = state.Crypt
350			}
351			algo := nl.DeserializeXfrmAlgo(attr.Value[:])
352			(*resAlgo).Name = nl.BytesToString(algo.AlgName[:])
353			(*resAlgo).Key = algo.AlgKey
354		case nl.XFRMA_ALG_AUTH_TRUNC:
355			if state.Auth == nil {
356				state.Auth = new(XfrmStateAlgo)
357			}
358			algo := nl.DeserializeXfrmAlgoAuth(attr.Value[:])
359			state.Auth.Name = nl.BytesToString(algo.AlgName[:])
360			state.Auth.Key = algo.AlgKey
361			state.Auth.TruncateLen = int(algo.AlgTruncLen)
362		case nl.XFRMA_ALG_AEAD:
363			state.Aead = new(XfrmStateAlgo)
364			algo := nl.DeserializeXfrmAlgoAEAD(attr.Value[:])
365			state.Aead.Name = nl.BytesToString(algo.AlgName[:])
366			state.Aead.Key = algo.AlgKey
367			state.Aead.ICVLen = int(algo.AlgICVLen)
368		case nl.XFRMA_ENCAP:
369			encap := nl.DeserializeXfrmEncapTmpl(attr.Value[:])
370			state.Encap = new(XfrmStateEncap)
371			state.Encap.Type = EncapType(encap.EncapType)
372			state.Encap.SrcPort = int(nl.Swap16(encap.EncapSport))
373			state.Encap.DstPort = int(nl.Swap16(encap.EncapDport))
374			state.Encap.OriginalAddress = encap.EncapOa.ToIP()
375		case nl.XFRMA_MARK:
376			mark := nl.DeserializeXfrmMark(attr.Value[:])
377			state.Mark = new(XfrmMark)
378			state.Mark.Value = mark.Value
379			state.Mark.Mask = mark.Mask
380		case nl.XFRMA_OUTPUT_MARK:
381			state.OutputMark = int(native.Uint32(attr.Value))
382		case nl.XFRMA_IF_ID:
383			state.Ifid = int(native.Uint32(attr.Value))
384		}
385	}
386
387	return state, nil
388}
389
390// XfrmStateFlush will flush the xfrm state on the system.
391// proto = 0 means any transformation protocols
392// Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
393func XfrmStateFlush(proto Proto) error {
394	return pkgHandle.XfrmStateFlush(proto)
395}
396
397// XfrmStateFlush will flush the xfrm state on the system.
398// proto = 0 means any transformation protocols
399// Equivalent to: `ip xfrm state flush [ proto XFRM-PROTO ]`
400func (h *Handle) XfrmStateFlush(proto Proto) error {
401	req := h.newNetlinkRequest(nl.XFRM_MSG_FLUSHSA, unix.NLM_F_ACK)
402
403	req.AddData(&nl.XfrmUsersaFlush{Proto: uint8(proto)})
404
405	_, err := req.Execute(unix.NETLINK_XFRM, 0)
406	return err
407}
408
409func limitsToLft(lmts XfrmStateLimits, lft *nl.XfrmLifetimeCfg) {
410	if lmts.ByteSoft != 0 {
411		lft.SoftByteLimit = lmts.ByteSoft
412	} else {
413		lft.SoftByteLimit = nl.XFRM_INF
414	}
415	if lmts.ByteHard != 0 {
416		lft.HardByteLimit = lmts.ByteHard
417	} else {
418		lft.HardByteLimit = nl.XFRM_INF
419	}
420	if lmts.PacketSoft != 0 {
421		lft.SoftPacketLimit = lmts.PacketSoft
422	} else {
423		lft.SoftPacketLimit = nl.XFRM_INF
424	}
425	if lmts.PacketHard != 0 {
426		lft.HardPacketLimit = lmts.PacketHard
427	} else {
428		lft.HardPacketLimit = nl.XFRM_INF
429	}
430	lft.SoftAddExpiresSeconds = lmts.TimeSoft
431	lft.HardAddExpiresSeconds = lmts.TimeHard
432	lft.SoftUseExpiresSeconds = lmts.TimeUseSoft
433	lft.HardUseExpiresSeconds = lmts.TimeUseHard
434}
435
436func lftToLimits(lft *nl.XfrmLifetimeCfg, lmts *XfrmStateLimits) {
437	*lmts = *(*XfrmStateLimits)(unsafe.Pointer(lft))
438}
439
440func curToStats(cur *nl.XfrmLifetimeCur, wstats *nl.XfrmStats, stats *XfrmStateStats) {
441	stats.Bytes = cur.Bytes
442	stats.Packets = cur.Packets
443	stats.AddTime = cur.AddTime
444	stats.UseTime = cur.UseTime
445	stats.ReplayWindow = wstats.ReplayWindow
446	stats.Replay = wstats.Replay
447	stats.Failed = wstats.IntegrityFailed
448}
449
450func xfrmUsersaInfoFromXfrmState(state *XfrmState) *nl.XfrmUsersaInfo {
451	msg := &nl.XfrmUsersaInfo{}
452	msg.Family = uint16(nl.GetIPFamily(state.Dst))
453	msg.Id.Daddr.FromIP(state.Dst)
454	msg.Saddr.FromIP(state.Src)
455	msg.Id.Proto = uint8(state.Proto)
456	msg.Mode = uint8(state.Mode)
457	msg.Id.Spi = nl.Swap32(uint32(state.Spi))
458	msg.Reqid = uint32(state.Reqid)
459	msg.ReplayWindow = uint8(state.ReplayWindow)
460
461	return msg
462}
463