1package net
2
3import (
4	"fmt"
5	"math"
6	"os"
7	"runtime"
8	"testing"
9
10	"github.com/shirou/gopsutil/internal/common"
11)
12
13func skipIfNotImplementedErr(t *testing.T, err error) {
14	if err == common.ErrNotImplementedError {
15		t.Skip("not implemented")
16	}
17}
18
19func TestAddrString(t *testing.T) {
20	v := Addr{IP: "192.168.0.1", Port: 8000}
21
22	s := fmt.Sprintf("%v", v)
23	if s != "{\"ip\":\"192.168.0.1\",\"port\":8000}" {
24		t.Errorf("Addr string is invalid: %v", v)
25	}
26}
27
28func TestNetIOCountersStatString(t *testing.T) {
29	v := IOCountersStat{
30		Name:      "test",
31		BytesSent: 100,
32	}
33	e := `{"name":"test","bytesSent":100,"bytesRecv":0,"packetsSent":0,"packetsRecv":0,"errin":0,"errout":0,"dropin":0,"dropout":0,"fifoin":0,"fifoout":0}`
34	if e != fmt.Sprintf("%v", v) {
35		t.Errorf("NetIOCountersStat string is invalid: %v", v)
36	}
37}
38
39func TestNetProtoCountersStatString(t *testing.T) {
40	v := ProtoCountersStat{
41		Protocol: "tcp",
42		Stats: map[string]int64{
43			"MaxConn":      -1,
44			"ActiveOpens":  4000,
45			"PassiveOpens": 3000,
46		},
47	}
48	e := `{"protocol":"tcp","stats":{"ActiveOpens":4000,"MaxConn":-1,"PassiveOpens":3000}}`
49	if e != fmt.Sprintf("%v", v) {
50		t.Errorf("NetProtoCountersStat string is invalid: %v", v)
51	}
52
53}
54
55func TestNetConnectionStatString(t *testing.T) {
56	v := ConnectionStat{
57		Fd:     10,
58		Family: 10,
59		Type:   10,
60		Uids:   []int32{10, 10},
61	}
62	e := `{"fd":10,"family":10,"type":10,"localaddr":{"ip":"","port":0},"remoteaddr":{"ip":"","port":0},"status":"","uids":[10,10],"pid":0}`
63	if e != fmt.Sprintf("%v", v) {
64		t.Errorf("NetConnectionStat string is invalid: %v", v)
65	}
66
67}
68
69func TestNetIOCountersAll(t *testing.T) {
70	v, err := IOCounters(false)
71	skipIfNotImplementedErr(t, err)
72	if err != nil {
73		t.Errorf("Could not get NetIOCounters: %v", err)
74	}
75	per, err := IOCounters(true)
76	skipIfNotImplementedErr(t, err)
77	if err != nil {
78		t.Errorf("Could not get NetIOCounters: %v", err)
79	}
80	if len(v) != 1 {
81		t.Errorf("Could not get NetIOCounters: %v", v)
82	}
83	if v[0].Name != "all" {
84		t.Errorf("Invalid NetIOCounters: %v", v)
85	}
86	var pr uint64
87	for _, p := range per {
88		pr += p.PacketsRecv
89	}
90	// small diff is ok
91	if math.Abs(float64(v[0].PacketsRecv-pr)) > 5 {
92		if ci := os.Getenv("CI"); ci != "" {
93			// This test often fails in CI. so just print even if failed.
94			fmt.Printf("invalid sum value: %v, %v", v[0].PacketsRecv, pr)
95		} else {
96			t.Errorf("invalid sum value: %v, %v", v[0].PacketsRecv, pr)
97		}
98	}
99}
100
101func TestNetIOCountersPerNic(t *testing.T) {
102	v, err := IOCounters(true)
103	skipIfNotImplementedErr(t, err)
104	if err != nil {
105		t.Errorf("Could not get NetIOCounters: %v", err)
106	}
107	if len(v) == 0 {
108		t.Errorf("Could not get NetIOCounters: %v", v)
109	}
110	for _, vv := range v {
111		if vv.Name == "" {
112			t.Errorf("Invalid NetIOCounters: %v", vv)
113		}
114	}
115}
116
117func TestGetNetIOCountersAll(t *testing.T) {
118	n := []IOCountersStat{
119		{
120			Name:        "a",
121			BytesRecv:   10,
122			PacketsRecv: 10,
123		},
124		{
125			Name:        "b",
126			BytesRecv:   10,
127			PacketsRecv: 10,
128			Errin:       10,
129		},
130	}
131	ret, err := getIOCountersAll(n)
132	skipIfNotImplementedErr(t, err)
133	if err != nil {
134		t.Error(err)
135	}
136	if len(ret) != 1 {
137		t.Errorf("invalid return count")
138	}
139	if ret[0].Name != "all" {
140		t.Errorf("invalid return name")
141	}
142	if ret[0].BytesRecv != 20 {
143		t.Errorf("invalid count bytesrecv")
144	}
145	if ret[0].Errin != 10 {
146		t.Errorf("invalid count errin")
147	}
148}
149
150func TestNetInterfaces(t *testing.T) {
151	v, err := Interfaces()
152	skipIfNotImplementedErr(t, err)
153	if err != nil {
154		t.Errorf("Could not get NetInterfaceStat: %v", err)
155	}
156	if len(v) == 0 {
157		t.Errorf("Could not get NetInterfaceStat: %v", err)
158	}
159	for _, vv := range v {
160		if vv.Name == "" {
161			t.Errorf("Invalid NetInterface: %v", vv)
162		}
163	}
164}
165
166func TestNetProtoCountersStatsAll(t *testing.T) {
167	v, err := ProtoCounters(nil)
168	skipIfNotImplementedErr(t, err)
169	if err != nil {
170		t.Fatalf("Could not get NetProtoCounters: %v", err)
171	}
172	if len(v) == 0 {
173		t.Fatalf("Could not get NetProtoCounters: %v", err)
174	}
175	for _, vv := range v {
176		if vv.Protocol == "" {
177			t.Errorf("Invalid NetProtoCountersStat: %v", vv)
178		}
179		if len(vv.Stats) == 0 {
180			t.Errorf("Invalid NetProtoCountersStat: %v", vv)
181		}
182	}
183}
184
185func TestNetProtoCountersStats(t *testing.T) {
186	v, err := ProtoCounters([]string{"tcp", "ip"})
187	skipIfNotImplementedErr(t, err)
188	if err != nil {
189		t.Fatalf("Could not get NetProtoCounters: %v", err)
190	}
191	if len(v) == 0 {
192		t.Fatalf("Could not get NetProtoCounters: %v", err)
193	}
194	if len(v) != 2 {
195		t.Fatalf("Go incorrect number of NetProtoCounters: %v", err)
196	}
197	for _, vv := range v {
198		if vv.Protocol != "tcp" && vv.Protocol != "ip" {
199			t.Errorf("Invalid NetProtoCountersStat: %v", vv)
200		}
201		if len(vv.Stats) == 0 {
202			t.Errorf("Invalid NetProtoCountersStat: %v", vv)
203		}
204	}
205}
206
207func TestNetConnections(t *testing.T) {
208	if ci := os.Getenv("CI"); ci != "" { // skip if test on drone.io
209		return
210	}
211
212	v, err := Connections("inet")
213	skipIfNotImplementedErr(t, err)
214	if err != nil {
215		t.Errorf("could not get NetConnections: %v", err)
216	}
217	if len(v) == 0 {
218		t.Errorf("could not get NetConnections: %v", v)
219	}
220	for _, vv := range v {
221		if vv.Family == 0 {
222			t.Errorf("invalid NetConnections: %v", vv)
223		}
224	}
225
226}
227
228func TestNetFilterCounters(t *testing.T) {
229	if ci := os.Getenv("CI"); ci != "" { // skip if test on drone.io
230		return
231	}
232
233	if runtime.GOOS == "linux" {
234		// some test environment has not the path.
235		if !common.PathExists("/proc/sys/net/netfilter/nf_conntrackCount") {
236			t.SkipNow()
237		}
238	}
239
240	v, err := FilterCounters()
241	skipIfNotImplementedErr(t, err)
242	if err != nil {
243		t.Errorf("could not get NetConnections: %v", err)
244	}
245	if len(v) == 0 {
246		t.Errorf("could not get NetConnections: %v", v)
247	}
248	for _, vv := range v {
249		if vv.ConnTrackMax == 0 {
250			t.Errorf("nf_conntrackMax needs to be greater than zero: %v", vv)
251		}
252	}
253
254}
255