1package toml
2
3import (
4	"bytes"
5	"errors"
6	"fmt"
7	"io"
8	"reflect"
9	"strconv"
10	"strings"
11	"time"
12)
13
14type tomlOpts struct {
15	name      string
16	comment   string
17	commented bool
18	include   bool
19	omitempty bool
20}
21
22type encOpts struct {
23	quoteMapKeys            bool
24	arraysOneElementPerLine bool
25}
26
27var encOptsDefaults = encOpts{
28	quoteMapKeys: false,
29}
30
31var timeType = reflect.TypeOf(time.Time{})
32var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
33
34// Check if the given marshall type maps to a Tree primitive
35func isPrimitive(mtype reflect.Type) bool {
36	switch mtype.Kind() {
37	case reflect.Ptr:
38		return isPrimitive(mtype.Elem())
39	case reflect.Bool:
40		return true
41	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
42		return true
43	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
44		return true
45	case reflect.Float32, reflect.Float64:
46		return true
47	case reflect.String:
48		return true
49	case reflect.Struct:
50		return mtype == timeType || isCustomMarshaler(mtype)
51	default:
52		return false
53	}
54}
55
56// Check if the given marshall type maps to a Tree slice
57func isTreeSlice(mtype reflect.Type) bool {
58	switch mtype.Kind() {
59	case reflect.Slice:
60		return !isOtherSlice(mtype)
61	default:
62		return false
63	}
64}
65
66// Check if the given marshall type maps to a non-Tree slice
67func isOtherSlice(mtype reflect.Type) bool {
68	switch mtype.Kind() {
69	case reflect.Ptr:
70		return isOtherSlice(mtype.Elem())
71	case reflect.Slice:
72		return isPrimitive(mtype.Elem()) || isOtherSlice(mtype.Elem())
73	default:
74		return false
75	}
76}
77
78// Check if the given marshall type maps to a Tree
79func isTree(mtype reflect.Type) bool {
80	switch mtype.Kind() {
81	case reflect.Map:
82		return true
83	case reflect.Struct:
84		return !isPrimitive(mtype)
85	default:
86		return false
87	}
88}
89
90func isCustomMarshaler(mtype reflect.Type) bool {
91	return mtype.Implements(marshalerType)
92}
93
94func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
95	return mval.Interface().(Marshaler).MarshalTOML()
96}
97
98// Marshaler is the interface implemented by types that
99// can marshal themselves into valid TOML.
100type Marshaler interface {
101	MarshalTOML() ([]byte, error)
102}
103
104/*
105Marshal returns the TOML encoding of v.  Behavior is similar to the Go json
106encoder, except that there is no concept of a Marshaler interface or MarshalTOML
107function for sub-structs, and currently only definite types can be marshaled
108(i.e. no `interface{}`).
109
110The following struct annotations are supported:
111
112  toml:"Field"      Overrides the field's name to output.
113  omitempty         When set, empty values and groups are not emitted.
114  comment:"comment" Emits a # comment on the same line. This supports new lines.
115  commented:"true"  Emits the value as commented.
116
117Note that pointers are automatically assigned the "omitempty" option, as TOML
118explicitly does not handle null values (saying instead the label should be
119dropped).
120
121Tree structural types and corresponding marshal types:
122
123  *Tree                            (*)struct, (*)map[string]interface{}
124  []*Tree                          (*)[](*)struct, (*)[](*)map[string]interface{}
125  []interface{} (as interface{})   (*)[]primitive, (*)[]([]interface{})
126  interface{}                      (*)primitive
127
128Tree primitive types and corresponding marshal types:
129
130  uint64     uint, uint8-uint64, pointers to same
131  int64      int, int8-uint64, pointers to same
132  float64    float32, float64, pointers to same
133  string     string, pointers to same
134  bool       bool, pointers to same
135  time.Time  time.Time{}, pointers to same
136*/
137func Marshal(v interface{}) ([]byte, error) {
138	return NewEncoder(nil).marshal(v)
139}
140
141// Encoder writes TOML values to an output stream.
142type Encoder struct {
143	w io.Writer
144	encOpts
145}
146
147// NewEncoder returns a new encoder that writes to w.
148func NewEncoder(w io.Writer) *Encoder {
149	return &Encoder{
150		w:       w,
151		encOpts: encOptsDefaults,
152	}
153}
154
155// Encode writes the TOML encoding of v to the stream.
156//
157// See the documentation for Marshal for details.
158func (e *Encoder) Encode(v interface{}) error {
159	b, err := e.marshal(v)
160	if err != nil {
161		return err
162	}
163	if _, err := e.w.Write(b); err != nil {
164		return err
165	}
166	return nil
167}
168
169// QuoteMapKeys sets up the encoder to encode
170// maps with string type keys with quoted TOML keys.
171//
172// This relieves the character limitations on map keys.
173func (e *Encoder) QuoteMapKeys(v bool) *Encoder {
174	e.quoteMapKeys = v
175	return e
176}
177
178// ArraysWithOneElementPerLine sets up the encoder to encode arrays
179// with more than one element on multiple lines instead of one.
180//
181// For example:
182//
183//   A = [1,2,3]
184//
185// Becomes
186//
187//   A = [
188//     1,
189//     2,
190//     3
191//   ]
192func (e *Encoder) ArraysWithOneElementPerLine(v bool) *Encoder {
193	e.arraysOneElementPerLine = v
194	return e
195}
196
197func (e *Encoder) marshal(v interface{}) ([]byte, error) {
198	mtype := reflect.TypeOf(v)
199	if mtype.Kind() != reflect.Struct {
200		return []byte{}, errors.New("Only a struct can be marshaled to TOML")
201	}
202	sval := reflect.ValueOf(v)
203	if isCustomMarshaler(mtype) {
204		return callCustomMarshaler(sval)
205	}
206	t, err := e.valueToTree(mtype, sval)
207	if err != nil {
208		return []byte{}, err
209	}
210
211	var buf bytes.Buffer
212	_, err = t.writeTo(&buf, "", "", 0, e.arraysOneElementPerLine)
213
214	return buf.Bytes(), err
215}
216
217// Convert given marshal struct or map value to toml tree
218func (e *Encoder) valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) {
219	if mtype.Kind() == reflect.Ptr {
220		return e.valueToTree(mtype.Elem(), mval.Elem())
221	}
222	tval := newTree()
223	switch mtype.Kind() {
224	case reflect.Struct:
225		for i := 0; i < mtype.NumField(); i++ {
226			mtypef, mvalf := mtype.Field(i), mval.Field(i)
227			opts := tomlOptions(mtypef)
228			if opts.include && (!opts.omitempty || !isZero(mvalf)) {
229				val, err := e.valueToToml(mtypef.Type, mvalf)
230				if err != nil {
231					return nil, err
232				}
233				tval.SetWithComment(opts.name, opts.comment, opts.commented, val)
234			}
235		}
236	case reflect.Map:
237		for _, key := range mval.MapKeys() {
238			mvalf := mval.MapIndex(key)
239			val, err := e.valueToToml(mtype.Elem(), mvalf)
240			if err != nil {
241				return nil, err
242			}
243			if e.quoteMapKeys {
244				keyStr, err := tomlValueStringRepresentation(key.String(), "", e.arraysOneElementPerLine)
245				if err != nil {
246					return nil, err
247				}
248				tval.SetPath([]string{keyStr}, val)
249			} else {
250				tval.Set(key.String(), val)
251			}
252		}
253	}
254	return tval, nil
255}
256
257// Convert given marshal slice to slice of Toml trees
258func (e *Encoder) valueToTreeSlice(mtype reflect.Type, mval reflect.Value) ([]*Tree, error) {
259	tval := make([]*Tree, mval.Len(), mval.Len())
260	for i := 0; i < mval.Len(); i++ {
261		val, err := e.valueToTree(mtype.Elem(), mval.Index(i))
262		if err != nil {
263			return nil, err
264		}
265		tval[i] = val
266	}
267	return tval, nil
268}
269
270// Convert given marshal slice to slice of toml values
271func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
272	tval := make([]interface{}, mval.Len(), mval.Len())
273	for i := 0; i < mval.Len(); i++ {
274		val, err := e.valueToToml(mtype.Elem(), mval.Index(i))
275		if err != nil {
276			return nil, err
277		}
278		tval[i] = val
279	}
280	return tval, nil
281}
282
283// Convert given marshal value to toml value
284func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
285	if mtype.Kind() == reflect.Ptr {
286		return e.valueToToml(mtype.Elem(), mval.Elem())
287	}
288	switch {
289	case isCustomMarshaler(mtype):
290		return callCustomMarshaler(mval)
291	case isTree(mtype):
292		return e.valueToTree(mtype, mval)
293	case isTreeSlice(mtype):
294		return e.valueToTreeSlice(mtype, mval)
295	case isOtherSlice(mtype):
296		return e.valueToOtherSlice(mtype, mval)
297	default:
298		switch mtype.Kind() {
299		case reflect.Bool:
300			return mval.Bool(), nil
301		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
302			return mval.Int(), nil
303		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
304			return mval.Uint(), nil
305		case reflect.Float32, reflect.Float64:
306			return mval.Float(), nil
307		case reflect.String:
308			return mval.String(), nil
309		case reflect.Struct:
310			return mval.Interface().(time.Time), nil
311		default:
312			return nil, fmt.Errorf("Marshal can't handle %v(%v)", mtype, mtype.Kind())
313		}
314	}
315}
316
317// Unmarshal attempts to unmarshal the Tree into a Go struct pointed by v.
318// Neither Unmarshaler interfaces nor UnmarshalTOML functions are supported for
319// sub-structs, and only definite types can be unmarshaled.
320func (t *Tree) Unmarshal(v interface{}) error {
321	d := Decoder{tval: t}
322	return d.unmarshal(v)
323}
324
325// Marshal returns the TOML encoding of Tree.
326// See Marshal() documentation for types mapping table.
327func (t *Tree) Marshal() ([]byte, error) {
328	var buf bytes.Buffer
329	err := NewEncoder(&buf).Encode(t)
330	return buf.Bytes(), err
331}
332
333// Unmarshal parses the TOML-encoded data and stores the result in the value
334// pointed to by v. Behavior is similar to the Go json encoder, except that there
335// is no concept of an Unmarshaler interface or UnmarshalTOML function for
336// sub-structs, and currently only definite types can be unmarshaled to (i.e. no
337// `interface{}`).
338//
339// The following struct annotations are supported:
340//
341//   toml:"Field" Overrides the field's name to map to.
342//
343// See Marshal() documentation for types mapping table.
344func Unmarshal(data []byte, v interface{}) error {
345	t, err := LoadReader(bytes.NewReader(data))
346	if err != nil {
347		return err
348	}
349	return t.Unmarshal(v)
350}
351
352// Decoder reads and decodes TOML values from an input stream.
353type Decoder struct {
354	r    io.Reader
355	tval *Tree
356	encOpts
357}
358
359// NewDecoder returns a new decoder that reads from r.
360func NewDecoder(r io.Reader) *Decoder {
361	return &Decoder{
362		r:       r,
363		encOpts: encOptsDefaults,
364	}
365}
366
367// Decode reads a TOML-encoded value from it's input
368// and unmarshals it in the value pointed at by v.
369//
370// See the documentation for Marshal for details.
371func (d *Decoder) Decode(v interface{}) error {
372	var err error
373	d.tval, err = LoadReader(d.r)
374	if err != nil {
375		return err
376	}
377	return d.unmarshal(v)
378}
379
380func (d *Decoder) unmarshal(v interface{}) error {
381	mtype := reflect.TypeOf(v)
382	if mtype.Kind() != reflect.Ptr || mtype.Elem().Kind() != reflect.Struct {
383		return errors.New("Only a pointer to struct can be unmarshaled from TOML")
384	}
385
386	sval, err := d.valueFromTree(mtype.Elem(), d.tval)
387	if err != nil {
388		return err
389	}
390	reflect.ValueOf(v).Elem().Set(sval)
391	return nil
392}
393
394// Convert toml tree to marshal struct or map, using marshal type
395func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree) (reflect.Value, error) {
396	if mtype.Kind() == reflect.Ptr {
397		return d.unwrapPointer(mtype, tval)
398	}
399	var mval reflect.Value
400	switch mtype.Kind() {
401	case reflect.Struct:
402		mval = reflect.New(mtype).Elem()
403		for i := 0; i < mtype.NumField(); i++ {
404			mtypef := mtype.Field(i)
405			opts := tomlOptions(mtypef)
406			if opts.include {
407				baseKey := opts.name
408				keysToTry := []string{baseKey, strings.ToLower(baseKey), strings.ToTitle(baseKey)}
409				for _, key := range keysToTry {
410					exists := tval.Has(key)
411					if !exists {
412						continue
413					}
414					val := tval.Get(key)
415					mvalf, err := d.valueFromToml(mtypef.Type, val)
416					if err != nil {
417						return mval, formatError(err, tval.GetPosition(key))
418					}
419					mval.Field(i).Set(mvalf)
420					break
421				}
422			}
423		}
424	case reflect.Map:
425		mval = reflect.MakeMap(mtype)
426		for _, key := range tval.Keys() {
427			// TODO: path splits key
428			val := tval.GetPath([]string{key})
429			mvalf, err := d.valueFromToml(mtype.Elem(), val)
430			if err != nil {
431				return mval, formatError(err, tval.GetPosition(key))
432			}
433			mval.SetMapIndex(reflect.ValueOf(key), mvalf)
434		}
435	}
436	return mval, nil
437}
438
439// Convert toml value to marshal struct/map slice, using marshal type
440func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
441	mval := reflect.MakeSlice(mtype, len(tval), len(tval))
442	for i := 0; i < len(tval); i++ {
443		val, err := d.valueFromTree(mtype.Elem(), tval[i])
444		if err != nil {
445			return mval, err
446		}
447		mval.Index(i).Set(val)
448	}
449	return mval, nil
450}
451
452// Convert toml value to marshal primitive slice, using marshal type
453func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) {
454	mval := reflect.MakeSlice(mtype, len(tval), len(tval))
455	for i := 0; i < len(tval); i++ {
456		val, err := d.valueFromToml(mtype.Elem(), tval[i])
457		if err != nil {
458			return mval, err
459		}
460		mval.Index(i).Set(val)
461	}
462	return mval, nil
463}
464
465// Convert toml value to marshal value, using marshal type
466func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
467	if mtype.Kind() == reflect.Ptr {
468		return d.unwrapPointer(mtype, tval)
469	}
470
471	switch tval.(type) {
472	case *Tree:
473		if isTree(mtype) {
474			return d.valueFromTree(mtype, tval.(*Tree))
475		}
476		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a tree", tval, tval)
477	case []*Tree:
478		if isTreeSlice(mtype) {
479			return d.valueFromTreeSlice(mtype, tval.([]*Tree))
480		}
481		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval)
482	case []interface{}:
483		if isOtherSlice(mtype) {
484			return d.valueFromOtherSlice(mtype, tval.([]interface{}))
485		}
486		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval)
487	default:
488		switch mtype.Kind() {
489		case reflect.Bool, reflect.Struct:
490			val := reflect.ValueOf(tval)
491			// if this passes for when mtype is reflect.Struct, tval is a time.Time
492			if !val.Type().ConvertibleTo(mtype) {
493				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
494			}
495
496			return val.Convert(mtype), nil
497		case reflect.String:
498			val := reflect.ValueOf(tval)
499			// stupidly, int64 is convertible to string. So special case this.
500			if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Int64 {
501				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
502			}
503
504			return val.Convert(mtype), nil
505		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
506			val := reflect.ValueOf(tval)
507			if !val.Type().ConvertibleTo(mtype) {
508				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
509			}
510			if reflect.Indirect(reflect.New(mtype)).OverflowInt(val.Int()) {
511				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
512			}
513
514			return val.Convert(mtype), nil
515		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
516			val := reflect.ValueOf(tval)
517			if !val.Type().ConvertibleTo(mtype) {
518				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
519			}
520			if val.Int() < 0 {
521				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) is negative so does not fit in %v", tval, tval, mtype.String())
522			}
523			if reflect.Indirect(reflect.New(mtype)).OverflowUint(uint64(val.Int())) {
524				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
525			}
526
527			return val.Convert(mtype), nil
528		case reflect.Float32, reflect.Float64:
529			val := reflect.ValueOf(tval)
530			if !val.Type().ConvertibleTo(mtype) {
531				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
532			}
533			if reflect.Indirect(reflect.New(mtype)).OverflowFloat(val.Float()) {
534				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
535			}
536
537			return val.Convert(mtype), nil
538		default:
539			return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v(%v)", tval, tval, mtype, mtype.Kind())
540		}
541	}
542}
543
544func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
545	val, err := d.valueFromToml(mtype.Elem(), tval)
546	if err != nil {
547		return reflect.ValueOf(nil), err
548	}
549	mval := reflect.New(mtype.Elem())
550	mval.Elem().Set(val)
551	return mval, nil
552}
553
554func tomlOptions(vf reflect.StructField) tomlOpts {
555	tag := vf.Tag.Get("toml")
556	parse := strings.Split(tag, ",")
557	var comment string
558	if c := vf.Tag.Get("comment"); c != "" {
559		comment = c
560	}
561	commented, _ := strconv.ParseBool(vf.Tag.Get("commented"))
562	result := tomlOpts{name: vf.Name, comment: comment, commented: commented, include: true, omitempty: false}
563	if parse[0] != "" {
564		if parse[0] == "-" && len(parse) == 1 {
565			result.include = false
566		} else {
567			result.name = strings.Trim(parse[0], " ")
568		}
569	}
570	if vf.PkgPath != "" {
571		result.include = false
572	}
573	if len(parse) > 1 && strings.Trim(parse[1], " ") == "omitempty" {
574		result.omitempty = true
575	}
576	if vf.Type.Kind() == reflect.Ptr {
577		result.omitempty = true
578	}
579	return result
580}
581
582func isZero(val reflect.Value) bool {
583	switch val.Type().Kind() {
584	case reflect.Map:
585		fallthrough
586	case reflect.Array:
587		fallthrough
588	case reflect.Slice:
589		return val.Len() == 0
590	default:
591		return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface())
592	}
593}
594
595func formatError(err error, pos Position) error {
596	if err.Error()[0] == '(' { // Error already contains position information
597		return err
598	}
599	return fmt.Errorf("%s: %s", pos, err)
600}
601