1// +build windows
2
3package wmi
4
5import (
6	"fmt"
7	"reflect"
8	"runtime"
9	"sync"
10
11	"github.com/go-ole/go-ole"
12	"github.com/go-ole/go-ole/oleutil"
13)
14
15// SWbemServices is used to access wmi. See https://msdn.microsoft.com/en-us/library/aa393719(v=vs.85).aspx
16type SWbemServices struct {
17	//TODO: track namespace. Not sure if we can re connect to a different namespace using the same instance
18	cWMIClient            *Client //This could also be an embedded struct, but then we would need to branch on Client vs SWbemServices in the Query method
19	sWbemLocatorIUnknown  *ole.IUnknown
20	sWbemLocatorIDispatch *ole.IDispatch
21	queries               chan *queryRequest
22	closeError            chan error
23	lQueryorClose         sync.Mutex
24}
25
26type queryRequest struct {
27	query    string
28	dst      interface{}
29	args     []interface{}
30	finished chan error
31}
32
33// InitializeSWbemServices will return a new SWbemServices object that can be used to query WMI
34func InitializeSWbemServices(c *Client, connectServerArgs ...interface{}) (*SWbemServices, error) {
35	//fmt.Println("InitializeSWbemServices: Starting")
36	//TODO: implement connectServerArgs as optional argument for init with connectServer call
37	s := new(SWbemServices)
38	s.cWMIClient = c
39	s.queries = make(chan *queryRequest)
40	initError := make(chan error)
41	go s.process(initError)
42
43	err, ok := <-initError
44	if ok {
45		return nil, err //Send error to caller
46	}
47	//fmt.Println("InitializeSWbemServices: Finished")
48	return s, nil
49}
50
51// Close will clear and release all of the SWbemServices resources
52func (s *SWbemServices) Close() error {
53	s.lQueryorClose.Lock()
54	if s == nil || s.sWbemLocatorIDispatch == nil {
55		s.lQueryorClose.Unlock()
56		return fmt.Errorf("SWbemServices is not Initialized")
57	}
58	if s.queries == nil {
59		s.lQueryorClose.Unlock()
60		return fmt.Errorf("SWbemServices has been closed")
61	}
62	//fmt.Println("Close: sending close request")
63	var result error
64	ce := make(chan error)
65	s.closeError = ce //Race condition if multiple callers to close. May need to lock here
66	close(s.queries)  //Tell background to shut things down
67	s.lQueryorClose.Unlock()
68	err, ok := <-ce
69	if ok {
70		result = err
71	}
72	//fmt.Println("Close: finished")
73	return result
74}
75
76func (s *SWbemServices) process(initError chan error) {
77	//fmt.Println("process: starting background thread initialization")
78	//All OLE/WMI calls must happen on the same initialized thead, so lock this goroutine
79	runtime.LockOSThread()
80	defer runtime.LockOSThread()
81
82	err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
83	if err != nil {
84		oleCode := err.(*ole.OleError).Code()
85		if oleCode != ole.S_OK && oleCode != S_FALSE {
86			initError <- fmt.Errorf("ole.CoInitializeEx error: %v", err)
87			return
88		}
89	}
90	defer ole.CoUninitialize()
91
92	unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
93	if err != nil {
94		initError <- fmt.Errorf("CreateObject SWbemLocator error: %v", err)
95		return
96	} else if unknown == nil {
97		initError <- ErrNilCreateObject
98		return
99	}
100	defer unknown.Release()
101	s.sWbemLocatorIUnknown = unknown
102
103	dispatch, err := s.sWbemLocatorIUnknown.QueryInterface(ole.IID_IDispatch)
104	if err != nil {
105		initError <- fmt.Errorf("SWbemLocator QueryInterface error: %v", err)
106		return
107	}
108	defer dispatch.Release()
109	s.sWbemLocatorIDispatch = dispatch
110
111	// we can't do the ConnectServer call outside the loop unless we find a way to track and re-init the connectServerArgs
112	//fmt.Println("process: initialized. closing initError")
113	close(initError)
114	//fmt.Println("process: waiting for queries")
115	for q := range s.queries {
116		//fmt.Printf("process: new query: len(query)=%d\n", len(q.query))
117		errQuery := s.queryBackground(q)
118		//fmt.Println("process: s.queryBackground finished")
119		if errQuery != nil {
120			q.finished <- errQuery
121		}
122		close(q.finished)
123	}
124	//fmt.Println("process: queries channel closed")
125	s.queries = nil //set channel to nil so we know it is closed
126	//TODO: I think the Release/Clear calls can panic if things are in a bad state.
127	//TODO: May need to recover from panics and send error to method caller instead.
128	close(s.closeError)
129}
130
131// Query runs the WQL query using a SWbemServices instance and appends the values to dst.
132//
133// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
134// the query must have the same name in dst. Supported types are all signed and
135// unsigned integers, time.Time, string, bool, or a pointer to one of those.
136// Array types are not supported.
137//
138// By default, the local machine and default namespace are used. These can be
139// changed using connectServerArgs. See
140// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
141func (s *SWbemServices) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
142	s.lQueryorClose.Lock()
143	if s == nil || s.sWbemLocatorIDispatch == nil {
144		s.lQueryorClose.Unlock()
145		return fmt.Errorf("SWbemServices is not Initialized")
146	}
147	if s.queries == nil {
148		s.lQueryorClose.Unlock()
149		return fmt.Errorf("SWbemServices has been closed")
150	}
151
152	//fmt.Println("Query: Sending query request")
153	qr := queryRequest{
154		query:    query,
155		dst:      dst,
156		args:     connectServerArgs,
157		finished: make(chan error),
158	}
159	s.queries <- &qr
160	s.lQueryorClose.Unlock()
161	err, ok := <-qr.finished
162	if ok {
163		//fmt.Println("Query: Finished with error")
164		return err //Send error to caller
165	}
166	//fmt.Println("Query: Finished")
167	return nil
168}
169
170func (s *SWbemServices) queryBackground(q *queryRequest) error {
171	if s == nil || s.sWbemLocatorIDispatch == nil {
172		return fmt.Errorf("SWbemServices is not Initialized")
173	}
174	wmi := s.sWbemLocatorIDispatch //Should just rename in the code, but this will help as we break things apart
175	//fmt.Println("queryBackground: Starting")
176
177	dv := reflect.ValueOf(q.dst)
178	if dv.Kind() != reflect.Ptr || dv.IsNil() {
179		return ErrInvalidEntityType
180	}
181	dv = dv.Elem()
182	mat, elemType := checkMultiArg(dv)
183	if mat == multiArgTypeInvalid {
184		return ErrInvalidEntityType
185	}
186
187	// service is a SWbemServices
188	serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", q.args...)
189	if err != nil {
190		return err
191	}
192	service := serviceRaw.ToIDispatch()
193	defer serviceRaw.Clear()
194
195	// result is a SWBemObjectSet
196	resultRaw, err := oleutil.CallMethod(service, "ExecQuery", q.query)
197	if err != nil {
198		return err
199	}
200	result := resultRaw.ToIDispatch()
201	defer resultRaw.Clear()
202
203	count, err := oleInt64(result, "Count")
204	if err != nil {
205		return err
206	}
207
208	enumProperty, err := result.GetProperty("_NewEnum")
209	if err != nil {
210		return err
211	}
212	defer enumProperty.Clear()
213
214	enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
215	if err != nil {
216		return err
217	}
218	if enum == nil {
219		return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
220	}
221	defer enum.Release()
222
223	// Initialize a slice with Count capacity
224	dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
225
226	var errFieldMismatch error
227	for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
228		if err != nil {
229			return err
230		}
231
232		err := func() error {
233			// item is a SWbemObject, but really a Win32_Process
234			item := itemRaw.ToIDispatch()
235			defer item.Release()
236
237			ev := reflect.New(elemType)
238			if err = s.cWMIClient.loadEntity(ev.Interface(), item); err != nil {
239				if _, ok := err.(*ErrFieldMismatch); ok {
240					// We continue loading entities even in the face of field mismatch errors.
241					// If we encounter any other error, that other error is returned. Otherwise,
242					// an ErrFieldMismatch is returned.
243					errFieldMismatch = err
244				} else {
245					return err
246				}
247			}
248			if mat != multiArgTypeStructPtr {
249				ev = ev.Elem()
250			}
251			dv.Set(reflect.Append(dv, ev))
252			return nil
253		}()
254		if err != nil {
255			return err
256		}
257	}
258	//fmt.Println("queryBackground: Finished")
259	return errFieldMismatch
260}
261