1package netlink
2
3import (
4	"fmt"
5
6	"github.com/vishvananda/netlink/nl"
7	"github.com/vishvananda/netns"
8	"golang.org/x/sys/unix"
9)
10
11type XfrmMsg interface {
12	Type() nl.XfrmMsgType
13}
14
15type XfrmMsgExpire struct {
16	XfrmState *XfrmState
17	Hard      bool
18}
19
20func (ue *XfrmMsgExpire) Type() nl.XfrmMsgType {
21	return nl.XFRM_MSG_EXPIRE
22}
23
24func parseXfrmMsgExpire(b []byte) *XfrmMsgExpire {
25	var e XfrmMsgExpire
26
27	msg := nl.DeserializeXfrmUserExpire(b)
28	e.XfrmState = xfrmStateFromXfrmUsersaInfo(&msg.XfrmUsersaInfo)
29	e.Hard = msg.Hard == 1
30
31	return &e
32}
33
34func XfrmMonitor(ch chan<- XfrmMsg, done <-chan struct{}, errorChan chan<- error,
35	types ...nl.XfrmMsgType) error {
36
37	groups, err := xfrmMcastGroups(types)
38	if err != nil {
39		return nil
40	}
41	s, err := nl.SubscribeAt(netns.None(), netns.None(), unix.NETLINK_XFRM, groups...)
42	if err != nil {
43		return err
44	}
45
46	if done != nil {
47		go func() {
48			<-done
49			s.Close()
50		}()
51
52	}
53
54	go func() {
55		defer close(ch)
56		for {
57			msgs, from, err := s.Receive()
58			if err != nil {
59				errorChan <- err
60				return
61			}
62			if from.Pid != nl.PidKernel {
63				errorChan <- fmt.Errorf("Wrong sender portid %d, expected %d", from.Pid, nl.PidKernel)
64				return
65			}
66			for _, m := range msgs {
67				switch m.Header.Type {
68				case nl.XFRM_MSG_EXPIRE:
69					ch <- parseXfrmMsgExpire(m.Data)
70				default:
71					errorChan <- fmt.Errorf("unsupported msg type: %x", m.Header.Type)
72				}
73			}
74		}
75	}()
76
77	return nil
78}
79
80func xfrmMcastGroups(types []nl.XfrmMsgType) ([]uint, error) {
81	groups := make([]uint, 0)
82
83	if len(types) == 0 {
84		return nil, fmt.Errorf("no xfrm msg type specified")
85	}
86
87	for _, t := range types {
88		var group uint
89
90		switch t {
91		case nl.XFRM_MSG_EXPIRE:
92			group = nl.XFRMNLGRP_EXPIRE
93		default:
94			return nil, fmt.Errorf("unsupported group: %x", t)
95		}
96
97		groups = append(groups, group)
98	}
99
100	return groups, nil
101}
102