1// Copyright 2018 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package fmtsort provides a general stable ordering mechanism
6// for maps, on behalf of the fmt and text/template packages.
7// It is not guaranteed to be efficient and works only for types
8// that are valid map keys.
9package fmtsort
10
11import (
12	"reflect"
13	"sort"
14)
15
16// Note: Throughout this package we avoid calling reflect.Value.Interface as
17// it is not always legal to do so and it's easier to avoid the issue than to face it.
18
19// SortedMap represents a map's keys and values. The keys and values are
20// aligned in index order: Value[i] is the value in the map corresponding to Key[i].
21type SortedMap struct {
22	Key   []reflect.Value
23	Value []reflect.Value
24}
25
26func (o *SortedMap) Len() int           { return len(o.Key) }
27func (o *SortedMap) Less(i, j int) bool { return compare(o.Key[i], o.Key[j]) < 0 }
28func (o *SortedMap) Swap(i, j int) {
29	o.Key[i], o.Key[j] = o.Key[j], o.Key[i]
30	o.Value[i], o.Value[j] = o.Value[j], o.Value[i]
31}
32
33// Sort accepts a map and returns a SortedMap that has the same keys and
34// values but in a stable sorted order according to the keys, modulo issues
35// raised by unorderable key values such as NaNs.
36//
37// The ordering rules are more general than with Go's < operator:
38//
39//  - when applicable, nil compares low
40//  - ints, floats, and strings order by <
41//  - NaN compares less than non-NaN floats
42//  - bool compares false before true
43//  - complex compares real, then imag
44//  - pointers compare by machine address
45//  - channel values compare by machine address
46//  - structs compare each field in turn
47//  - arrays compare each element in turn.
48//    Otherwise identical arrays compare by length.
49//  - interface values compare first by reflect.Type describing the concrete type
50//    and then by concrete value as described in the previous rules.
51//
52func Sort(mapValue reflect.Value) *SortedMap {
53	if mapValue.Type().Kind() != reflect.Map {
54		return nil
55	}
56	key := make([]reflect.Value, mapValue.Len())
57	value := make([]reflect.Value, len(key))
58	iter := mapValue.MapRange()
59	for i := 0; iter.Next(); i++ {
60		key[i] = iter.Key()
61		value[i] = iter.Value()
62	}
63	sorted := &SortedMap{
64		Key:   key,
65		Value: value,
66	}
67	sort.Stable(sorted)
68	return sorted
69}
70
71// compare compares two values of the same type. It returns -1, 0, 1
72// according to whether a > b (1), a == b (0), or a < b (-1).
73// If the types differ, it returns -1.
74// See the comment on Sort for the comparison rules.
75func compare(aVal, bVal reflect.Value) int {
76	aType, bType := aVal.Type(), bVal.Type()
77	if aType != bType {
78		return -1 // No good answer possible, but don't return 0: they're not equal.
79	}
80	switch aVal.Kind() {
81	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
82		a, b := aVal.Int(), bVal.Int()
83		switch {
84		case a < b:
85			return -1
86		case a > b:
87			return 1
88		default:
89			return 0
90		}
91	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
92		a, b := aVal.Uint(), bVal.Uint()
93		switch {
94		case a < b:
95			return -1
96		case a > b:
97			return 1
98		default:
99			return 0
100		}
101	case reflect.String:
102		a, b := aVal.String(), bVal.String()
103		switch {
104		case a < b:
105			return -1
106		case a > b:
107			return 1
108		default:
109			return 0
110		}
111	case reflect.Float32, reflect.Float64:
112		return floatCompare(aVal.Float(), bVal.Float())
113	case reflect.Complex64, reflect.Complex128:
114		a, b := aVal.Complex(), bVal.Complex()
115		if c := floatCompare(real(a), real(b)); c != 0 {
116			return c
117		}
118		return floatCompare(imag(a), imag(b))
119	case reflect.Bool:
120		a, b := aVal.Bool(), bVal.Bool()
121		switch {
122		case a == b:
123			return 0
124		case a:
125			return 1
126		default:
127			return -1
128		}
129	case reflect.Ptr:
130		a, b := aVal.Pointer(), bVal.Pointer()
131		switch {
132		case a < b:
133			return -1
134		case a > b:
135			return 1
136		default:
137			return 0
138		}
139	case reflect.Chan:
140		if c, ok := nilCompare(aVal, bVal); ok {
141			return c
142		}
143		ap, bp := aVal.Pointer(), bVal.Pointer()
144		switch {
145		case ap < bp:
146			return -1
147		case ap > bp:
148			return 1
149		default:
150			return 0
151		}
152	case reflect.Struct:
153		for i := 0; i < aVal.NumField(); i++ {
154			if c := compare(aVal.Field(i), bVal.Field(i)); c != 0 {
155				return c
156			}
157		}
158		return 0
159	case reflect.Array:
160		for i := 0; i < aVal.Len(); i++ {
161			if c := compare(aVal.Index(i), bVal.Index(i)); c != 0 {
162				return c
163			}
164		}
165		return 0
166	case reflect.Interface:
167		if c, ok := nilCompare(aVal, bVal); ok {
168			return c
169		}
170		c := compare(reflect.ValueOf(aVal.Elem().Type()), reflect.ValueOf(bVal.Elem().Type()))
171		if c != 0 {
172			return c
173		}
174		return compare(aVal.Elem(), bVal.Elem())
175	default:
176		// Certain types cannot appear as keys (maps, funcs, slices), but be explicit.
177		panic("bad type in compare: " + aType.String())
178	}
179}
180
181// nilCompare checks whether either value is nil. If not, the boolean is false.
182// If either value is nil, the boolean is true and the integer is the comparison
183// value. The comparison is defined to be 0 if both are nil, otherwise the one
184// nil value compares low. Both arguments must represent a chan, func,
185// interface, map, pointer, or slice.
186func nilCompare(aVal, bVal reflect.Value) (int, bool) {
187	if aVal.IsNil() {
188		if bVal.IsNil() {
189			return 0, true
190		}
191		return -1, true
192	}
193	if bVal.IsNil() {
194		return 1, true
195	}
196	return 0, false
197}
198
199// floatCompare compares two floating-point values. NaNs compare low.
200func floatCompare(a, b float64) int {
201	switch {
202	case isNaN(a):
203		return -1 // No good answer if b is a NaN so don't bother checking.
204	case isNaN(b):
205		return 1
206	case a < b:
207		return -1
208	case a > b:
209		return 1
210	}
211	return 0
212}
213
214func isNaN(a float64) bool {
215	return a != a
216}
217