1// Copyright (c) 2014 The btcsuite developers
2// Use of this source code is governed by an ISC
3// license that can be found in the LICENSE file.
4
5package btcjson
6
7import (
8	"encoding/json"
9	"fmt"
10	"reflect"
11	"sort"
12	"strconv"
13	"strings"
14	"sync"
15)
16
17// UsageFlag define flags that specify additional properties about the
18// circumstances under which a command can be used.
19type UsageFlag uint32
20
21const (
22	// UFWalletOnly indicates that the command can only be used with an RPC
23	// server that supports wallet commands.
24	UFWalletOnly UsageFlag = 1 << iota
25
26	// UFWebsocketOnly indicates that the command can only be used when
27	// communicating with an RPC server over websockets.  This typically
28	// applies to notifications and notification registration functions
29	// since neiher makes since when using a single-shot HTTP-POST request.
30	UFWebsocketOnly
31
32	// UFNotification indicates that the command is actually a notification.
33	// This means when it is marshalled, the ID must be nil.
34	UFNotification
35
36	// highestUsageFlagBit is the maximum usage flag bit and is used in the
37	// stringer and tests to ensure all of the above constants have been
38	// tested.
39	highestUsageFlagBit
40)
41
42// Map of UsageFlag values back to their constant names for pretty printing.
43var usageFlagStrings = map[UsageFlag]string{
44	UFWalletOnly:    "UFWalletOnly",
45	UFWebsocketOnly: "UFWebsocketOnly",
46	UFNotification:  "UFNotification",
47}
48
49// String returns the UsageFlag in human-readable form.
50func (fl UsageFlag) String() string {
51	// No flags are set.
52	if fl == 0 {
53		return "0x0"
54	}
55
56	// Add individual bit flags.
57	s := ""
58	for flag := UFWalletOnly; flag < highestUsageFlagBit; flag <<= 1 {
59		if fl&flag == flag {
60			s += usageFlagStrings[flag] + "|"
61			fl -= flag
62		}
63	}
64
65	// Add remaining value as raw hex.
66	s = strings.TrimRight(s, "|")
67	if fl != 0 {
68		s += "|0x" + strconv.FormatUint(uint64(fl), 16)
69	}
70	s = strings.TrimLeft(s, "|")
71	return s
72}
73
74// methodInfo keeps track of information about each registered method such as
75// the parameter information.
76type methodInfo struct {
77	maxParams    int
78	numReqParams int
79	numOptParams int
80	defaults     map[int]reflect.Value
81	flags        UsageFlag
82	usage        string
83}
84
85var (
86	// These fields are used to map the registered types to method names.
87	registerLock         sync.RWMutex
88	methodToConcreteType = make(map[string]reflect.Type)
89	methodToInfo         = make(map[string]methodInfo)
90	concreteTypeToMethod = make(map[reflect.Type]string)
91)
92
93// baseKindString returns the base kind for a given reflect.Type after
94// indirecting through all pointers.
95func baseKindString(rt reflect.Type) string {
96	numIndirects := 0
97	for rt.Kind() == reflect.Ptr {
98		numIndirects++
99		rt = rt.Elem()
100	}
101
102	return fmt.Sprintf("%s%s", strings.Repeat("*", numIndirects), rt.Kind())
103}
104
105// isAcceptableKind returns whether or not the passed field type is a supported
106// type.  It is called after the first pointer indirection, so further pointers
107// are not supported.
108func isAcceptableKind(kind reflect.Kind) bool {
109	switch kind {
110	case reflect.Chan:
111		fallthrough
112	case reflect.Complex64:
113		fallthrough
114	case reflect.Complex128:
115		fallthrough
116	case reflect.Func:
117		fallthrough
118	case reflect.Ptr:
119		fallthrough
120	case reflect.Interface:
121		return false
122	}
123
124	return true
125}
126
127// RegisterCmd registers a new command that will automatically marshal to and
128// from JSON-RPC with full type checking and positional parameter support.  It
129// also accepts usage flags which identify the circumstances under which the
130// command can be used.
131//
132// This package automatically registers all of the exported commands by default
133// using this function, however it is also exported so callers can easily
134// register custom types.
135//
136// The type format is very strict since it needs to be able to automatically
137// marshal to and from JSON-RPC 1.0.  The following enumerates the requirements:
138//
139//   - The provided command must be a single pointer to a struct
140//   - All fields must be exported
141//   - The order of the positional parameters in the marshalled JSON will be in
142//     the same order as declared in the struct definition
143//   - Struct embedding is not supported
144//   - Struct fields may NOT be channels, functions, complex, or interface
145//   - A field in the provided struct with a pointer is treated as optional
146//   - Multiple indirections (i.e **int) are not supported
147//   - Once the first optional field (pointer) is encountered, the remaining
148//     fields must also be optional fields (pointers) as required by positional
149//     params
150//   - A field that has a 'jsonrpcdefault' struct tag must be an optional field
151//     (pointer)
152//
153// NOTE: This function only needs to be able to examine the structure of the
154// passed struct, so it does not need to be an actual instance.  Therefore, it
155// is recommended to simply pass a nil pointer cast to the appropriate type.
156// For example, (*FooCmd)(nil).
157func RegisterCmd(method string, cmd interface{}, flags UsageFlag) error {
158	registerLock.Lock()
159	defer registerLock.Unlock()
160
161	if _, ok := methodToConcreteType[method]; ok {
162		str := fmt.Sprintf("method %q is already registered", method)
163		return makeError(ErrDuplicateMethod, str)
164	}
165
166	// Ensure that no unrecognized flag bits were specified.
167	if ^(highestUsageFlagBit-1)&flags != 0 {
168		str := fmt.Sprintf("invalid usage flags specified for method "+
169			"%s: %v", method, flags)
170		return makeError(ErrInvalidUsageFlags, str)
171	}
172
173	rtp := reflect.TypeOf(cmd)
174	if rtp.Kind() != reflect.Ptr {
175		str := fmt.Sprintf("type must be *struct not '%s (%s)'", rtp,
176			rtp.Kind())
177		return makeError(ErrInvalidType, str)
178	}
179	rt := rtp.Elem()
180	if rt.Kind() != reflect.Struct {
181		str := fmt.Sprintf("type must be *struct not '%s (*%s)'",
182			rtp, rt.Kind())
183		return makeError(ErrInvalidType, str)
184	}
185
186	// Enumerate the struct fields to validate them and gather parameter
187	// information.
188	numFields := rt.NumField()
189	numOptFields := 0
190	defaults := make(map[int]reflect.Value)
191	for i := 0; i < numFields; i++ {
192		rtf := rt.Field(i)
193		if rtf.Anonymous {
194			str := fmt.Sprintf("embedded fields are not supported "+
195				"(field name: %q)", rtf.Name)
196			return makeError(ErrEmbeddedType, str)
197		}
198		if rtf.PkgPath != "" {
199			str := fmt.Sprintf("unexported fields are not supported "+
200				"(field name: %q)", rtf.Name)
201			return makeError(ErrUnexportedField, str)
202		}
203
204		// Disallow types that can't be JSON encoded.  Also, determine
205		// if the field is optional based on it being a pointer.
206		var isOptional bool
207		switch kind := rtf.Type.Kind(); kind {
208		case reflect.Ptr:
209			isOptional = true
210			kind = rtf.Type.Elem().Kind()
211			fallthrough
212		default:
213			if !isAcceptableKind(kind) {
214				str := fmt.Sprintf("unsupported field type "+
215					"'%s (%s)' (field name %q)", rtf.Type,
216					baseKindString(rtf.Type), rtf.Name)
217				return makeError(ErrUnsupportedFieldType, str)
218			}
219		}
220
221		// Count the optional fields and ensure all fields after the
222		// first optional field are also optional.
223		if isOptional {
224			numOptFields++
225		} else {
226			if numOptFields > 0 {
227				str := fmt.Sprintf("all fields after the first "+
228					"optional field must also be optional "+
229					"(field name %q)", rtf.Name)
230				return makeError(ErrNonOptionalField, str)
231			}
232		}
233
234		// Ensure the default value can be unsmarshalled into the type
235		// and that defaults are only specified for optional fields.
236		if tag := rtf.Tag.Get("jsonrpcdefault"); tag != "" {
237			if !isOptional {
238				str := fmt.Sprintf("required fields must not "+
239					"have a default specified (field name "+
240					"%q)", rtf.Name)
241				return makeError(ErrNonOptionalDefault, str)
242			}
243
244			rvf := reflect.New(rtf.Type.Elem())
245			err := json.Unmarshal([]byte(tag), rvf.Interface())
246			if err != nil {
247				str := fmt.Sprintf("default value of %q is "+
248					"the wrong type (field name %q)", tag,
249					rtf.Name)
250				return makeError(ErrMismatchedDefault, str)
251			}
252			defaults[i] = rvf
253		}
254	}
255
256	// Update the registration maps.
257	methodToConcreteType[method] = rtp
258	methodToInfo[method] = methodInfo{
259		maxParams:    numFields,
260		numReqParams: numFields - numOptFields,
261		numOptParams: numOptFields,
262		defaults:     defaults,
263		flags:        flags,
264	}
265	concreteTypeToMethod[rtp] = method
266	return nil
267}
268
269// MustRegisterCmd performs the same function as RegisterCmd except it panics
270// if there is an error.  This should only be called from package init
271// functions.
272func MustRegisterCmd(method string, cmd interface{}, flags UsageFlag) {
273	if err := RegisterCmd(method, cmd, flags); err != nil {
274		panic(fmt.Sprintf("failed to register type %q: %v\n", method,
275			err))
276	}
277}
278
279// RegisteredCmdMethods returns a sorted list of methods for all registered
280// commands.
281func RegisteredCmdMethods() []string {
282	registerLock.Lock()
283	defer registerLock.Unlock()
284
285	methods := make([]string, 0, len(methodToInfo))
286	for k := range methodToInfo {
287		methods = append(methods, k)
288	}
289
290	sort.Sort(sort.StringSlice(methods))
291	return methods
292}
293