1package toml
2
3import (
4	"bytes"
5	"encoding"
6	"errors"
7	"fmt"
8	"io"
9	"reflect"
10	"sort"
11	"strconv"
12	"strings"
13	"time"
14)
15
16const (
17	tagFieldName    = "toml"
18	tagFieldComment = "comment"
19	tagCommented    = "commented"
20	tagMultiline    = "multiline"
21	tagLiteral      = "literal"
22	tagDefault      = "default"
23)
24
25type tomlOpts struct {
26	name         string
27	nameFromTag  bool
28	comment      string
29	commented    bool
30	multiline    bool
31	literal      bool
32	include      bool
33	omitempty    bool
34	defaultValue string
35}
36
37type encOpts struct {
38	quoteMapKeys            bool
39	arraysOneElementPerLine bool
40}
41
42var encOptsDefaults = encOpts{
43	quoteMapKeys: false,
44}
45
46type annotation struct {
47	tag          string
48	comment      string
49	commented    string
50	multiline    string
51	literal      string
52	defaultValue string
53}
54
55var annotationDefault = annotation{
56	tag:          tagFieldName,
57	comment:      tagFieldComment,
58	commented:    tagCommented,
59	multiline:    tagMultiline,
60	literal:      tagLiteral,
61	defaultValue: tagDefault,
62}
63
64type MarshalOrder int
65
66// Orders the Encoder can write the fields to the output stream.
67const (
68	// Sort fields alphabetically.
69	OrderAlphabetical MarshalOrder = iota + 1
70	// Preserve the order the fields are encountered. For example, the order of fields in
71	// a struct.
72	OrderPreserve
73)
74
75var timeType = reflect.TypeOf(time.Time{})
76var marshalerType = reflect.TypeOf(new(Marshaler)).Elem()
77var unmarshalerType = reflect.TypeOf(new(Unmarshaler)).Elem()
78var textMarshalerType = reflect.TypeOf(new(encoding.TextMarshaler)).Elem()
79var textUnmarshalerType = reflect.TypeOf(new(encoding.TextUnmarshaler)).Elem()
80var localDateType = reflect.TypeOf(LocalDate{})
81var localTimeType = reflect.TypeOf(LocalTime{})
82var localDateTimeType = reflect.TypeOf(LocalDateTime{})
83var mapStringInterfaceType = reflect.TypeOf(map[string]interface{}{})
84
85// Check if the given marshal type maps to a Tree primitive
86func isPrimitive(mtype reflect.Type) bool {
87	switch mtype.Kind() {
88	case reflect.Ptr:
89		return isPrimitive(mtype.Elem())
90	case reflect.Bool:
91		return true
92	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
93		return true
94	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
95		return true
96	case reflect.Float32, reflect.Float64:
97		return true
98	case reflect.String:
99		return true
100	case reflect.Struct:
101		return isTimeType(mtype)
102	default:
103		return false
104	}
105}
106
107func isTimeType(mtype reflect.Type) bool {
108	return mtype == timeType || mtype == localDateType || mtype == localDateTimeType || mtype == localTimeType
109}
110
111// Check if the given marshal type maps to a Tree slice or array
112func isTreeSequence(mtype reflect.Type) bool {
113	switch mtype.Kind() {
114	case reflect.Ptr:
115		return isTreeSequence(mtype.Elem())
116	case reflect.Slice, reflect.Array:
117		return isTree(mtype.Elem())
118	default:
119		return false
120	}
121}
122
123// Check if the given marshal type maps to a slice or array of a custom marshaler type
124func isCustomMarshalerSequence(mtype reflect.Type) bool {
125	switch mtype.Kind() {
126	case reflect.Ptr:
127		return isCustomMarshalerSequence(mtype.Elem())
128	case reflect.Slice, reflect.Array:
129		return isCustomMarshaler(mtype.Elem()) || isCustomMarshaler(reflect.New(mtype.Elem()).Type())
130	default:
131		return false
132	}
133}
134
135// Check if the given marshal type maps to a slice or array of a text marshaler type
136func isTextMarshalerSequence(mtype reflect.Type) bool {
137	switch mtype.Kind() {
138	case reflect.Ptr:
139		return isTextMarshalerSequence(mtype.Elem())
140	case reflect.Slice, reflect.Array:
141		return isTextMarshaler(mtype.Elem()) || isTextMarshaler(reflect.New(mtype.Elem()).Type())
142	default:
143		return false
144	}
145}
146
147// Check if the given marshal type maps to a non-Tree slice or array
148func isOtherSequence(mtype reflect.Type) bool {
149	switch mtype.Kind() {
150	case reflect.Ptr:
151		return isOtherSequence(mtype.Elem())
152	case reflect.Slice, reflect.Array:
153		return !isTreeSequence(mtype)
154	default:
155		return false
156	}
157}
158
159// Check if the given marshal type maps to a Tree
160func isTree(mtype reflect.Type) bool {
161	switch mtype.Kind() {
162	case reflect.Ptr:
163		return isTree(mtype.Elem())
164	case reflect.Map:
165		return true
166	case reflect.Struct:
167		return !isPrimitive(mtype)
168	default:
169		return false
170	}
171}
172
173func isCustomMarshaler(mtype reflect.Type) bool {
174	return mtype.Implements(marshalerType)
175}
176
177func callCustomMarshaler(mval reflect.Value) ([]byte, error) {
178	return mval.Interface().(Marshaler).MarshalTOML()
179}
180
181func isTextMarshaler(mtype reflect.Type) bool {
182	return mtype.Implements(textMarshalerType) && !isTimeType(mtype)
183}
184
185func callTextMarshaler(mval reflect.Value) ([]byte, error) {
186	return mval.Interface().(encoding.TextMarshaler).MarshalText()
187}
188
189func isCustomUnmarshaler(mtype reflect.Type) bool {
190	return mtype.Implements(unmarshalerType)
191}
192
193func callCustomUnmarshaler(mval reflect.Value, tval interface{}) error {
194	return mval.Interface().(Unmarshaler).UnmarshalTOML(tval)
195}
196
197func isTextUnmarshaler(mtype reflect.Type) bool {
198	return mtype.Implements(textUnmarshalerType)
199}
200
201func callTextUnmarshaler(mval reflect.Value, text []byte) error {
202	return mval.Interface().(encoding.TextUnmarshaler).UnmarshalText(text)
203}
204
205// Marshaler is the interface implemented by types that
206// can marshal themselves into valid TOML.
207type Marshaler interface {
208	MarshalTOML() ([]byte, error)
209}
210
211// Unmarshaler is the interface implemented by types that
212// can unmarshal a TOML description of themselves.
213type Unmarshaler interface {
214	UnmarshalTOML(interface{}) error
215}
216
217/*
218Marshal returns the TOML encoding of v.  Behavior is similar to the Go json
219encoder, except that there is no concept of a Marshaler interface or MarshalTOML
220function for sub-structs, and currently only definite types can be marshaled
221(i.e. no `interface{}`).
222
223The following struct annotations are supported:
224
225  toml:"Field"      Overrides the field's name to output.
226  omitempty         When set, empty values and groups are not emitted.
227  comment:"comment" Emits a # comment on the same line. This supports new lines.
228  commented:"true"  Emits the value as commented.
229
230Note that pointers are automatically assigned the "omitempty" option, as TOML
231explicitly does not handle null values (saying instead the label should be
232dropped).
233
234Tree structural types and corresponding marshal types:
235
236  *Tree                            (*)struct, (*)map[string]interface{}
237  []*Tree                          (*)[](*)struct, (*)[](*)map[string]interface{}
238  []interface{} (as interface{})   (*)[]primitive, (*)[]([]interface{})
239  interface{}                      (*)primitive
240
241Tree primitive types and corresponding marshal types:
242
243  uint64     uint, uint8-uint64, pointers to same
244  int64      int, int8-uint64, pointers to same
245  float64    float32, float64, pointers to same
246  string     string, pointers to same
247  bool       bool, pointers to same
248  time.LocalTime  time.LocalTime{}, pointers to same
249
250For additional flexibility, use the Encoder API.
251*/
252func Marshal(v interface{}) ([]byte, error) {
253	return NewEncoder(nil).marshal(v)
254}
255
256// Encoder writes TOML values to an output stream.
257type Encoder struct {
258	w io.Writer
259	encOpts
260	annotation
261	line            int
262	col             int
263	order           MarshalOrder
264	promoteAnon     bool
265	compactComments bool
266	indentation     string
267}
268
269// NewEncoder returns a new encoder that writes to w.
270func NewEncoder(w io.Writer) *Encoder {
271	return &Encoder{
272		w:           w,
273		encOpts:     encOptsDefaults,
274		annotation:  annotationDefault,
275		line:        0,
276		col:         1,
277		order:       OrderAlphabetical,
278		indentation: "  ",
279	}
280}
281
282// Encode writes the TOML encoding of v to the stream.
283//
284// See the documentation for Marshal for details.
285func (e *Encoder) Encode(v interface{}) error {
286	b, err := e.marshal(v)
287	if err != nil {
288		return err
289	}
290	if _, err := e.w.Write(b); err != nil {
291		return err
292	}
293	return nil
294}
295
296// QuoteMapKeys sets up the encoder to encode
297// maps with string type keys with quoted TOML keys.
298//
299// This relieves the character limitations on map keys.
300func (e *Encoder) QuoteMapKeys(v bool) *Encoder {
301	e.quoteMapKeys = v
302	return e
303}
304
305// ArraysWithOneElementPerLine sets up the encoder to encode arrays
306// with more than one element on multiple lines instead of one.
307//
308// For example:
309//
310//   A = [1,2,3]
311//
312// Becomes
313//
314//   A = [
315//     1,
316//     2,
317//     3,
318//   ]
319func (e *Encoder) ArraysWithOneElementPerLine(v bool) *Encoder {
320	e.arraysOneElementPerLine = v
321	return e
322}
323
324// Order allows to change in which order fields will be written to the output stream.
325func (e *Encoder) Order(ord MarshalOrder) *Encoder {
326	e.order = ord
327	return e
328}
329
330// Indentation allows to change indentation when marshalling.
331func (e *Encoder) Indentation(indent string) *Encoder {
332	e.indentation = indent
333	return e
334}
335
336// SetTagName allows changing default tag "toml"
337func (e *Encoder) SetTagName(v string) *Encoder {
338	e.tag = v
339	return e
340}
341
342// SetTagComment allows changing default tag "comment"
343func (e *Encoder) SetTagComment(v string) *Encoder {
344	e.comment = v
345	return e
346}
347
348// SetTagCommented allows changing default tag "commented"
349func (e *Encoder) SetTagCommented(v string) *Encoder {
350	e.commented = v
351	return e
352}
353
354// SetTagMultiline allows changing default tag "multiline"
355func (e *Encoder) SetTagMultiline(v string) *Encoder {
356	e.multiline = v
357	return e
358}
359
360// PromoteAnonymous allows to change how anonymous struct fields are marshaled.
361// Usually, they are marshaled as if the inner exported fields were fields in
362// the outer struct. However, if an anonymous struct field is given a name in
363// its TOML tag, it is treated like a regular struct field with that name.
364// rather than being anonymous.
365//
366// In case anonymous promotion is enabled, all anonymous structs are promoted
367// and treated like regular struct fields.
368func (e *Encoder) PromoteAnonymous(promote bool) *Encoder {
369	e.promoteAnon = promote
370	return e
371}
372
373// CompactComments removes the new line before each comment in the tree.
374func (e *Encoder) CompactComments(cc bool) *Encoder {
375	e.compactComments = cc
376	return e
377}
378
379func (e *Encoder) marshal(v interface{}) ([]byte, error) {
380	// Check if indentation is valid
381	for _, char := range e.indentation {
382		if !isSpace(char) {
383			return []byte{}, fmt.Errorf("invalid indentation: must only contains space or tab characters")
384		}
385	}
386
387	mtype := reflect.TypeOf(v)
388	if mtype == nil {
389		return []byte{}, errors.New("nil cannot be marshaled to TOML")
390	}
391
392	switch mtype.Kind() {
393	case reflect.Struct, reflect.Map:
394	case reflect.Ptr:
395		if mtype.Elem().Kind() != reflect.Struct {
396			return []byte{}, errors.New("Only pointer to struct can be marshaled to TOML")
397		}
398		if reflect.ValueOf(v).IsNil() {
399			return []byte{}, errors.New("nil pointer cannot be marshaled to TOML")
400		}
401	default:
402		return []byte{}, errors.New("Only a struct or map can be marshaled to TOML")
403	}
404
405	sval := reflect.ValueOf(v)
406	if isCustomMarshaler(mtype) {
407		return callCustomMarshaler(sval)
408	}
409	if isTextMarshaler(mtype) {
410		return callTextMarshaler(sval)
411	}
412	t, err := e.valueToTree(mtype, sval)
413	if err != nil {
414		return []byte{}, err
415	}
416
417	var buf bytes.Buffer
418	_, err = t.writeToOrdered(&buf, "", "", 0, e.arraysOneElementPerLine, e.order, e.indentation, e.compactComments, false)
419
420	return buf.Bytes(), err
421}
422
423// Create next tree with a position based on Encoder.line
424func (e *Encoder) nextTree() *Tree {
425	return newTreeWithPosition(Position{Line: e.line, Col: 1})
426}
427
428// Convert given marshal struct or map value to toml tree
429func (e *Encoder) valueToTree(mtype reflect.Type, mval reflect.Value) (*Tree, error) {
430	if mtype.Kind() == reflect.Ptr {
431		return e.valueToTree(mtype.Elem(), mval.Elem())
432	}
433	tval := e.nextTree()
434	switch mtype.Kind() {
435	case reflect.Struct:
436		switch mval.Interface().(type) {
437		case Tree:
438			reflect.ValueOf(tval).Elem().Set(mval)
439		default:
440			for i := 0; i < mtype.NumField(); i++ {
441				mtypef, mvalf := mtype.Field(i), mval.Field(i)
442				opts := tomlOptions(mtypef, e.annotation)
443				if opts.include && ((mtypef.Type.Kind() != reflect.Interface && !opts.omitempty) || !isZero(mvalf)) {
444					val, err := e.valueToToml(mtypef.Type, mvalf)
445					if err != nil {
446						return nil, err
447					}
448					if tree, ok := val.(*Tree); ok && mtypef.Anonymous && !opts.nameFromTag && !e.promoteAnon {
449						e.appendTree(tval, tree)
450					} else {
451						val = e.wrapTomlValue(val, tval)
452						tval.SetPathWithOptions([]string{opts.name}, SetOptions{
453							Comment:   opts.comment,
454							Commented: opts.commented,
455							Multiline: opts.multiline,
456							Literal:   opts.literal,
457						}, val)
458					}
459				}
460			}
461		}
462	case reflect.Map:
463		keys := mval.MapKeys()
464		if e.order == OrderPreserve && len(keys) > 0 {
465			// Sorting []reflect.Value is not straight forward.
466			//
467			// OrderPreserve will support deterministic results when string is used
468			// as the key to maps.
469			typ := keys[0].Type()
470			kind := keys[0].Kind()
471			if kind == reflect.String {
472				ikeys := make([]string, len(keys))
473				for i := range keys {
474					ikeys[i] = keys[i].Interface().(string)
475				}
476				sort.Strings(ikeys)
477				for i := range ikeys {
478					keys[i] = reflect.ValueOf(ikeys[i]).Convert(typ)
479				}
480			}
481		}
482		for _, key := range keys {
483			mvalf := mval.MapIndex(key)
484			if (mtype.Elem().Kind() == reflect.Ptr || mtype.Elem().Kind() == reflect.Interface) && mvalf.IsNil() {
485				continue
486			}
487			val, err := e.valueToToml(mtype.Elem(), mvalf)
488			if err != nil {
489				return nil, err
490			}
491			val = e.wrapTomlValue(val, tval)
492			if e.quoteMapKeys {
493				keyStr, err := tomlValueStringRepresentation(key.String(), "", "", e.order, e.arraysOneElementPerLine)
494				if err != nil {
495					return nil, err
496				}
497				tval.SetPath([]string{keyStr}, val)
498			} else {
499				tval.SetPath([]string{key.String()}, val)
500			}
501		}
502	}
503	return tval, nil
504}
505
506// Convert given marshal slice to slice of Toml trees
507func (e *Encoder) valueToTreeSlice(mtype reflect.Type, mval reflect.Value) ([]*Tree, error) {
508	tval := make([]*Tree, mval.Len(), mval.Len())
509	for i := 0; i < mval.Len(); i++ {
510		val, err := e.valueToTree(mtype.Elem(), mval.Index(i))
511		if err != nil {
512			return nil, err
513		}
514		tval[i] = val
515	}
516	return tval, nil
517}
518
519// Convert given marshal slice to slice of toml values
520func (e *Encoder) valueToOtherSlice(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
521	tval := make([]interface{}, mval.Len(), mval.Len())
522	for i := 0; i < mval.Len(); i++ {
523		val, err := e.valueToToml(mtype.Elem(), mval.Index(i))
524		if err != nil {
525			return nil, err
526		}
527		tval[i] = val
528	}
529	return tval, nil
530}
531
532// Convert given marshal value to toml value
533func (e *Encoder) valueToToml(mtype reflect.Type, mval reflect.Value) (interface{}, error) {
534	if mtype.Kind() == reflect.Ptr {
535		switch {
536		case isCustomMarshaler(mtype):
537			return callCustomMarshaler(mval)
538		case isTextMarshaler(mtype):
539			b, err := callTextMarshaler(mval)
540			return string(b), err
541		default:
542			return e.valueToToml(mtype.Elem(), mval.Elem())
543		}
544	}
545	if mtype.Kind() == reflect.Interface {
546		return e.valueToToml(mval.Elem().Type(), mval.Elem())
547	}
548	switch {
549	case isCustomMarshaler(mtype):
550		return callCustomMarshaler(mval)
551	case isTextMarshaler(mtype):
552		b, err := callTextMarshaler(mval)
553		return string(b), err
554	case isTree(mtype):
555		return e.valueToTree(mtype, mval)
556	case isOtherSequence(mtype), isCustomMarshalerSequence(mtype), isTextMarshalerSequence(mtype):
557		return e.valueToOtherSlice(mtype, mval)
558	case isTreeSequence(mtype):
559		return e.valueToTreeSlice(mtype, mval)
560	default:
561		switch mtype.Kind() {
562		case reflect.Bool:
563			return mval.Bool(), nil
564		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
565			if mtype.Kind() == reflect.Int64 && mtype == reflect.TypeOf(time.Duration(1)) {
566				return fmt.Sprint(mval), nil
567			}
568			return mval.Int(), nil
569		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
570			return mval.Uint(), nil
571		case reflect.Float32, reflect.Float64:
572			return mval.Float(), nil
573		case reflect.String:
574			return mval.String(), nil
575		case reflect.Struct:
576			return mval.Interface(), nil
577		default:
578			return nil, fmt.Errorf("Marshal can't handle %v(%v)", mtype, mtype.Kind())
579		}
580	}
581}
582
583func (e *Encoder) appendTree(t, o *Tree) error {
584	for key, value := range o.values {
585		if _, ok := t.values[key]; ok {
586			continue
587		}
588		if tomlValue, ok := value.(*tomlValue); ok {
589			tomlValue.position.Col = t.position.Col
590		}
591		t.values[key] = value
592	}
593	return nil
594}
595
596// Create a toml value with the current line number as the position line
597func (e *Encoder) wrapTomlValue(val interface{}, parent *Tree) interface{} {
598	_, isTree := val.(*Tree)
599	_, isTreeS := val.([]*Tree)
600	if isTree || isTreeS {
601		e.line++
602		return val
603	}
604
605	ret := &tomlValue{
606		value: val,
607		position: Position{
608			e.line,
609			parent.position.Col,
610		},
611	}
612	e.line++
613	return ret
614}
615
616// Unmarshal attempts to unmarshal the Tree into a Go struct pointed by v.
617// Neither Unmarshaler interfaces nor UnmarshalTOML functions are supported for
618// sub-structs, and only definite types can be unmarshaled.
619func (t *Tree) Unmarshal(v interface{}) error {
620	d := Decoder{tval: t, tagName: tagFieldName}
621	return d.unmarshal(v)
622}
623
624// Marshal returns the TOML encoding of Tree.
625// See Marshal() documentation for types mapping table.
626func (t *Tree) Marshal() ([]byte, error) {
627	var buf bytes.Buffer
628	_, err := t.WriteTo(&buf)
629	if err != nil {
630		return nil, err
631	}
632	return buf.Bytes(), nil
633}
634
635// Unmarshal parses the TOML-encoded data and stores the result in the value
636// pointed to by v. Behavior is similar to the Go json encoder, except that there
637// is no concept of an Unmarshaler interface or UnmarshalTOML function for
638// sub-structs, and currently only definite types can be unmarshaled to (i.e. no
639// `interface{}`).
640//
641// The following struct annotations are supported:
642//
643//   toml:"Field" Overrides the field's name to map to.
644//   default:"foo" Provides a default value.
645//
646// For default values, only fields of the following types are supported:
647//   * string
648//   * bool
649//   * int
650//   * int64
651//   * float64
652//
653// See Marshal() documentation for types mapping table.
654func Unmarshal(data []byte, v interface{}) error {
655	t, err := LoadReader(bytes.NewReader(data))
656	if err != nil {
657		return err
658	}
659	return t.Unmarshal(v)
660}
661
662// Decoder reads and decodes TOML values from an input stream.
663type Decoder struct {
664	r    io.Reader
665	tval *Tree
666	encOpts
667	tagName string
668	strict  bool
669	visitor visitorState
670}
671
672// NewDecoder returns a new decoder that reads from r.
673func NewDecoder(r io.Reader) *Decoder {
674	return &Decoder{
675		r:       r,
676		encOpts: encOptsDefaults,
677		tagName: tagFieldName,
678	}
679}
680
681// Decode reads a TOML-encoded value from it's input
682// and unmarshals it in the value pointed at by v.
683//
684// See the documentation for Marshal for details.
685func (d *Decoder) Decode(v interface{}) error {
686	var err error
687	d.tval, err = LoadReader(d.r)
688	if err != nil {
689		return err
690	}
691	return d.unmarshal(v)
692}
693
694// SetTagName allows changing default tag "toml"
695func (d *Decoder) SetTagName(v string) *Decoder {
696	d.tagName = v
697	return d
698}
699
700// Strict allows changing to strict decoding. Any fields that are found in the
701// input data and do not have a corresponding struct member cause an error.
702func (d *Decoder) Strict(strict bool) *Decoder {
703	d.strict = strict
704	return d
705}
706
707func (d *Decoder) unmarshal(v interface{}) error {
708	mtype := reflect.TypeOf(v)
709	if mtype == nil {
710		return errors.New("nil cannot be unmarshaled from TOML")
711	}
712	if mtype.Kind() != reflect.Ptr {
713		return errors.New("only a pointer to struct or map can be unmarshaled from TOML")
714	}
715
716	elem := mtype.Elem()
717
718	switch elem.Kind() {
719	case reflect.Struct, reflect.Map:
720	case reflect.Interface:
721		elem = mapStringInterfaceType
722	default:
723		return errors.New("only a pointer to struct or map can be unmarshaled from TOML")
724	}
725
726	if reflect.ValueOf(v).IsNil() {
727		return errors.New("nil pointer cannot be unmarshaled from TOML")
728	}
729
730	vv := reflect.ValueOf(v).Elem()
731
732	if d.strict {
733		d.visitor = newVisitorState(d.tval)
734	}
735
736	sval, err := d.valueFromTree(elem, d.tval, &vv)
737	if err != nil {
738		return err
739	}
740	if err := d.visitor.validate(); err != nil {
741		return err
742	}
743	reflect.ValueOf(v).Elem().Set(sval)
744	return nil
745}
746
747// Convert toml tree to marshal struct or map, using marshal type. When mval1
748// is non-nil, merge fields into the given value instead of allocating a new one.
749func (d *Decoder) valueFromTree(mtype reflect.Type, tval *Tree, mval1 *reflect.Value) (reflect.Value, error) {
750	if mtype.Kind() == reflect.Ptr {
751		return d.unwrapPointer(mtype, tval, mval1)
752	}
753
754	// Check if pointer to value implements the Unmarshaler interface.
755	if mvalPtr := reflect.New(mtype); isCustomUnmarshaler(mvalPtr.Type()) {
756		d.visitor.visitAll()
757
758		if tval == nil {
759			return mvalPtr.Elem(), nil
760		}
761
762		if err := callCustomUnmarshaler(mvalPtr, tval.ToMap()); err != nil {
763			return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err)
764		}
765		return mvalPtr.Elem(), nil
766	}
767
768	var mval reflect.Value
769	switch mtype.Kind() {
770	case reflect.Struct:
771		if mval1 != nil {
772			mval = *mval1
773		} else {
774			mval = reflect.New(mtype).Elem()
775		}
776
777		switch mval.Interface().(type) {
778		case Tree:
779			mval.Set(reflect.ValueOf(tval).Elem())
780		default:
781			for i := 0; i < mtype.NumField(); i++ {
782				mtypef := mtype.Field(i)
783				an := annotation{tag: d.tagName}
784				opts := tomlOptions(mtypef, an)
785				if !opts.include {
786					continue
787				}
788				baseKey := opts.name
789				keysToTry := []string{
790					baseKey,
791					strings.ToLower(baseKey),
792					strings.ToTitle(baseKey),
793					strings.ToLower(string(baseKey[0])) + baseKey[1:],
794				}
795
796				found := false
797				if tval != nil {
798					for _, key := range keysToTry {
799						exists := tval.HasPath([]string{key})
800						if !exists {
801							continue
802						}
803
804						d.visitor.push(key)
805						val := tval.GetPath([]string{key})
806						fval := mval.Field(i)
807						mvalf, err := d.valueFromToml(mtypef.Type, val, &fval)
808						if err != nil {
809							return mval, formatError(err, tval.GetPositionPath([]string{key}))
810						}
811						mval.Field(i).Set(mvalf)
812						found = true
813						d.visitor.pop()
814						break
815					}
816				}
817
818				if !found && opts.defaultValue != "" {
819					mvalf := mval.Field(i)
820					var val interface{}
821					var err error
822					switch mvalf.Kind() {
823					case reflect.String:
824						val = opts.defaultValue
825					case reflect.Bool:
826						val, err = strconv.ParseBool(opts.defaultValue)
827					case reflect.Uint:
828						val, err = strconv.ParseUint(opts.defaultValue, 10, 0)
829					case reflect.Uint8:
830						val, err = strconv.ParseUint(opts.defaultValue, 10, 8)
831					case reflect.Uint16:
832						val, err = strconv.ParseUint(opts.defaultValue, 10, 16)
833					case reflect.Uint32:
834						val, err = strconv.ParseUint(opts.defaultValue, 10, 32)
835					case reflect.Uint64:
836						val, err = strconv.ParseUint(opts.defaultValue, 10, 64)
837					case reflect.Int:
838						val, err = strconv.ParseInt(opts.defaultValue, 10, 0)
839					case reflect.Int8:
840						val, err = strconv.ParseInt(opts.defaultValue, 10, 8)
841					case reflect.Int16:
842						val, err = strconv.ParseInt(opts.defaultValue, 10, 16)
843					case reflect.Int32:
844						val, err = strconv.ParseInt(opts.defaultValue, 10, 32)
845					case reflect.Int64:
846						// Check if the provided number has a non-numeric extension.
847						var hasExtension bool
848						if len(opts.defaultValue) > 0 {
849							lastChar := opts.defaultValue[len(opts.defaultValue)-1]
850							if lastChar < '0' || lastChar > '9' {
851								hasExtension = true
852							}
853						}
854						// If the value is a time.Duration with extension, parse as duration.
855						// If the value is an int64 or a time.Duration without extension, parse as number.
856						if hasExtension && mvalf.Type().String() == "time.Duration" {
857							val, err = time.ParseDuration(opts.defaultValue)
858						} else {
859							val, err = strconv.ParseInt(opts.defaultValue, 10, 64)
860						}
861					case reflect.Float32:
862						val, err = strconv.ParseFloat(opts.defaultValue, 32)
863					case reflect.Float64:
864						val, err = strconv.ParseFloat(opts.defaultValue, 64)
865					default:
866						return mvalf, fmt.Errorf("unsupported field type for default option")
867					}
868
869					if err != nil {
870						return mvalf, err
871					}
872					mvalf.Set(reflect.ValueOf(val).Convert(mvalf.Type()))
873				}
874
875				// save the old behavior above and try to check structs
876				if !found && opts.defaultValue == "" && mtypef.Type.Kind() == reflect.Struct {
877					tmpTval := tval
878					if !mtypef.Anonymous {
879						tmpTval = nil
880					}
881					fval := mval.Field(i)
882					v, err := d.valueFromTree(mtypef.Type, tmpTval, &fval)
883					if err != nil {
884						return v, err
885					}
886					mval.Field(i).Set(v)
887				}
888			}
889		}
890	case reflect.Map:
891		mval = reflect.MakeMap(mtype)
892		for _, key := range tval.Keys() {
893			d.visitor.push(key)
894			// TODO: path splits key
895			val := tval.GetPath([]string{key})
896			mvalf, err := d.valueFromToml(mtype.Elem(), val, nil)
897			if err != nil {
898				return mval, formatError(err, tval.GetPositionPath([]string{key}))
899			}
900			mval.SetMapIndex(reflect.ValueOf(key).Convert(mtype.Key()), mvalf)
901			d.visitor.pop()
902		}
903	}
904	return mval, nil
905}
906
907// Convert toml value to marshal struct/map slice, using marshal type
908func (d *Decoder) valueFromTreeSlice(mtype reflect.Type, tval []*Tree) (reflect.Value, error) {
909	mval, err := makeSliceOrArray(mtype, len(tval))
910	if err != nil {
911		return mval, err
912	}
913
914	for i := 0; i < len(tval); i++ {
915		d.visitor.push(strconv.Itoa(i))
916		val, err := d.valueFromTree(mtype.Elem(), tval[i], nil)
917		if err != nil {
918			return mval, err
919		}
920		mval.Index(i).Set(val)
921		d.visitor.pop()
922	}
923	return mval, nil
924}
925
926// Convert toml value to marshal primitive slice, using marshal type
927func (d *Decoder) valueFromOtherSlice(mtype reflect.Type, tval []interface{}) (reflect.Value, error) {
928	mval, err := makeSliceOrArray(mtype, len(tval))
929	if err != nil {
930		return mval, err
931	}
932
933	for i := 0; i < len(tval); i++ {
934		val, err := d.valueFromToml(mtype.Elem(), tval[i], nil)
935		if err != nil {
936			return mval, err
937		}
938		mval.Index(i).Set(val)
939	}
940	return mval, nil
941}
942
943// Convert toml value to marshal primitive slice, using marshal type
944func (d *Decoder) valueFromOtherSliceI(mtype reflect.Type, tval interface{}) (reflect.Value, error) {
945	val := reflect.ValueOf(tval)
946	length := val.Len()
947
948	mval, err := makeSliceOrArray(mtype, length)
949	if err != nil {
950		return mval, err
951	}
952
953	for i := 0; i < length; i++ {
954		val, err := d.valueFromToml(mtype.Elem(), val.Index(i).Interface(), nil)
955		if err != nil {
956			return mval, err
957		}
958		mval.Index(i).Set(val)
959	}
960	return mval, nil
961}
962
963// Create a new slice or a new array with specified length
964func makeSliceOrArray(mtype reflect.Type, tLength int) (reflect.Value, error) {
965	var mval reflect.Value
966	switch mtype.Kind() {
967	case reflect.Slice:
968		mval = reflect.MakeSlice(mtype, tLength, tLength)
969	case reflect.Array:
970		mval = reflect.New(reflect.ArrayOf(mtype.Len(), mtype.Elem())).Elem()
971		if tLength > mtype.Len() {
972			return mval, fmt.Errorf("unmarshal: TOML array length (%v) exceeds destination array length (%v)", tLength, mtype.Len())
973		}
974	}
975	return mval, nil
976}
977
978// Convert toml value to marshal value, using marshal type. When mval1 is non-nil
979// and the given type is a struct value, merge fields into it.
980func (d *Decoder) valueFromToml(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) {
981	if mtype.Kind() == reflect.Ptr {
982		return d.unwrapPointer(mtype, tval, mval1)
983	}
984
985	switch t := tval.(type) {
986	case *Tree:
987		var mval11 *reflect.Value
988		if mtype.Kind() == reflect.Struct {
989			mval11 = mval1
990		}
991
992		if isTree(mtype) {
993			return d.valueFromTree(mtype, t, mval11)
994		}
995
996		if mtype.Kind() == reflect.Interface {
997			if mval1 == nil || mval1.IsNil() {
998				return d.valueFromTree(reflect.TypeOf(map[string]interface{}{}), t, nil)
999			} else {
1000				return d.valueFromToml(mval1.Elem().Type(), t, nil)
1001			}
1002		}
1003
1004		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a tree", tval, tval)
1005	case []*Tree:
1006		if isTreeSequence(mtype) {
1007			return d.valueFromTreeSlice(mtype, t)
1008		}
1009		if mtype.Kind() == reflect.Interface {
1010			if mval1 == nil || mval1.IsNil() {
1011				return d.valueFromTreeSlice(reflect.TypeOf([]map[string]interface{}{}), t)
1012			} else {
1013				ival := mval1.Elem()
1014				return d.valueFromToml(mval1.Elem().Type(), t, &ival)
1015			}
1016		}
1017		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to trees", tval, tval)
1018	case []interface{}:
1019		d.visitor.visit()
1020		if isOtherSequence(mtype) {
1021			return d.valueFromOtherSlice(mtype, t)
1022		}
1023		if mtype.Kind() == reflect.Interface {
1024			if mval1 == nil || mval1.IsNil() {
1025				return d.valueFromOtherSlice(reflect.TypeOf([]interface{}{}), t)
1026			} else {
1027				ival := mval1.Elem()
1028				return d.valueFromToml(mval1.Elem().Type(), t, &ival)
1029			}
1030		}
1031		return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to a slice", tval, tval)
1032	default:
1033		d.visitor.visit()
1034		mvalPtr := reflect.New(mtype)
1035
1036		// Check if pointer to value implements the Unmarshaler interface.
1037		if isCustomUnmarshaler(mvalPtr.Type()) {
1038			if err := callCustomUnmarshaler(mvalPtr, tval); err != nil {
1039				return reflect.ValueOf(nil), fmt.Errorf("unmarshal toml: %v", err)
1040			}
1041			return mvalPtr.Elem(), nil
1042		}
1043
1044		// Check if pointer to value implements the encoding.TextUnmarshaler.
1045		if isTextUnmarshaler(mvalPtr.Type()) && !isTimeType(mtype) {
1046			if err := d.unmarshalText(tval, mvalPtr); err != nil {
1047				return reflect.ValueOf(nil), fmt.Errorf("unmarshal text: %v", err)
1048			}
1049			return mvalPtr.Elem(), nil
1050		}
1051
1052		switch mtype.Kind() {
1053		case reflect.Bool, reflect.Struct:
1054			val := reflect.ValueOf(tval)
1055
1056			switch val.Type() {
1057			case localDateType:
1058				localDate := val.Interface().(LocalDate)
1059				switch mtype {
1060				case timeType:
1061					return reflect.ValueOf(time.Date(localDate.Year, localDate.Month, localDate.Day, 0, 0, 0, 0, time.Local)), nil
1062				}
1063			case localDateTimeType:
1064				localDateTime := val.Interface().(LocalDateTime)
1065				switch mtype {
1066				case timeType:
1067					return reflect.ValueOf(time.Date(
1068						localDateTime.Date.Year,
1069						localDateTime.Date.Month,
1070						localDateTime.Date.Day,
1071						localDateTime.Time.Hour,
1072						localDateTime.Time.Minute,
1073						localDateTime.Time.Second,
1074						localDateTime.Time.Nanosecond,
1075						time.Local)), nil
1076				}
1077			}
1078
1079			// if this passes for when mtype is reflect.Struct, tval is a time.LocalTime
1080			if !val.Type().ConvertibleTo(mtype) {
1081				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
1082			}
1083
1084			return val.Convert(mtype), nil
1085		case reflect.String:
1086			val := reflect.ValueOf(tval)
1087			// stupidly, int64 is convertible to string. So special case this.
1088			if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Int64 {
1089				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
1090			}
1091
1092			return val.Convert(mtype), nil
1093		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
1094			val := reflect.ValueOf(tval)
1095			if mtype.Kind() == reflect.Int64 && mtype == reflect.TypeOf(time.Duration(1)) && val.Kind() == reflect.String {
1096				d, err := time.ParseDuration(val.String())
1097				if err != nil {
1098					return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v. %s", tval, tval, mtype.String(), err)
1099				}
1100				return reflect.ValueOf(d), nil
1101			}
1102			if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Float64 {
1103				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
1104			}
1105			if reflect.Indirect(reflect.New(mtype)).OverflowInt(val.Convert(reflect.TypeOf(int64(0))).Int()) {
1106				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
1107			}
1108
1109			return val.Convert(mtype), nil
1110		case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
1111			val := reflect.ValueOf(tval)
1112			if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Float64 {
1113				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
1114			}
1115
1116			if val.Convert(reflect.TypeOf(int(1))).Int() < 0 {
1117				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) is negative so does not fit in %v", tval, tval, mtype.String())
1118			}
1119			if reflect.Indirect(reflect.New(mtype)).OverflowUint(val.Convert(reflect.TypeOf(uint64(0))).Uint()) {
1120				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
1121			}
1122
1123			return val.Convert(mtype), nil
1124		case reflect.Float32, reflect.Float64:
1125			val := reflect.ValueOf(tval)
1126			if !val.Type().ConvertibleTo(mtype) || val.Kind() == reflect.Int64 {
1127				return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v", tval, tval, mtype.String())
1128			}
1129			if reflect.Indirect(reflect.New(mtype)).OverflowFloat(val.Convert(reflect.TypeOf(float64(0))).Float()) {
1130				return reflect.ValueOf(nil), fmt.Errorf("%v(%T) would overflow %v", tval, tval, mtype.String())
1131			}
1132
1133			return val.Convert(mtype), nil
1134		case reflect.Interface:
1135			if mval1 == nil || mval1.IsNil() {
1136				return reflect.ValueOf(tval), nil
1137			} else {
1138				ival := mval1.Elem()
1139				return d.valueFromToml(mval1.Elem().Type(), t, &ival)
1140			}
1141		case reflect.Slice, reflect.Array:
1142			if isOtherSequence(mtype) && isOtherSequence(reflect.TypeOf(t)) {
1143				return d.valueFromOtherSliceI(mtype, t)
1144			}
1145			return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v(%v)", tval, tval, mtype, mtype.Kind())
1146		default:
1147			return reflect.ValueOf(nil), fmt.Errorf("Can't convert %v(%T) to %v(%v)", tval, tval, mtype, mtype.Kind())
1148		}
1149	}
1150}
1151
1152func (d *Decoder) unwrapPointer(mtype reflect.Type, tval interface{}, mval1 *reflect.Value) (reflect.Value, error) {
1153	var melem *reflect.Value
1154
1155	if mval1 != nil && !mval1.IsNil() && (mtype.Elem().Kind() == reflect.Struct || mtype.Elem().Kind() == reflect.Interface) {
1156		elem := mval1.Elem()
1157		melem = &elem
1158	}
1159
1160	val, err := d.valueFromToml(mtype.Elem(), tval, melem)
1161	if err != nil {
1162		return reflect.ValueOf(nil), err
1163	}
1164	mval := reflect.New(mtype.Elem())
1165	mval.Elem().Set(val)
1166	return mval, nil
1167}
1168
1169func (d *Decoder) unmarshalText(tval interface{}, mval reflect.Value) error {
1170	var buf bytes.Buffer
1171	fmt.Fprint(&buf, tval)
1172	return callTextUnmarshaler(mval, buf.Bytes())
1173}
1174
1175func tomlOptions(vf reflect.StructField, an annotation) tomlOpts {
1176	tag := vf.Tag.Get(an.tag)
1177	parse := strings.Split(tag, ",")
1178	var comment string
1179	if c := vf.Tag.Get(an.comment); c != "" {
1180		comment = c
1181	}
1182	commented, _ := strconv.ParseBool(vf.Tag.Get(an.commented))
1183	multiline, _ := strconv.ParseBool(vf.Tag.Get(an.multiline))
1184	literal, _ := strconv.ParseBool(vf.Tag.Get(an.literal))
1185	defaultValue := vf.Tag.Get(tagDefault)
1186	result := tomlOpts{
1187		name:         vf.Name,
1188		nameFromTag:  false,
1189		comment:      comment,
1190		commented:    commented,
1191		multiline:    multiline,
1192		literal:      literal,
1193		include:      true,
1194		omitempty:    false,
1195		defaultValue: defaultValue,
1196	}
1197	if parse[0] != "" {
1198		if parse[0] == "-" && len(parse) == 1 {
1199			result.include = false
1200		} else {
1201			result.name = strings.Trim(parse[0], " ")
1202			result.nameFromTag = true
1203		}
1204	}
1205	if vf.PkgPath != "" {
1206		result.include = false
1207	}
1208	if len(parse) > 1 && strings.Trim(parse[1], " ") == "omitempty" {
1209		result.omitempty = true
1210	}
1211	if vf.Type.Kind() == reflect.Ptr {
1212		result.omitempty = true
1213	}
1214	return result
1215}
1216
1217func isZero(val reflect.Value) bool {
1218	switch val.Type().Kind() {
1219	case reflect.Slice, reflect.Array, reflect.Map:
1220		return val.Len() == 0
1221	default:
1222		return reflect.DeepEqual(val.Interface(), reflect.Zero(val.Type()).Interface())
1223	}
1224}
1225
1226func formatError(err error, pos Position) error {
1227	if err.Error()[0] == '(' { // Error already contains position information
1228		return err
1229	}
1230	return fmt.Errorf("%s: %s", pos, err)
1231}
1232
1233// visitorState keeps track of which keys were unmarshaled.
1234type visitorState struct {
1235	tree   *Tree
1236	path   []string
1237	keys   map[string]struct{}
1238	active bool
1239}
1240
1241func newVisitorState(tree *Tree) visitorState {
1242	path, result := []string{}, map[string]struct{}{}
1243	insertKeys(path, result, tree)
1244	return visitorState{
1245		tree:   tree,
1246		path:   path[:0],
1247		keys:   result,
1248		active: true,
1249	}
1250}
1251
1252func (s *visitorState) push(key string) {
1253	if s.active {
1254		s.path = append(s.path, key)
1255	}
1256}
1257
1258func (s *visitorState) pop() {
1259	if s.active {
1260		s.path = s.path[:len(s.path)-1]
1261	}
1262}
1263
1264func (s *visitorState) visit() {
1265	if s.active {
1266		delete(s.keys, strings.Join(s.path, "."))
1267	}
1268}
1269
1270func (s *visitorState) visitAll() {
1271	if s.active {
1272		for k := range s.keys {
1273			if strings.HasPrefix(k, strings.Join(s.path, ".")) {
1274				delete(s.keys, k)
1275			}
1276		}
1277	}
1278}
1279
1280func (s *visitorState) validate() error {
1281	if !s.active {
1282		return nil
1283	}
1284	undecoded := make([]string, 0, len(s.keys))
1285	for key := range s.keys {
1286		undecoded = append(undecoded, key)
1287	}
1288	sort.Strings(undecoded)
1289	if len(undecoded) > 0 {
1290		return fmt.Errorf("undecoded keys: %q", undecoded)
1291	}
1292	return nil
1293}
1294
1295func insertKeys(path []string, m map[string]struct{}, tree *Tree) {
1296	for k, v := range tree.values {
1297		switch node := v.(type) {
1298		case []*Tree:
1299			for i, item := range node {
1300				insertKeys(append(path, k, strconv.Itoa(i)), m, item)
1301			}
1302		case *Tree:
1303			insertKeys(append(path, k), m, node)
1304		case *tomlValue:
1305			m[strings.Join(append(path, k), ".")] = struct{}{}
1306		}
1307	}
1308}
1309