1package toml
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"reflect"
8	"strings"
9	"time"
10)
11
12type tomlOpts struct {
13	name      string
14	include   bool
15	omitempty bool
16}
17
18var timeType = reflect.TypeOf(time.Time{})
19var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
20
21// Check if the given marshall type maps to a Tree primitive
22func isPrimitive(mtype reflect.Type) bool {
23	switch mtype.Kind() {
24	case reflect.Ptr:
25		return isPrimitive(mtype.Elem())
26	case reflect.Bool:
27		return true
28	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
29		return true
30	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
31		return true
32	case reflect.Float32, reflect.Float64:
33		return true
34	case reflect.String:
35		return true
36	case reflect.Struct:
37		return mtype == timeType || isCustomMarshaler(mtype)
38	default:
39		return false
40	}
41}
42
43// Check if the given marshall type maps to a Tree slice
44func isTreeSlice(mtype reflect.Type) bool {
45	switch mtype.Kind() {
46	case reflect.Slice:
47		return !isOtherSlice(mtype)
48	default:
49		return false
50	}
51}
52
53// Check if the given marshall type maps to a non-Tree slice
54func isOtherSlice(mtype reflect.Type) bool {
55	switch mtype.Kind() {
56	case reflect.Ptr:
57		return isOtherSlice(mtype.Elem())
58	case reflect.Slice:
59		return isPrimitive(mtype.Elem()) || isOtherSlice(mtype.Elem())
60	default:
61		return false
62	}
63}
64
65// Check if the given marshall type maps to a Tree
66func isTree(mtype reflect.Type) bool {
67	switch mtype.Kind() {
68	case reflect.Map:
69		return true
70	case reflect.Struct:
71		return !isPrimitive(mtype)
72	default:
73		return false
74	}
75}
76
77func isCustomMarshaler(mtype reflect.Type) bool {
78	return mtype.Implements(marshalerType)
79}
80
81func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
82	return mval.Interface().(Marshaler).MarshalTOML()
83}
84
85// Marshaler is the interface implemented by types that
86// can marshal themselves into valid TOML.
87type Marshaler interface {
88	MarshalTOML() ([]byte, error)
89}
90
91/*
92Marshal returns the TOML encoding of v.  Behavior is similar to the Go json
93encoder, except that there is no concept of a Marshaler interface or MarshalTOML
94function for sub-structs, and currently only definite types can be marshaled
95(i.e. no `interface{}`).
96
97Note that pointers are automatically assigned the "omitempty" option, as TOML
98explicity does not handle null values (saying instead the label should be
99dropped).
100
101Tree structural types and corresponding marshal types:
102
103  *Tree                            (*)struct, (*)map[string]interface{}
104  []*Tree                          (*)[](*)struct, (*)[](*)map[string]interface{}
105  []interface{} (as interface{})   (*)[]primitive, (*)[]([]interface{})
106  interface{}                      (*)primitive
107
108Tree primitive types and corresponding marshal types:
109
110  uint64     uint, uint8-uint64, pointers to same
111  int64      int, int8-uint64, pointers to same
112  float64    float32, float64, pointers to same
113  string     string, pointers to same
114  bool       bool, pointers to same
115  time.Time  time.Time{}, pointers to same
116*/
117func Marshal(v interface{}) ([]byte, error) {
118	mtype := reflect.TypeOf(v)
119	if mtype.Kind() != reflect.Struct {
120		return []byte{}, errors.New("Only a struct can be marshaled to TOML")
121	}
122	sval := reflect.ValueOf(v)
123	if isCustomMarshaler(mtype) {
124		return callCustomMarshaler(sval)
125	}
126	t, err := valueToTree(mtype, sval)
127	if err != nil {
128		return []byte{}, err
129	}
130	s, err := t.ToTomlString()
131	return []byte(s), err
132}
133
134// Convert given marshal struct or map value to toml tree
135func valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) {
136	if mtype.Kind() == reflect.Ptr {
137		return valueToTree(mtype.Elem(), mval.Elem())
138	}
139	tval := newTree()
140	switch mtype.Kind() {
141	case reflect.Struct:
142		for i := 0; i < mtype.NumField(); i++ {
143			mtypef, mvalf := mtype.Field(i), mval.Field(i)
144			opts := tomlOptions(mtypef)
145			if opts.include && (!opts.omitempty || !isZero(mvalf)) {
146				val, err := valueToToml(mtypef.Type, mvalf)
147				if err != nil {
148					return nil, err
149				}
150				tval.Set(opts.name, val)
151			}
152		}
153	case reflect.Map:
154		for _, key := range mval.MapKeys() {
155			mvalf := mval.MapIndex(key)
156			val, err := valueToToml(mtype.Elem(), mvalf)
157			if err != nil {
158				return nil, err
159			}
160			tval.Set(key.String(), val)
161		}
162	}
163	return tval, nil
164}
165
166// Convert given marshal slice to slice of Toml trees
167func valueToTreeSlice(mtype reflect.Type, mval reflect.Value) ([]*Tree, error) {
168	tval := make([]*Tree, mval.Len(), mval.Len())
169	for i := 0; i < mval.Len(); i++ {
170		val, err := valueToTree(mtype.Elem(), mval.Index(i))
171		if err != nil {
172			return nil, err
173		}
174		tval[i] = val
175	}
176	return tval, nil
177}
178
179// Convert given marshal slice to slice of toml values
180func valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
181	tval := make([]interface{}, mval.Len(), mval.Len())
182	for i := 0; i < mval.Len(); i++ {
183		val, err := valueToToml(mtype.Elem(), mval.Index(i))
184		if err != nil {
185			return nil, err
186		}
187		tval[i] = val
188	}
189	return tval, nil
190}
191
192// Convert given marshal value to toml value
193func valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
194	if mtype.Kind() == reflect.Ptr {
195		return valueToToml(mtype.Elem(), mval.Elem())
196	}
197	switch {
198	case isCustomMarshaler(mtype):
199		return callCustomMarshaler(mval)
200	case isTree(mtype):
201		return valueToTree(mtype, mval)
202	case isTreeSlice(mtype):
203		return valueToTreeSlice(mtype, mval)
204	case isOtherSlice(mtype):
205		return valueToOtherSlice(mtype, mval)
206	default:
207		switch mtype.Kind() {
208		case reflect.Bool:
209			return mval.Bool(), nil
210		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
211			return mval.Int(), nil
212		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
213			return mval.Uint(), nil
214		case reflect.Float32, reflect.Float64:
215			return mval.Float(), nil
216		case reflect.String:
217			return mval.String(), nil
218		case reflect.Struct:
219			return mval.Interface().(time.Time), nil
220		default:
221			return nil, fmt.Errorf("Marshal can't handle %v(%v)", mtype, mtype.Kind())
222		}
223	}
224}
225
226// Unmarshal attempts to unmarshal the Tree into a Go struct pointed by v.
227// Neither Unmarshaler interfaces nor UnmarshalTOML functions are supported for
228// sub-structs, and only definite types can be unmarshaled.
229func (t *Tree) Unmarshal(v interface{}) error {
230	mtype := reflect.TypeOf(v)
231	if mtype.Kind() != reflect.Ptr || mtype.Elem().Kind() != reflect.Struct {
232		return errors.New("Only a pointer to struct can be unmarshaled from TOML")
233	}
234
235	sval, err := valueFromTree(mtype.Elem(), t)
236	if err != nil {
237		return err
238	}
239	reflect.ValueOf(v).Elem().Set(sval)
240	return nil
241}
242
243// Unmarshal parses the TOML-encoded data and stores the result in the value
244// pointed to by v. Behavior is similar to the Go json encoder, except that there
245// is no concept of an Unmarshaler interface or UnmarshalTOML function for
246// sub-structs, and currently only definite types can be unmarshaled to (i.e. no
247// `interface{}`).
248//
249// See Marshal() documentation for types mapping table.
250func Unmarshal(data []byte, v interface{}) error {
251	t, err := LoadReader(bytes.NewReader(data))
252	if err != nil {
253		return err
254	}
255	return t.Unmarshal(v)
256}
257
258// Convert toml tree to marshal struct or map, using marshal type
259func valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, error) {
260	if mtype.Kind() == reflect.Ptr {
261		return unwrapPointer(mtype, tval)
262	}
263	var mval reflect.Value
264	switch mtype.Kind() {
265	case reflect.Struct:
266		mval = reflect.New(mtype).Elem()
267		for i := 0; i < mtype.NumField(); i++ {
268			mtypef := mtype.Field(i)
269			opts := tomlOptions(mtypef)
270			if opts.include {
271				baseKey := opts.name
272				keysToTry := []string{baseKey, strings.ToLower(baseKey), strings.ToTitle(baseKey)}
273				for _, key := range keysToTry {
274					exists := tval.Has(key)
275					if !exists {
276						continue
277					}
278					val := tval.Get(key)
279					mvalf, err := valueFromToml(mtypef.Type, val)
280					if err != nil {
281						return mval, formatError(err, tval.GetPosition(key))
282					}
283					mval.Field(i).Set(mvalf)
284					break
285				}
286			}
287		}
288	case reflect.Map:
289		mval = reflect.MakeMap(mtype)
290		for _, key := range tval.Keys() {
291			val := tval.Get(key)
292			mvalf, err := valueFromToml(mtype.Elem(), val)
293			if err != nil {
294				return mval, formatError(err, tval.GetPosition(key))
295			}
296			mval.SetMapIndex(reflect.ValueOf(key), mvalf)
297		}
298	}
299	return mval, nil
300}
301
302// Convert toml value to marshal struct/map slice, using marshal type
303func valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
304	mval := reflect.MakeSlice(mtype, len(tval), len(tval))
305	for i := 0; i < len(tval); i++ {
306		val, err := valueFromTree(mtype.Elem(), tval[i])
307		if err != nil {
308			return mval, err
309		}
310		mval.Index(i).Set(val)
311	}
312	return mval, nil
313}
314
315// Convert toml value to marshal primitive slice, using marshal type
316func valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) {
317	mval := reflect.MakeSlice(mtype, len(tval), len(tval))
318	for i := 0; i < len(tval); i++ {
319		val, err := valueFromToml(mtype.Elem(), tval[i])
320		if err != nil {
321			return mval, err
322		}
323		mval.Index(i).Set(val)
324	}
325	return mval, nil
326}
327
328// Convert toml value to marshal value, using marshal type
329func valueFromToml(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
330	if mtype.Kind() == reflect.Ptr {
331		return unwrapPointer(mtype, tval)
332	}
333	switch {
334	case isTree(mtype):
335		return valueFromTree(mtype, tval.(*Tree))
336	case isTreeSlice(mtype):
337		return valueFromTreeSlice(mtype, tval.([]*Tree))
338	case isOtherSlice(mtype):
339		return valueFromOtherSlice(mtype, tval.([]interface{}))
340	default:
341		switch mtype.Kind() {
342		case reflect.Bool:
343			val, ok := tval.(bool)
344			if !ok {
345				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to bool", tval, tval)
346			}
347			return reflect.ValueOf(val), nil
348		case reflect.Int:
349			val, ok := tval.(int64)
350			if !ok {
351				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
352			}
353			return reflect.ValueOf(int(val)), nil
354		case reflect.Int8:
355			val, ok := tval.(int64)
356			if !ok {
357				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
358			}
359			return reflect.ValueOf(int8(val)), nil
360		case reflect.Int16:
361			val, ok := tval.(int64)
362			if !ok {
363				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
364			}
365			return reflect.ValueOf(int16(val)), nil
366		case reflect.Int32:
367			val, ok := tval.(int64)
368			if !ok {
369				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
370			}
371			return reflect.ValueOf(int32(val)), nil
372		case reflect.Int64:
373			val, ok := tval.(int64)
374			if !ok {
375				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to int", tval, tval)
376			}
377			return reflect.ValueOf(val), nil
378		case reflect.Uint:
379			val, ok := tval.(int64)
380			if !ok {
381				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
382			}
383			return reflect.ValueOf(uint(val)), nil
384		case reflect.Uint8:
385			val, ok := tval.(int64)
386			if !ok {
387				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
388			}
389			return reflect.ValueOf(uint8(val)), nil
390		case reflect.Uint16:
391			val, ok := tval.(int64)
392			if !ok {
393				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
394			}
395			return reflect.ValueOf(uint16(val)), nil
396		case reflect.Uint32:
397			val, ok := tval.(int64)
398			if !ok {
399				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
400			}
401			return reflect.ValueOf(uint32(val)), nil
402		case reflect.Uint64:
403			val, ok := tval.(int64)
404			if !ok {
405				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to uint", tval, tval)
406			}
407			return reflect.ValueOf(uint64(val)), nil
408		case reflect.Float32:
409			val, ok := tval.(float64)
410			if !ok {
411				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval)
412			}
413			return reflect.ValueOf(float32(val)), nil
414		case reflect.Float64:
415			val, ok := tval.(float64)
416			if !ok {
417				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to float", tval, tval)
418			}
419			return reflect.ValueOf(val), nil
420		case reflect.String:
421			val, ok := tval.(string)
422			if !ok {
423				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to string", tval, tval)
424			}
425			return reflect.ValueOf(val), nil
426		case reflect.Struct:
427			val, ok := tval.(time.Time)
428			if !ok {
429				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to time", tval, tval)
430			}
431			return reflect.ValueOf(val), nil
432		default:
433			return reflect.ValueOf(nil), fmt.Errorf("Unmarshal can't handle %v(%v)", mtype, mtype.Kind())
434		}
435	}
436}
437
438func unwrapPointer(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
439	val, err := valueFromToml(mtype.Elem(), tval)
440	if err != nil {
441		return reflect.ValueOf(nil), err
442	}
443	mval := reflect.New(mtype.Elem())
444	mval.Elem().Set(val)
445	return mval, nil
446}
447
448func tomlOptions(vf reflect.StructField) tomlOpts {
449	tag := vf.Tag.Get("toml")
450	parse := strings.Split(tag, ",")
451	result := tomlOpts{vf.Name, true, false}
452	if parse[0] != "" {
453		if parse[0] == "-" && len(parse) == 1 {
454			result.include = false
455		} else {
456			result.name = strings.Trim(parse[0], " ")
457		}
458	}
459	if vf.PkgPath != "" {
460		result.include = false
461	}
462	if len(parse) > 1 && strings.Trim(parse[1], " ") == "omitempty" {
463		result.omitempty = true
464	}
465	if vf.Type.Kind() == reflect.Ptr {
466		result.omitempty = true
467	}
468	return result
469}
470
471func isZero(val reflect.Value) bool {
472	switch val.Type().Kind() {
473	case reflect.Map:
474		fallthrough
475	case reflect.Array:
476		fallthrough
477	case reflect.Slice:
478		return val.Len() == 0
479	default:
480		return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface())
481	}
482}
483
484func formatError(err error, pos Position) error {
485	if err.Error()[0] == '(' { // Error already contains position information
486		return err
487	}
488	return fmt.Errorf("%s: %s", pos, err)
489}
490