1package toml
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"math"
8	"math/big"
9	"reflect"
10	"sort"
11	"strconv"
12	"strings"
13	"time"
14)
15
16type valueComplexity int
17
18const (
19	valueSimple valueComplexity = iota + 1
20	valueComplex
21)
22
23type sortNode struct {
24	key        string
25	complexity valueComplexity
26}
27
28// Encodes a string to a TOML-compliant multi-line string value
29// This function is a clone of the existing encodeTomlString function, except that whitespace characters
30// are preserved. Quotation marks and backslashes are also not escaped.
31func encodeMultilineTomlString(value string, commented string) string {
32	var b bytes.Buffer
33	adjacentQuoteCount := 0
34
35	b.WriteString(commented)
36	for i, rr := range value {
37		if rr != '"' {
38			adjacentQuoteCount = 0
39		} else {
40			adjacentQuoteCount++
41		}
42		switch rr {
43		case '\b':
44			b.WriteString(`\b`)
45		case '\t':
46			b.WriteString("\t")
47		case '\n':
48			b.WriteString("\n" + commented)
49		case '\f':
50			b.WriteString(`\f`)
51		case '\r':
52			b.WriteString("\r")
53		case '"':
54			if adjacentQuoteCount >= 3 || i == len(value)-1 {
55				adjacentQuoteCount = 0
56				b.WriteString(`\"`)
57			} else {
58				b.WriteString(`"`)
59			}
60		case '\\':
61			b.WriteString(`\`)
62		default:
63			intRr := uint16(rr)
64			if intRr < 0x001F {
65				b.WriteString(fmt.Sprintf("\\u%0.4X", intRr))
66			} else {
67				b.WriteRune(rr)
68			}
69		}
70	}
71	return b.String()
72}
73
74// Encodes a string to a TOML-compliant string value
75func encodeTomlString(value string) string {
76	var b bytes.Buffer
77
78	for _, rr := range value {
79		switch rr {
80		case '\b':
81			b.WriteString(`\b`)
82		case '\t':
83			b.WriteString(`\t`)
84		case '\n':
85			b.WriteString(`\n`)
86		case '\f':
87			b.WriteString(`\f`)
88		case '\r':
89			b.WriteString(`\r`)
90		case '"':
91			b.WriteString(`\"`)
92		case '\\':
93			b.WriteString(`\\`)
94		default:
95			intRr := uint16(rr)
96			if intRr < 0x001F {
97				b.WriteString(fmt.Sprintf("\\u%0.4X", intRr))
98			} else {
99				b.WriteRune(rr)
100			}
101		}
102	}
103	return b.String()
104}
105
106func tomlTreeStringRepresentation(t *Tree, ord marshalOrder) (string, error) {
107	var orderedVals []sortNode
108	switch ord {
109	case OrderPreserve:
110		orderedVals = sortByLines(t)
111	default:
112		orderedVals = sortAlphabetical(t)
113	}
114
115	var values []string
116	for _, node := range orderedVals {
117		k := node.key
118		v := t.values[k]
119
120		repr, err := tomlValueStringRepresentation(v, "", "", ord, false)
121		if err != nil {
122			return "", err
123		}
124		values = append(values, quoteKeyIfNeeded(k)+" = "+repr)
125	}
126	return "{ " + strings.Join(values, ", ") + " }", nil
127}
128
129func tomlValueStringRepresentation(v interface{}, commented string, indent string, ord marshalOrder, arraysOneElementPerLine bool) (string, error) {
130	// this interface check is added to dereference the change made in the writeTo function.
131	// That change was made to allow this function to see formatting options.
132	tv, ok := v.(*tomlValue)
133	if ok {
134		v = tv.value
135	} else {
136		tv = &tomlValue{}
137	}
138
139	switch value := v.(type) {
140	case uint64:
141		return strconv.FormatUint(value, 10), nil
142	case int64:
143		return strconv.FormatInt(value, 10), nil
144	case float64:
145		// Default bit length is full 64
146		bits := 64
147		// Float panics if nan is used
148		if !math.IsNaN(value) {
149			// if 32 bit accuracy is enough to exactly show, use 32
150			_, acc := big.NewFloat(value).Float32()
151			if acc == big.Exact {
152				bits = 32
153			}
154		}
155		if math.Trunc(value) == value {
156			return strings.ToLower(strconv.FormatFloat(value, 'f', 1, bits)), nil
157		}
158		return strings.ToLower(strconv.FormatFloat(value, 'f', -1, bits)), nil
159	case string:
160		if tv.multiline {
161			return "\"\"\"\n" + encodeMultilineTomlString(value, commented) + "\"\"\"", nil
162		}
163		return "\"" + encodeTomlString(value) + "\"", nil
164	case []byte:
165		b, _ := v.([]byte)
166		return string(b), nil
167	case bool:
168		if value {
169			return "true", nil
170		}
171		return "false", nil
172	case time.Time:
173		return value.Format(time.RFC3339), nil
174	case LocalDate:
175		return value.String(), nil
176	case LocalDateTime:
177		return value.String(), nil
178	case LocalTime:
179		return value.String(), nil
180	case *Tree:
181		return tomlTreeStringRepresentation(value, ord)
182	case nil:
183		return "", nil
184	}
185
186	rv := reflect.ValueOf(v)
187
188	if rv.Kind() == reflect.Slice {
189		var values []string
190		for i := 0; i < rv.Len(); i++ {
191			item := rv.Index(i).Interface()
192			itemRepr, err := tomlValueStringRepresentation(item, commented, indent, ord, arraysOneElementPerLine)
193			if err != nil {
194				return "", err
195			}
196			values = append(values, itemRepr)
197		}
198		if arraysOneElementPerLine && len(values) > 1 {
199			stringBuffer := bytes.Buffer{}
200			valueIndent := indent + `  ` // TODO: move that to a shared encoder state
201
202			stringBuffer.WriteString("[\n")
203
204			for _, value := range values {
205				stringBuffer.WriteString(valueIndent)
206				stringBuffer.WriteString(commented + value)
207				stringBuffer.WriteString(`,`)
208				stringBuffer.WriteString("\n")
209			}
210
211			stringBuffer.WriteString(indent + commented + "]")
212
213			return stringBuffer.String(), nil
214		}
215		return "[" + strings.Join(values, ", ") + "]", nil
216	}
217	return "", fmt.Errorf("unsupported value type %T: %v", v, v)
218}
219
220func getTreeArrayLine(trees []*Tree) (line int) {
221	// get lowest line number that is not 0
222	for _, tv := range trees {
223		if tv.position.Line < line || line == 0 {
224			line = tv.position.Line
225		}
226	}
227	return
228}
229
230func sortByLines(t *Tree) (vals []sortNode) {
231	var (
232		line  int
233		lines []int
234		tv    *Tree
235		tom   *tomlValue
236		node  sortNode
237	)
238	vals = make([]sortNode, 0)
239	m := make(map[int]sortNode)
240
241	for k := range t.values {
242		v := t.values[k]
243		switch v.(type) {
244		case *Tree:
245			tv = v.(*Tree)
246			line = tv.position.Line
247			node = sortNode{key: k, complexity: valueComplex}
248		case []*Tree:
249			line = getTreeArrayLine(v.([]*Tree))
250			node = sortNode{key: k, complexity: valueComplex}
251		default:
252			tom = v.(*tomlValue)
253			line = tom.position.Line
254			node = sortNode{key: k, complexity: valueSimple}
255		}
256		lines = append(lines, line)
257		vals = append(vals, node)
258		m[line] = node
259	}
260	sort.Ints(lines)
261
262	for i, line := range lines {
263		vals[i] = m[line]
264	}
265
266	return vals
267}
268
269func sortAlphabetical(t *Tree) (vals []sortNode) {
270	var (
271		node     sortNode
272		simpVals []string
273		compVals []string
274	)
275	vals = make([]sortNode, 0)
276	m := make(map[string]sortNode)
277
278	for k := range t.values {
279		v := t.values[k]
280		switch v.(type) {
281		case *Tree, []*Tree:
282			node = sortNode{key: k, complexity: valueComplex}
283			compVals = append(compVals, node.key)
284		default:
285			node = sortNode{key: k, complexity: valueSimple}
286			simpVals = append(simpVals, node.key)
287		}
288		vals = append(vals, node)
289		m[node.key] = node
290	}
291
292	// Simples first to match previous implementation
293	sort.Strings(simpVals)
294	i := 0
295	for _, key := range simpVals {
296		vals[i] = m[key]
297		i++
298	}
299
300	sort.Strings(compVals)
301	for _, key := range compVals {
302		vals[i] = m[key]
303		i++
304	}
305
306	return vals
307}
308
309func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool) (int64, error) {
310	return t.writeToOrdered(w, indent, keyspace, bytesCount, arraysOneElementPerLine, OrderAlphabetical, "  ", false)
311}
312
313func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool, ord marshalOrder, indentString string, parentCommented bool) (int64, error) {
314	var orderedVals []sortNode
315
316	switch ord {
317	case OrderPreserve:
318		orderedVals = sortByLines(t)
319	default:
320		orderedVals = sortAlphabetical(t)
321	}
322
323	for _, node := range orderedVals {
324		switch node.complexity {
325		case valueComplex:
326			k := node.key
327			v := t.values[k]
328
329			combinedKey := quoteKeyIfNeeded(k)
330			if keyspace != "" {
331				combinedKey = keyspace + "." + combinedKey
332			}
333
334			switch node := v.(type) {
335			// node has to be of those two types given how keys are sorted above
336			case *Tree:
337				tv, ok := t.values[k].(*Tree)
338				if !ok {
339					return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k])
340				}
341				if tv.comment != "" {
342					comment := strings.Replace(tv.comment, "\n", "\n"+indent+"#", -1)
343					start := "# "
344					if strings.HasPrefix(comment, "#") {
345						start = ""
346					}
347					writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment)
348					bytesCount += int64(writtenBytesCountComment)
349					if errc != nil {
350						return bytesCount, errc
351					}
352				}
353
354				var commented string
355				if parentCommented || t.commented || tv.commented {
356					commented = "# "
357				}
358				writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[", combinedKey, "]\n")
359				bytesCount += int64(writtenBytesCount)
360				if err != nil {
361					return bytesCount, err
362				}
363				bytesCount, err = node.writeToOrdered(w, indent+indentString, combinedKey, bytesCount, arraysOneElementPerLine, ord, indentString, parentCommented || t.commented || tv.commented)
364				if err != nil {
365					return bytesCount, err
366				}
367			case []*Tree:
368				for _, subTree := range node {
369					var commented string
370					if parentCommented || t.commented || subTree.commented {
371						commented = "# "
372					}
373					writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[[", combinedKey, "]]\n")
374					bytesCount += int64(writtenBytesCount)
375					if err != nil {
376						return bytesCount, err
377					}
378
379					bytesCount, err = subTree.writeToOrdered(w, indent+indentString, combinedKey, bytesCount, arraysOneElementPerLine, ord, indentString, parentCommented || t.commented || subTree.commented)
380					if err != nil {
381						return bytesCount, err
382					}
383				}
384			}
385		default: // Simple
386			k := node.key
387			v, ok := t.values[k].(*tomlValue)
388			if !ok {
389				return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k])
390			}
391
392			var commented string
393			if parentCommented || t.commented || v.commented {
394				commented = "# "
395			}
396			repr, err := tomlValueStringRepresentation(v, commented, indent, ord, arraysOneElementPerLine)
397			if err != nil {
398				return bytesCount, err
399			}
400
401			if v.comment != "" {
402				comment := strings.Replace(v.comment, "\n", "\n"+indent+"#", -1)
403				start := "# "
404				if strings.HasPrefix(comment, "#") {
405					start = ""
406				}
407				writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment, "\n")
408				bytesCount += int64(writtenBytesCountComment)
409				if errc != nil {
410					return bytesCount, errc
411				}
412			}
413
414			quotedKey := quoteKeyIfNeeded(k)
415			writtenBytesCount, err := writeStrings(w, indent, commented, quotedKey, " = ", repr, "\n")
416			bytesCount += int64(writtenBytesCount)
417			if err != nil {
418				return bytesCount, err
419			}
420		}
421	}
422
423	return bytesCount, nil
424}
425
426// quote a key if it does not fit the bare key format (A-Za-z0-9_-)
427// quoted keys use the same rules as strings
428func quoteKeyIfNeeded(k string) string {
429	// when encoding a map with the 'quoteMapKeys' option enabled, the tree will contain
430	// keys that have already been quoted.
431	// not an ideal situation, but good enough of a stop gap.
432	if len(k) >= 2 && k[0] == '"' && k[len(k)-1] == '"' {
433		return k
434	}
435	isBare := true
436	for _, r := range k {
437		if !isValidBareChar(r) {
438			isBare = false
439			break
440		}
441	}
442	if isBare {
443		return k
444	}
445	return quoteKey(k)
446}
447
448func quoteKey(k string) string {
449	return "\"" + encodeTomlString(k) + "\""
450}
451
452func writeStrings(w io.Writer, s ...string) (int, error) {
453	var n int
454	for i := range s {
455		b, err := io.WriteString(w, s[i])
456		n += b
457		if err != nil {
458			return n, err
459		}
460	}
461	return n, nil
462}
463
464// WriteTo encode the Tree as Toml and writes it to the writer w.
465// Returns the number of bytes written in case of success, or an error if anything happened.
466func (t *Tree) WriteTo(w io.Writer) (int64, error) {
467	return t.writeTo(w, "", "", 0, false)
468}
469
470// ToTomlString generates a human-readable representation of the current tree.
471// Output spans multiple lines, and is suitable for ingest by a TOML parser.
472// If the conversion cannot be performed, ToString returns a non-nil error.
473func (t *Tree) ToTomlString() (string, error) {
474	b, err := t.Marshal()
475	if err != nil {
476		return "", err
477	}
478	return string(b), nil
479}
480
481// String generates a human-readable representation of the current tree.
482// Alias of ToString. Present to implement the fmt.Stringer interface.
483func (t *Tree) String() string {
484	result, _ := t.ToTomlString()
485	return result
486}
487
488// ToMap recursively generates a representation of the tree using Go built-in structures.
489// The following types are used:
490//
491//	* bool
492//	* float64
493//	* int64
494//	* string
495//	* uint64
496//	* time.Time
497//	* map[string]interface{} (where interface{} is any of this list)
498//	* []interface{} (where interface{} is any of this list)
499func (t *Tree) ToMap() map[string]interface{} {
500	result := map[string]interface{}{}
501
502	for k, v := range t.values {
503		switch node := v.(type) {
504		case []*Tree:
505			var array []interface{}
506			for _, item := range node {
507				array = append(array, item.ToMap())
508			}
509			result[k] = array
510		case *Tree:
511			result[k] = node.ToMap()
512		case *tomlValue:
513			result[k] = node.value
514		}
515	}
516	return result
517}
518