1// Copyright 2015 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// +build windows
6
7package registry
8
9import (
10	"errors"
11	"io"
12	"syscall"
13	"unicode/utf16"
14	"unsafe"
15)
16
17const (
18	// Registry value types.
19	NONE                       = 0
20	SZ                         = 1
21	EXPAND_SZ                  = 2
22	BINARY                     = 3
23	DWORD                      = 4
24	DWORD_BIG_ENDIAN           = 5
25	LINK                       = 6
26	MULTI_SZ                   = 7
27	RESOURCE_LIST              = 8
28	FULL_RESOURCE_DESCRIPTOR   = 9
29	RESOURCE_REQUIREMENTS_LIST = 10
30	QWORD                      = 11
31)
32
33var (
34	// ErrShortBuffer is returned when the buffer was too short for the operation.
35	ErrShortBuffer = syscall.ERROR_MORE_DATA
36
37	// ErrNotExist is returned when a registry key or value does not exist.
38	ErrNotExist = syscall.ERROR_FILE_NOT_FOUND
39
40	// ErrUnexpectedType is returned by Get*Value when the value's type was unexpected.
41	ErrUnexpectedType = errors.New("unexpected key value type")
42)
43
44// GetValue retrieves the type and data for the specified value associated
45// with an open key k. It fills up buffer buf and returns the retrieved
46// byte count n. If buf is too small to fit the stored value it returns
47// ErrShortBuffer error along with the required buffer size n.
48// If no buffer is provided, it returns true and actual buffer size n.
49// If no buffer is provided, GetValue returns the value's type only.
50// If the value does not exist, the error returned is ErrNotExist.
51//
52// GetValue is a low level function. If value's type is known, use the appropriate
53// Get*Value function instead.
54func (k Key) GetValue(name string, buf []byte) (n int, valtype uint32, err error) {
55	pname, err := syscall.UTF16PtrFromString(name)
56	if err != nil {
57		return 0, 0, err
58	}
59	var pbuf *byte
60	if len(buf) > 0 {
61		pbuf = (*byte)(unsafe.Pointer(&buf[0]))
62	}
63	l := uint32(len(buf))
64	err = syscall.RegQueryValueEx(syscall.Handle(k), pname, nil, &valtype, pbuf, &l)
65	if err != nil {
66		return int(l), valtype, err
67	}
68	return int(l), valtype, nil
69}
70
71func (k Key) getValue(name string, buf []byte) (data []byte, valtype uint32, err error) {
72	p, err := syscall.UTF16PtrFromString(name)
73	if err != nil {
74		return nil, 0, err
75	}
76	var t uint32
77	n := uint32(len(buf))
78	for {
79		err = syscall.RegQueryValueEx(syscall.Handle(k), p, nil, &t, (*byte)(unsafe.Pointer(&buf[0])), &n)
80		if err == nil {
81			return buf[:n], t, nil
82		}
83		if err != syscall.ERROR_MORE_DATA {
84			return nil, 0, err
85		}
86		if n <= uint32(len(buf)) {
87			return nil, 0, err
88		}
89		buf = make([]byte, n)
90	}
91}
92
93// GetStringValue retrieves the string value for the specified
94// value name associated with an open key k. It also returns the value's type.
95// If value does not exist, GetStringValue returns ErrNotExist.
96// If value is not SZ or EXPAND_SZ, it will return the correct value
97// type and ErrUnexpectedType.
98func (k Key) GetStringValue(name string) (val string, valtype uint32, err error) {
99	data, typ, err2 := k.getValue(name, make([]byte, 64))
100	if err2 != nil {
101		return "", typ, err2
102	}
103	switch typ {
104	case SZ, EXPAND_SZ:
105	default:
106		return "", typ, ErrUnexpectedType
107	}
108	if len(data) == 0 {
109		return "", typ, nil
110	}
111	u := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
112	return syscall.UTF16ToString(u), typ, nil
113}
114
115// GetMUIStringValue retrieves the localized string value for
116// the specified value name associated with an open key k.
117// If the value name doesn't exist or the localized string value
118// can't be resolved, GetMUIStringValue returns ErrNotExist.
119// GetMUIStringValue panics if the system doesn't support
120// regLoadMUIString; use LoadRegLoadMUIString to check if
121// regLoadMUIString is supported before calling this function.
122func (k Key) GetMUIStringValue(name string) (string, error) {
123	pname, err := syscall.UTF16PtrFromString(name)
124	if err != nil {
125		return "", err
126	}
127
128	buf := make([]uint16, 1024)
129	var buflen uint32
130	var pdir *uint16
131
132	err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
133	if err == syscall.ERROR_FILE_NOT_FOUND { // Try fallback path
134
135		// Try to resolve the string value using the system directory as
136		// a DLL search path; this assumes the string value is of the form
137		// @[path]\dllname,-strID but with no path given, e.g. @tzres.dll,-320.
138
139		// This approach works with tzres.dll but may have to be revised
140		// in the future to allow callers to provide custom search paths.
141
142		var s string
143		s, err = ExpandString("%SystemRoot%\\system32\\")
144		if err != nil {
145			return "", err
146		}
147		pdir, err = syscall.UTF16PtrFromString(s)
148		if err != nil {
149			return "", err
150		}
151
152		err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
153	}
154
155	for err == syscall.ERROR_MORE_DATA { // Grow buffer if needed
156		if buflen <= uint32(len(buf)) {
157			break // Buffer not growing, assume race; break
158		}
159		buf = make([]uint16, buflen)
160		err = regLoadMUIString(syscall.Handle(k), pname, &buf[0], uint32(len(buf)), &buflen, 0, pdir)
161	}
162
163	if err != nil {
164		return "", err
165	}
166
167	return syscall.UTF16ToString(buf), nil
168}
169
170// ExpandString expands environment-variable strings and replaces
171// them with the values defined for the current user.
172// Use ExpandString to expand EXPAND_SZ strings.
173func ExpandString(value string) (string, error) {
174	if value == "" {
175		return "", nil
176	}
177	p, err := syscall.UTF16PtrFromString(value)
178	if err != nil {
179		return "", err
180	}
181	r := make([]uint16, 100)
182	for {
183		n, err := expandEnvironmentStrings(p, &r[0], uint32(len(r)))
184		if err != nil {
185			return "", err
186		}
187		if n <= uint32(len(r)) {
188			return syscall.UTF16ToString(r[:n]), nil
189		}
190		r = make([]uint16, n)
191	}
192}
193
194// GetStringsValue retrieves the []string value for the specified
195// value name associated with an open key k. It also returns the value's type.
196// If value does not exist, GetStringsValue returns ErrNotExist.
197// If value is not MULTI_SZ, it will return the correct value
198// type and ErrUnexpectedType.
199func (k Key) GetStringsValue(name string) (val []string, valtype uint32, err error) {
200	data, typ, err2 := k.getValue(name, make([]byte, 64))
201	if err2 != nil {
202		return nil, typ, err2
203	}
204	if typ != MULTI_SZ {
205		return nil, typ, ErrUnexpectedType
206	}
207	if len(data) == 0 {
208		return nil, typ, nil
209	}
210	p := (*[1 << 29]uint16)(unsafe.Pointer(&data[0]))[: len(data)/2 : len(data)/2]
211	if len(p) == 0 {
212		return nil, typ, nil
213	}
214	if p[len(p)-1] == 0 {
215		p = p[:len(p)-1] // remove terminating null
216	}
217	val = make([]string, 0, 5)
218	from := 0
219	for i, c := range p {
220		if c == 0 {
221			val = append(val, string(utf16.Decode(p[from:i])))
222			from = i + 1
223		}
224	}
225	return val, typ, nil
226}
227
228// GetIntegerValue retrieves the integer value for the specified
229// value name associated with an open key k. It also returns the value's type.
230// If value does not exist, GetIntegerValue returns ErrNotExist.
231// If value is not DWORD or QWORD, it will return the correct value
232// type and ErrUnexpectedType.
233func (k Key) GetIntegerValue(name string) (val uint64, valtype uint32, err error) {
234	data, typ, err2 := k.getValue(name, make([]byte, 8))
235	if err2 != nil {
236		return 0, typ, err2
237	}
238	switch typ {
239	case DWORD:
240		if len(data) != 4 {
241			return 0, typ, errors.New("DWORD value is not 4 bytes long")
242		}
243		var val32 uint32
244		copy((*[4]byte)(unsafe.Pointer(&val32))[:], data)
245		return uint64(val32), DWORD, nil
246	case QWORD:
247		if len(data) != 8 {
248			return 0, typ, errors.New("QWORD value is not 8 bytes long")
249		}
250		copy((*[8]byte)(unsafe.Pointer(&val))[:], data)
251		return val, QWORD, nil
252	default:
253		return 0, typ, ErrUnexpectedType
254	}
255}
256
257// GetBinaryValue retrieves the binary value for the specified
258// value name associated with an open key k. It also returns the value's type.
259// If value does not exist, GetBinaryValue returns ErrNotExist.
260// If value is not BINARY, it will return the correct value
261// type and ErrUnexpectedType.
262func (k Key) GetBinaryValue(name string) (val []byte, valtype uint32, err error) {
263	data, typ, err2 := k.getValue(name, make([]byte, 64))
264	if err2 != nil {
265		return nil, typ, err2
266	}
267	if typ != BINARY {
268		return nil, typ, ErrUnexpectedType
269	}
270	return data, typ, nil
271}
272
273func (k Key) setValue(name string, valtype uint32, data []byte) error {
274	p, err := syscall.UTF16PtrFromString(name)
275	if err != nil {
276		return err
277	}
278	if len(data) == 0 {
279		return regSetValueEx(syscall.Handle(k), p, 0, valtype, nil, 0)
280	}
281	return regSetValueEx(syscall.Handle(k), p, 0, valtype, &data[0], uint32(len(data)))
282}
283
284// SetDWordValue sets the data and type of a name value
285// under key k to value and DWORD.
286func (k Key) SetDWordValue(name string, value uint32) error {
287	return k.setValue(name, DWORD, (*[4]byte)(unsafe.Pointer(&value))[:])
288}
289
290// SetQWordValue sets the data and type of a name value
291// under key k to value and QWORD.
292func (k Key) SetQWordValue(name string, value uint64) error {
293	return k.setValue(name, QWORD, (*[8]byte)(unsafe.Pointer(&value))[:])
294}
295
296func (k Key) setStringValue(name string, valtype uint32, value string) error {
297	v, err := syscall.UTF16FromString(value)
298	if err != nil {
299		return err
300	}
301	buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
302	return k.setValue(name, valtype, buf)
303}
304
305// SetStringValue sets the data and type of a name value
306// under key k to value and SZ. The value must not contain a zero byte.
307func (k Key) SetStringValue(name, value string) error {
308	return k.setStringValue(name, SZ, value)
309}
310
311// SetExpandStringValue sets the data and type of a name value
312// under key k to value and EXPAND_SZ. The value must not contain a zero byte.
313func (k Key) SetExpandStringValue(name, value string) error {
314	return k.setStringValue(name, EXPAND_SZ, value)
315}
316
317// SetStringsValue sets the data and type of a name value
318// under key k to value and MULTI_SZ. The value strings
319// must not contain a zero byte.
320func (k Key) SetStringsValue(name string, value []string) error {
321	ss := ""
322	for _, s := range value {
323		for i := 0; i < len(s); i++ {
324			if s[i] == 0 {
325				return errors.New("string cannot have 0 inside")
326			}
327		}
328		ss += s + "\x00"
329	}
330	v := utf16.Encode([]rune(ss + "\x00"))
331	buf := (*[1 << 29]byte)(unsafe.Pointer(&v[0]))[: len(v)*2 : len(v)*2]
332	return k.setValue(name, MULTI_SZ, buf)
333}
334
335// SetBinaryValue sets the data and type of a name value
336// under key k to value and BINARY.
337func (k Key) SetBinaryValue(name string, value []byte) error {
338	return k.setValue(name, BINARY, value)
339}
340
341// DeleteValue removes a named value from the key k.
342func (k Key) DeleteValue(name string) error {
343	return regDeleteValue(syscall.Handle(k), syscall.StringToUTF16Ptr(name))
344}
345
346// ReadValueNames returns the value names of key k.
347// The parameter n controls the number of returned names,
348// analogous to the way os.File.Readdirnames works.
349func (k Key) ReadValueNames(n int) ([]string, error) {
350	ki, err := k.Stat()
351	if err != nil {
352		return nil, err
353	}
354	names := make([]string, 0, ki.ValueCount)
355	buf := make([]uint16, ki.MaxValueNameLen+1) // extra room for terminating null character
356loopItems:
357	for i := uint32(0); ; i++ {
358		if n > 0 {
359			if len(names) == n {
360				return names, nil
361			}
362		}
363		l := uint32(len(buf))
364		for {
365			err := regEnumValue(syscall.Handle(k), i, &buf[0], &l, nil, nil, nil, nil)
366			if err == nil {
367				break
368			}
369			if err == syscall.ERROR_MORE_DATA {
370				// Double buffer size and try again.
371				l = uint32(2 * len(buf))
372				buf = make([]uint16, l)
373				continue
374			}
375			if err == _ERROR_NO_MORE_ITEMS {
376				break loopItems
377			}
378			return names, err
379		}
380		names = append(names, syscall.UTF16ToString(buf[:l]))
381	}
382	if n > len(names) {
383		return names, io.EOF
384	}
385	return names, nil
386}
387