1// +build windows
2
3package common
4
5import (
6	"context"
7	"fmt"
8	"path/filepath"
9	"strings"
10	"syscall"
11	"unsafe"
12
13	"github.com/StackExchange/wmi"
14	"golang.org/x/sys/windows"
15)
16
17// for double values
18type PDH_FMT_COUNTERVALUE_DOUBLE struct {
19	CStatus     uint32
20	DoubleValue float64
21}
22
23// for 64 bit integer values
24type PDH_FMT_COUNTERVALUE_LARGE struct {
25	CStatus    uint32
26	LargeValue int64
27}
28
29// for long values
30type PDH_FMT_COUNTERVALUE_LONG struct {
31	CStatus   uint32
32	LongValue int32
33	padding   [4]byte
34}
35
36// windows system const
37const (
38	ERROR_SUCCESS        = 0
39	ERROR_FILE_NOT_FOUND = 2
40	DRIVE_REMOVABLE      = 2
41	DRIVE_FIXED          = 3
42	HKEY_LOCAL_MACHINE   = 0x80000002
43	RRF_RT_REG_SZ        = 0x00000002
44	RRF_RT_REG_DWORD     = 0x00000010
45	PDH_FMT_LONG         = 0x00000100
46	PDH_FMT_DOUBLE       = 0x00000200
47	PDH_FMT_LARGE        = 0x00000400
48	PDH_INVALID_DATA     = 0xc0000bc6
49	PDH_INVALID_HANDLE   = 0xC0000bbc
50	PDH_NO_DATA          = 0x800007d5
51)
52
53const (
54	ProcessBasicInformation = 0
55	ProcessWow64Information = 26
56)
57
58var (
59	Modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
60	ModNt       = windows.NewLazySystemDLL("ntdll.dll")
61	ModPdh      = windows.NewLazySystemDLL("pdh.dll")
62	ModPsapi    = windows.NewLazySystemDLL("psapi.dll")
63
64	ProcGetSystemTimes                   = Modkernel32.NewProc("GetSystemTimes")
65	ProcNtQuerySystemInformation         = ModNt.NewProc("NtQuerySystemInformation")
66	ProcRtlGetNativeSystemInformation    = ModNt.NewProc("RtlGetNativeSystemInformation")
67	ProcRtlNtStatusToDosError            = ModNt.NewProc("RtlNtStatusToDosError")
68	ProcNtQueryInformationProcess        = ModNt.NewProc("NtQueryInformationProcess")
69	ProcNtReadVirtualMemory              = ModNt.NewProc("NtReadVirtualMemory")
70	ProcNtWow64QueryInformationProcess64 = ModNt.NewProc("NtWow64QueryInformationProcess64")
71	ProcNtWow64ReadVirtualMemory64       = ModNt.NewProc("NtWow64ReadVirtualMemory64")
72
73	PdhOpenQuery                = ModPdh.NewProc("PdhOpenQuery")
74	PdhAddEnglishCounterW       = ModPdh.NewProc("PdhAddEnglishCounterW")
75	PdhCollectQueryData         = ModPdh.NewProc("PdhCollectQueryData")
76	PdhGetFormattedCounterValue = ModPdh.NewProc("PdhGetFormattedCounterValue")
77	PdhCloseQuery               = ModPdh.NewProc("PdhCloseQuery")
78
79	procQueryDosDeviceW = Modkernel32.NewProc("QueryDosDeviceW")
80)
81
82type FILETIME struct {
83	DwLowDateTime  uint32
84	DwHighDateTime uint32
85}
86
87// borrowed from net/interface_windows.go
88func BytePtrToString(p *uint8) string {
89	a := (*[10000]uint8)(unsafe.Pointer(p))
90	i := 0
91	for a[i] != 0 {
92		i++
93	}
94	return string(a[:i])
95}
96
97// CounterInfo struct is used to track a windows performance counter
98// copied from https://github.com/mackerelio/mackerel-agent/
99type CounterInfo struct {
100	PostName    string
101	CounterName string
102	Counter     windows.Handle
103}
104
105// CreateQuery with a PdhOpenQuery call
106// copied from https://github.com/mackerelio/mackerel-agent/
107func CreateQuery() (windows.Handle, error) {
108	var query windows.Handle
109	r, _, err := PdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query)))
110	if r != 0 {
111		return 0, err
112	}
113	return query, nil
114}
115
116// CreateCounter with a PdhAddEnglishCounterW call
117func CreateCounter(query windows.Handle, pname, cname string) (*CounterInfo, error) {
118	var counter windows.Handle
119	r, _, err := PdhAddEnglishCounterW.Call(
120		uintptr(query),
121		uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cname))),
122		0,
123		uintptr(unsafe.Pointer(&counter)))
124	if r != 0 {
125		return nil, err
126	}
127	return &CounterInfo{
128		PostName:    pname,
129		CounterName: cname,
130		Counter:     counter,
131	}, nil
132}
133
134// GetCounterValue get counter value from handle
135// adapted from https://github.com/mackerelio/mackerel-agent/
136func GetCounterValue(counter windows.Handle) (float64, error) {
137	var value PDH_FMT_COUNTERVALUE_DOUBLE
138	r, _, err := PdhGetFormattedCounterValue.Call(uintptr(counter), PDH_FMT_DOUBLE, uintptr(0), uintptr(unsafe.Pointer(&value)))
139	if r != 0 && r != PDH_INVALID_DATA {
140		return 0.0, err
141	}
142	return value.DoubleValue, nil
143}
144
145type Win32PerformanceCounter struct {
146	PostName    string
147	CounterName string
148	Query       windows.Handle
149	Counter     windows.Handle
150}
151
152func NewWin32PerformanceCounter(postName, counterName string) (*Win32PerformanceCounter, error) {
153	query, err := CreateQuery()
154	if err != nil {
155		return nil, err
156	}
157	var counter = Win32PerformanceCounter{
158		Query:       query,
159		PostName:    postName,
160		CounterName: counterName,
161	}
162	r, _, err := PdhAddEnglishCounterW.Call(
163		uintptr(counter.Query),
164		uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(counter.CounterName))),
165		0,
166		uintptr(unsafe.Pointer(&counter.Counter)),
167	)
168	if r != 0 {
169		return nil, err
170	}
171	return &counter, nil
172}
173
174func (w *Win32PerformanceCounter) GetValue() (float64, error) {
175	r, _, err := PdhCollectQueryData.Call(uintptr(w.Query))
176	if r != 0 && err != nil {
177		if r == PDH_NO_DATA {
178			return 0.0, fmt.Errorf("%w: this counter has not data", err)
179		}
180		return 0.0, err
181	}
182
183	return GetCounterValue(w.Counter)
184}
185
186func ProcessorQueueLengthCounter() (*Win32PerformanceCounter, error) {
187	return NewWin32PerformanceCounter("processor_queue_length", `\System\Processor Queue Length`)
188}
189
190// WMIQueryWithContext - wraps wmi.Query with a timed-out context to avoid hanging
191func WMIQueryWithContext(ctx context.Context, query string, dst interface{}, connectServerArgs ...interface{}) error {
192	if _, ok := ctx.Deadline(); !ok {
193		ctxTimeout, cancel := context.WithTimeout(ctx, Timeout)
194		defer cancel()
195		ctx = ctxTimeout
196	}
197
198	errChan := make(chan error, 1)
199	go func() {
200		errChan <- wmi.Query(query, dst, connectServerArgs...)
201	}()
202
203	select {
204	case <-ctx.Done():
205		return ctx.Err()
206	case err := <-errChan:
207		return err
208	}
209}
210
211// Convert paths using native DOS format like:
212//   "\Device\HarddiskVolume1\Windows\systemew\file.txt"
213// into:
214//   "C:\Windows\systemew\file.txt"
215func ConvertDOSPath(p string) string {
216	rawDrive := strings.Join(strings.Split(p, `\`)[:3], `\`)
217
218	for d := 'A'; d <= 'Z'; d++ {
219		szDeviceName := string(d) + ":"
220		szTarget := make([]uint16, 512)
221		ret, _, _ := procQueryDosDeviceW.Call(uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(szDeviceName))),
222			uintptr(unsafe.Pointer(&szTarget[0])),
223			uintptr(len(szTarget)))
224		if ret != 0 && windows.UTF16ToString(szTarget[:]) == rawDrive {
225			return filepath.Join(szDeviceName, p[len(rawDrive):])
226		}
227	}
228	return p
229}
230