1package toml
2
3import (
4	"bytes"
5	"fmt"
6	"io"
7	"math"
8	"math/big"
9	"reflect"
10	"sort"
11	"strconv"
12	"strings"
13	"time"
14)
15
16type valueComplexity int
17
18const (
19	valueSimple valueComplexity = iota + 1
20	valueComplex
21)
22
23type sortNode struct {
24	key        string
25	complexity valueComplexity
26}
27
28// Encodes a string to a TOML-compliant multi-line string value
29// This function is a clone of the existing encodeTomlString function, except that whitespace characters
30// are preserved. Quotation marks and backslashes are also not escaped.
31func encodeMultilineTomlString(value string, commented string) string {
32	var b bytes.Buffer
33
34	b.WriteString(commented)
35	for _, rr := range value {
36		switch rr {
37		case '\b':
38			b.WriteString(`\b`)
39		case '\t':
40			b.WriteString("\t")
41		case '\n':
42			b.WriteString("\n" + commented)
43		case '\f':
44			b.WriteString(`\f`)
45		case '\r':
46			b.WriteString("\r")
47		case '"':
48			b.WriteString(`"`)
49		case '\\':
50			b.WriteString(`\`)
51		default:
52			intRr := uint16(rr)
53			if intRr < 0x001F {
54				b.WriteString(fmt.Sprintf("\\u%0.4X", intRr))
55			} else {
56				b.WriteRune(rr)
57			}
58		}
59	}
60	return b.String()
61}
62
63// Encodes a string to a TOML-compliant string value
64func encodeTomlString(value string) string {
65	var b bytes.Buffer
66
67	for _, rr := range value {
68		switch rr {
69		case '\b':
70			b.WriteString(`\b`)
71		case '\t':
72			b.WriteString(`\t`)
73		case '\n':
74			b.WriteString(`\n`)
75		case '\f':
76			b.WriteString(`\f`)
77		case '\r':
78			b.WriteString(`\r`)
79		case '"':
80			b.WriteString(`\"`)
81		case '\\':
82			b.WriteString(`\\`)
83		default:
84			intRr := uint16(rr)
85			if intRr < 0x001F {
86				b.WriteString(fmt.Sprintf("\\u%0.4X", intRr))
87			} else {
88				b.WriteRune(rr)
89			}
90		}
91	}
92	return b.String()
93}
94
95func tomlValueStringRepresentation(v interface{}, commented string, indent string, arraysOneElementPerLine bool) (string, error) {
96	// this interface check is added to dereference the change made in the writeTo function.
97	// That change was made to allow this function to see formatting options.
98	tv, ok := v.(*tomlValue)
99	if ok {
100		v = tv.value
101	} else {
102		tv = &tomlValue{}
103	}
104
105	switch value := v.(type) {
106	case uint64:
107		return strconv.FormatUint(value, 10), nil
108	case int64:
109		return strconv.FormatInt(value, 10), nil
110	case float64:
111		// Default bit length is full 64
112		bits := 64
113		// Float panics if nan is used
114		if !math.IsNaN(value) {
115			// if 32 bit accuracy is enough to exactly show, use 32
116			_, acc := big.NewFloat(value).Float32()
117			if acc == big.Exact {
118				bits = 32
119			}
120		}
121		if math.Trunc(value) == value {
122			return strings.ToLower(strconv.FormatFloat(value, 'f', 1, bits)), nil
123		}
124		return strings.ToLower(strconv.FormatFloat(value, 'f', -1, bits)), nil
125	case string:
126		if tv.multiline {
127			return "\"\"\"\n" + encodeMultilineTomlString(value, commented) + "\"\"\"", nil
128		}
129		return "\"" + encodeTomlString(value) + "\"", nil
130	case []byte:
131		b, _ := v.([]byte)
132		return tomlValueStringRepresentation(string(b), commented, indent, arraysOneElementPerLine)
133	case bool:
134		if value {
135			return "true", nil
136		}
137		return "false", nil
138	case time.Time:
139		return value.Format(time.RFC3339), nil
140	case LocalDate:
141		return value.String(), nil
142	case LocalDateTime:
143		return value.String(), nil
144	case LocalTime:
145		return value.String(), nil
146	case nil:
147		return "", nil
148	}
149
150	rv := reflect.ValueOf(v)
151
152	if rv.Kind() == reflect.Slice {
153		var values []string
154		for i := 0; i < rv.Len(); i++ {
155			item := rv.Index(i).Interface()
156			itemRepr, err := tomlValueStringRepresentation(item, commented, indent, arraysOneElementPerLine)
157			if err != nil {
158				return "", err
159			}
160			values = append(values, itemRepr)
161		}
162		if arraysOneElementPerLine && len(values) > 1 {
163			stringBuffer := bytes.Buffer{}
164			valueIndent := indent + `  ` // TODO: move that to a shared encoder state
165
166			stringBuffer.WriteString("[\n")
167
168			for _, value := range values {
169				stringBuffer.WriteString(valueIndent)
170				stringBuffer.WriteString(commented + value)
171				stringBuffer.WriteString(`,`)
172				stringBuffer.WriteString("\n")
173			}
174
175			stringBuffer.WriteString(indent + commented + "]")
176
177			return stringBuffer.String(), nil
178		}
179		return "[" + strings.Join(values, ",") + "]", nil
180	}
181	return "", fmt.Errorf("unsupported value type %T: %v", v, v)
182}
183
184func getTreeArrayLine(trees []*Tree) (line int) {
185	// get lowest line number that is not 0
186	for _, tv := range trees {
187		if tv.position.Line < line || line == 0 {
188			line = tv.position.Line
189		}
190	}
191	return
192}
193
194func sortByLines(t *Tree) (vals []sortNode) {
195	var (
196		line  int
197		lines []int
198		tv    *Tree
199		tom   *tomlValue
200		node  sortNode
201	)
202	vals = make([]sortNode, 0)
203	m := make(map[int]sortNode)
204
205	for k := range t.values {
206		v := t.values[k]
207		switch v.(type) {
208		case *Tree:
209			tv = v.(*Tree)
210			line = tv.position.Line
211			node = sortNode{key: k, complexity: valueComplex}
212		case []*Tree:
213			line = getTreeArrayLine(v.([]*Tree))
214			node = sortNode{key: k, complexity: valueComplex}
215		default:
216			tom = v.(*tomlValue)
217			line = tom.position.Line
218			node = sortNode{key: k, complexity: valueSimple}
219		}
220		lines = append(lines, line)
221		vals = append(vals, node)
222		m[line] = node
223	}
224	sort.Ints(lines)
225
226	for i, line := range lines {
227		vals[i] = m[line]
228	}
229
230	return vals
231}
232
233func sortAlphabetical(t *Tree) (vals []sortNode) {
234	var (
235		node     sortNode
236		simpVals []string
237		compVals []string
238	)
239	vals = make([]sortNode, 0)
240	m := make(map[string]sortNode)
241
242	for k := range t.values {
243		v := t.values[k]
244		switch v.(type) {
245		case *Tree, []*Tree:
246			node = sortNode{key: k, complexity: valueComplex}
247			compVals = append(compVals, node.key)
248		default:
249			node = sortNode{key: k, complexity: valueSimple}
250			simpVals = append(simpVals, node.key)
251		}
252		vals = append(vals, node)
253		m[node.key] = node
254	}
255
256	// Simples first to match previous implementation
257	sort.Strings(simpVals)
258	i := 0
259	for _, key := range simpVals {
260		vals[i] = m[key]
261		i++
262	}
263
264	sort.Strings(compVals)
265	for _, key := range compVals {
266		vals[i] = m[key]
267		i++
268	}
269
270	return vals
271}
272
273func (t *Tree) writeTo(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool) (int64, error) {
274	return t.writeToOrdered(w, indent, keyspace, bytesCount, arraysOneElementPerLine, OrderAlphabetical, false)
275}
276
277func (t *Tree) writeToOrdered(w io.Writer, indent, keyspace string, bytesCount int64, arraysOneElementPerLine bool, ord marshalOrder, parentCommented bool) (int64, error) {
278	var orderedVals []sortNode
279
280	switch ord {
281	case OrderPreserve:
282		orderedVals = sortByLines(t)
283	default:
284		orderedVals = sortAlphabetical(t)
285	}
286
287	for _, node := range orderedVals {
288		switch node.complexity {
289		case valueComplex:
290			k := node.key
291			v := t.values[k]
292
293			combinedKey := k
294			if keyspace != "" {
295				combinedKey = keyspace + "." + combinedKey
296			}
297
298			switch node := v.(type) {
299			// node has to be of those two types given how keys are sorted above
300			case *Tree:
301				tv, ok := t.values[k].(*Tree)
302				if !ok {
303					return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k])
304				}
305				if tv.comment != "" {
306					comment := strings.Replace(tv.comment, "\n", "\n"+indent+"#", -1)
307					start := "# "
308					if strings.HasPrefix(comment, "#") {
309						start = ""
310					}
311					writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment)
312					bytesCount += int64(writtenBytesCountComment)
313					if errc != nil {
314						return bytesCount, errc
315					}
316				}
317
318				var commented string
319				if parentCommented || t.commented || tv.commented {
320					commented = "# "
321				}
322				writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[", combinedKey, "]\n")
323				bytesCount += int64(writtenBytesCount)
324				if err != nil {
325					return bytesCount, err
326				}
327				bytesCount, err = node.writeToOrdered(w, indent+"  ", combinedKey, bytesCount, arraysOneElementPerLine, ord, parentCommented || t.commented || tv.commented)
328				if err != nil {
329					return bytesCount, err
330				}
331			case []*Tree:
332				for _, subTree := range node {
333					var commented string
334					if parentCommented || t.commented || subTree.commented {
335						commented = "# "
336					}
337					writtenBytesCount, err := writeStrings(w, "\n", indent, commented, "[[", combinedKey, "]]\n")
338					bytesCount += int64(writtenBytesCount)
339					if err != nil {
340						return bytesCount, err
341					}
342
343					bytesCount, err = subTree.writeToOrdered(w, indent+"  ", combinedKey, bytesCount, arraysOneElementPerLine, ord, parentCommented || t.commented || subTree.commented)
344					if err != nil {
345						return bytesCount, err
346					}
347				}
348			}
349		default: // Simple
350			k := node.key
351			v, ok := t.values[k].(*tomlValue)
352			if !ok {
353				return bytesCount, fmt.Errorf("invalid value type at %s: %T", k, t.values[k])
354			}
355
356			var commented string
357			if parentCommented || t.commented || v.commented {
358				commented = "# "
359			}
360			repr, err := tomlValueStringRepresentation(v, commented, indent, arraysOneElementPerLine)
361			if err != nil {
362				return bytesCount, err
363			}
364
365			if v.comment != "" {
366				comment := strings.Replace(v.comment, "\n", "\n"+indent+"#", -1)
367				start := "# "
368				if strings.HasPrefix(comment, "#") {
369					start = ""
370				}
371				writtenBytesCountComment, errc := writeStrings(w, "\n", indent, start, comment, "\n")
372				bytesCount += int64(writtenBytesCountComment)
373				if errc != nil {
374					return bytesCount, errc
375				}
376			}
377
378			quotedKey := quoteKeyIfNeeded(k)
379			writtenBytesCount, err := writeStrings(w, indent, commented, quotedKey, " = ", repr, "\n")
380			bytesCount += int64(writtenBytesCount)
381			if err != nil {
382				return bytesCount, err
383			}
384		}
385	}
386
387	return bytesCount, nil
388}
389
390// quote a key if it does not fit the bare key format (A-Za-z0-9_-)
391// quoted keys use the same rules as strings
392func quoteKeyIfNeeded(k string) string {
393	// when encoding a map with the 'quoteMapKeys' option enabled, the tree will contain
394	// keys that have already been quoted.
395	// not an ideal situation, but good enough of a stop gap.
396	if len(k) >= 2 && k[0] == '"' && k[len(k)-1] == '"' {
397		return k
398	}
399	isBare := true
400	for _, r := range k {
401		if !isValidBareChar(r) {
402			isBare = false
403			break
404		}
405	}
406	if isBare {
407		return k
408	}
409	return quoteKey(k)
410}
411
412func quoteKey(k string) string {
413	return "\"" + encodeTomlString(k) + "\""
414}
415
416func writeStrings(w io.Writer, s ...string) (int, error) {
417	var n int
418	for i := range s {
419		b, err := io.WriteString(w, s[i])
420		n += b
421		if err != nil {
422			return n, err
423		}
424	}
425	return n, nil
426}
427
428// WriteTo encode the Tree as Toml and writes it to the writer w.
429// Returns the number of bytes written in case of success, or an error if anything happened.
430func (t *Tree) WriteTo(w io.Writer) (int64, error) {
431	return t.writeTo(w, "", "", 0, false)
432}
433
434// ToTomlString generates a human-readable representation of the current tree.
435// Output spans multiple lines, and is suitable for ingest by a TOML parser.
436// If the conversion cannot be performed, ToString returns a non-nil error.
437func (t *Tree) ToTomlString() (string, error) {
438	b, err := t.Marshal()
439	if err != nil {
440		return "", err
441	}
442	return string(b), nil
443}
444
445// String generates a human-readable representation of the current tree.
446// Alias of ToString. Present to implement the fmt.Stringer interface.
447func (t *Tree) String() string {
448	result, _ := t.ToTomlString()
449	return result
450}
451
452// ToMap recursively generates a representation of the tree using Go built-in structures.
453// The following types are used:
454//
455//	* bool
456//	* float64
457//	* int64
458//	* string
459//	* uint64
460//	* time.Time
461//	* map[string]interface{} (where interface{} is any of this list)
462//	* []interface{} (where interface{} is any of this list)
463func (t *Tree) ToMap() map[string]interface{} {
464	result := map[string]interface{}{}
465
466	for k, v := range t.values {
467		switch node := v.(type) {
468		case []*Tree:
469			var array []interface{}
470			for _, item := range node {
471				array = append(array, item.ToMap())
472			}
473			result[k] = array
474		case *Tree:
475			result[k] = node.ToMap()
476		case *tomlValue:
477			result[k] = node.value
478		}
479	}
480	return result
481}
482