1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package socktest provides utilities for socket testing.
6package socktest
7
8import (
9	"fmt"
10	"sync"
11)
12
13// A Switch represents a callpath point switch for socket system
14// calls.
15type Switch struct {
16	once sync.Once
17
18	fmu   sync.RWMutex
19	fltab map[FilterType]Filter
20
21	smu   sync.RWMutex
22	sotab Sockets
23	stats stats
24}
25
26func (sw *Switch) init() {
27	sw.fltab = make(map[FilterType]Filter)
28	sw.sotab = make(Sockets)
29	sw.stats = make(stats)
30}
31
32// Stats returns a list of per-cookie socket statistics.
33func (sw *Switch) Stats() []Stat {
34	var st []Stat
35	sw.smu.RLock()
36	for _, s := range sw.stats {
37		ns := *s
38		st = append(st, ns)
39	}
40	sw.smu.RUnlock()
41	return st
42}
43
44// Sockets returns mappings of socket descriptor to socket status.
45func (sw *Switch) Sockets() Sockets {
46	sw.smu.RLock()
47	tab := make(Sockets, len(sw.sotab))
48	for i, s := range sw.sotab {
49		tab[i] = s
50	}
51	sw.smu.RUnlock()
52	return tab
53}
54
55// A Cookie represents a 3-tuple of a socket; address family, socket
56// type and protocol number.
57type Cookie uint64
58
59// Family returns an address family.
60func (c Cookie) Family() int { return int(c >> 48) }
61
62// Type returns a socket type.
63func (c Cookie) Type() int { return int(c << 16 >> 32) }
64
65// Protocol returns a protocol number.
66func (c Cookie) Protocol() int { return int(c & 0xff) }
67
68func cookie(family, sotype, proto int) Cookie {
69	return Cookie(family)<<48 | Cookie(sotype)&0xffffffff<<16 | Cookie(proto)&0xff
70}
71
72// A Status represents the status of a socket.
73type Status struct {
74	Cookie    Cookie
75	Err       error // error status of socket system call
76	SocketErr error // error status of socket by SO_ERROR
77}
78
79func (so Status) String() string {
80	return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
81}
82
83// A Stat represents a per-cookie socket statistics.
84type Stat struct {
85	Family   int // address family
86	Type     int // socket type
87	Protocol int // protocol number
88
89	Opened    uint64 // number of sockets opened
90	Connected uint64 // number of sockets connected
91	Listened  uint64 // number of sockets listened
92	Accepted  uint64 // number of sockets accepted
93	Closed    uint64 // number of sockets closed
94
95	OpenFailed    uint64 // number of sockets open failed
96	ConnectFailed uint64 // number of sockets connect failed
97	ListenFailed  uint64 // number of sockets listen failed
98	AcceptFailed  uint64 // number of sockets accept failed
99	CloseFailed   uint64 // number of sockets close failed
100}
101
102func (st Stat) String() string {
103	return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
104}
105
106type stats map[Cookie]*Stat
107
108func (st stats) getLocked(c Cookie) *Stat {
109	s, ok := st[c]
110	if !ok {
111		s = &Stat{Family: c.Family(), Type: c.Type(), Protocol: c.Protocol()}
112		st[c] = s
113	}
114	return s
115}
116
117// A FilterType represents a filter type.
118type FilterType int
119
120const (
121	FilterSocket        FilterType = iota // for Socket
122	FilterConnect                         // for Connect or ConnectEx
123	FilterListen                          // for Listen
124	FilterAccept                          // for Accept, Accept4 or AcceptEx
125	FilterGetsockoptInt                   // for GetsockoptInt
126	FilterClose                           // for Close or Closesocket
127)
128
129// A Filter represents a socket system call filter.
130//
131// It will only be executed before a system call for a socket that has
132// an entry in internal table.
133// If the filter returns a non-nil error, the execution of system call
134// will be canceled and the system call function returns the non-nil
135// error.
136// It can return a non-nil AfterFilter for filtering after the
137// execution of the system call.
138type Filter func(*Status) (AfterFilter, error)
139
140func (f Filter) apply(st *Status) (AfterFilter, error) {
141	if f == nil {
142		return nil, nil
143	}
144	return f(st)
145}
146
147// An AfterFilter represents a socket system call filter after an
148// execution of a system call.
149//
150// It will only be executed after a system call for a socket that has
151// an entry in internal table.
152// If the filter returns a non-nil error, the system call function
153// returns the non-nil error.
154type AfterFilter func(*Status) error
155
156func (f AfterFilter) apply(st *Status) error {
157	if f == nil {
158		return nil
159	}
160	return f(st)
161}
162
163// Set deploys the socket system call filter f for the filter type t.
164func (sw *Switch) Set(t FilterType, f Filter) {
165	sw.once.Do(sw.init)
166	sw.fmu.Lock()
167	sw.fltab[t] = f
168	sw.fmu.Unlock()
169}
170