1// +build windows
2
3package common
4
5import (
6	"context"
7	"path/filepath"
8	"strings"
9	"syscall"
10	"unsafe"
11
12	"github.com/StackExchange/wmi"
13	"golang.org/x/sys/windows"
14)
15
16// for double values
17type PDH_FMT_COUNTERVALUE_DOUBLE struct {
18	CStatus     uint32
19	DoubleValue float64
20}
21
22// for 64 bit integer values
23type PDH_FMT_COUNTERVALUE_LARGE struct {
24	CStatus    uint32
25	LargeValue int64
26}
27
28// for long values
29type PDH_FMT_COUNTERVALUE_LONG struct {
30	CStatus   uint32
31	LongValue int32
32	padding   [4]byte
33}
34
35// windows system const
36const (
37	ERROR_SUCCESS        = 0
38	ERROR_FILE_NOT_FOUND = 2
39	DRIVE_REMOVABLE      = 2
40	DRIVE_FIXED          = 3
41	HKEY_LOCAL_MACHINE   = 0x80000002
42	RRF_RT_REG_SZ        = 0x00000002
43	RRF_RT_REG_DWORD     = 0x00000010
44	PDH_FMT_LONG         = 0x00000100
45	PDH_FMT_DOUBLE       = 0x00000200
46	PDH_FMT_LARGE        = 0x00000400
47	PDH_INVALID_DATA     = 0xc0000bc6
48	PDH_INVALID_HANDLE   = 0xC0000bbc
49	PDH_NO_DATA          = 0x800007d5
50)
51
52const (
53	ProcessBasicInformation = 0
54	ProcessWow64Information = 26
55)
56
57var (
58	Modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
59	ModNt       = windows.NewLazySystemDLL("ntdll.dll")
60	ModPdh      = windows.NewLazySystemDLL("pdh.dll")
61	ModPsapi    = windows.NewLazySystemDLL("psapi.dll")
62
63	ProcGetSystemTimes                   = Modkernel32.NewProc("GetSystemTimes")
64	ProcNtQuerySystemInformation         = ModNt.NewProc("NtQuerySystemInformation")
65	ProcRtlGetNativeSystemInformation    = ModNt.NewProc("RtlGetNativeSystemInformation")
66	ProcRtlNtStatusToDosError            = ModNt.NewProc("RtlNtStatusToDosError")
67	ProcNtQueryInformationProcess        = ModNt.NewProc("NtQueryInformationProcess")
68	ProcNtReadVirtualMemory              = ModNt.NewProc("NtReadVirtualMemory")
69	ProcNtWow64QueryInformationProcess64 = ModNt.NewProc("NtWow64QueryInformationProcess64")
70	ProcNtWow64ReadVirtualMemory64       = ModNt.NewProc("NtWow64ReadVirtualMemory64")
71
72	PdhOpenQuery                         = ModPdh.NewProc("PdhOpenQuery")
73	PdhAddCounter                        = ModPdh.NewProc("PdhAddCounterW")
74	PdhCollectQueryData                  = ModPdh.NewProc("PdhCollectQueryData")
75	PdhGetFormattedCounterValue          = ModPdh.NewProc("PdhGetFormattedCounterValue")
76	PdhCloseQuery                        = ModPdh.NewProc("PdhCloseQuery")
77
78	procQueryDosDeviceW                  = Modkernel32.NewProc("QueryDosDeviceW")
79)
80
81type FILETIME struct {
82	DwLowDateTime  uint32
83	DwHighDateTime uint32
84}
85
86// borrowed from net/interface_windows.go
87func BytePtrToString(p *uint8) string {
88	a := (*[10000]uint8)(unsafe.Pointer(p))
89	i := 0
90	for a[i] != 0 {
91		i++
92	}
93	return string(a[:i])
94}
95
96// CounterInfo
97// copied from https://github.com/mackerelio/mackerel-agent/
98type CounterInfo struct {
99	PostName    string
100	CounterName string
101	Counter     windows.Handle
102}
103
104// CreateQuery XXX
105// copied from https://github.com/mackerelio/mackerel-agent/
106func CreateQuery() (windows.Handle, error) {
107	var query windows.Handle
108	r, _, err := PdhOpenQuery.Call(0, 0, uintptr(unsafe.Pointer(&query)))
109	if r != 0 {
110		return 0, err
111	}
112	return query, nil
113}
114
115// CreateCounter XXX
116func CreateCounter(query windows.Handle, pname, cname string) (*CounterInfo, error) {
117	var counter windows.Handle
118	r, _, err := PdhAddCounter.Call(
119		uintptr(query),
120		uintptr(unsafe.Pointer(windows.StringToUTF16Ptr(cname))),
121		0,
122		uintptr(unsafe.Pointer(&counter)))
123	if r != 0 {
124		return nil, err
125	}
126	return &CounterInfo{
127		PostName:    pname,
128		CounterName: cname,
129		Counter:     counter,
130	}, nil
131}
132
133// WMIQueryWithContext - wraps wmi.Query with a timed-out context to avoid hanging
134func WMIQueryWithContext(ctx context.Context, query string, dst interface{}, connectServerArgs ...interface{}) error {
135	if _, ok := ctx.Deadline(); !ok {
136		ctxTimeout, cancel := context.WithTimeout(ctx, Timeout)
137		defer cancel()
138		ctx = ctxTimeout
139	}
140
141	errChan := make(chan error, 1)
142	go func() {
143		errChan <- wmi.Query(query, dst, connectServerArgs...)
144	}()
145
146	select {
147	case <-ctx.Done():
148		return ctx.Err()
149	case err := <-errChan:
150		return err
151	}
152}
153
154// Convert paths using native DOS format like:
155//   "\Device\HarddiskVolume1\Windows\systemew\file.txt"
156// into:
157//   "C:\Windows\systemew\file.txt"
158func ConvertDOSPath(p string) string {
159	rawDrive := strings.Join(strings.Split(p, `\`)[:3], `\`)
160
161	for d := 'A'; d <= 'Z'; d++ {
162		szDeviceName := string(d) + ":"
163		szTarget := make([]uint16, 512)
164		ret, _, _ := procQueryDosDeviceW.Call(uintptr(unsafe.Pointer(syscall.StringToUTF16Ptr(szDeviceName))),
165			uintptr(unsafe.Pointer(&szTarget[0])),
166			uintptr(len(szTarget)))
167		if ret != 0 && windows.UTF16ToString(szTarget[:]) == rawDrive {
168			return filepath.Join(szDeviceName, p[len(rawDrive):])
169		}
170	}
171	return p
172}
173