1// Copyright 2014 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5// Package slice provides a slice sorting function.
6//
7// It uses gross, low-level operations to make it easy to sort
8// arbitrary slices with only a less function, without defining a new
9// type with Len and Swap operations.
10package slice
11
12import (
13	"fmt"
14	"reflect"
15	"sort"
16	"unsafe"
17)
18
19const useReflectSwap = false
20
21const ptrSize = unsafe.Sizeof((*int)(nil))
22
23// Sort sorts the provided slice using the function less.
24// If slice is not a slice, Sort panics.
25func Sort(slice interface{}, less func(i, j int) bool) {
26	sort.Sort(SortInterface(slice, less))
27}
28
29// SortInterface returns a sort.Interface to sort the provided slice
30// using the function less.
31func SortInterface(slice interface{}, less func(i, j int) bool) sort.Interface {
32	sv := reflect.ValueOf(slice)
33	if sv.Kind() != reflect.Slice {
34		panic(fmt.Sprintf("slice.Sort called with non-slice value of type %T", slice))
35	}
36
37	size := sv.Type().Elem().Size()
38	ss := &lenLesser{
39		less:  less,
40		slice: sv,
41		size:  size,
42		len:   sv.Len(),
43	}
44
45	var baseMem unsafe.Pointer
46	if ss.len > 0 {
47		baseMem = unsafe.Pointer(sv.Index(0).Addr().Pointer())
48	}
49
50	switch {
51	case useReflectSwap:
52		return &reflectSwap{
53			temp:      reflect.New(sv.Type().Elem()).Elem(),
54			lenLesser: ss,
55		}
56	case uintptr(size) == ptrSize:
57		return &pointerSwap{
58			baseMem:   baseMem,
59			lenLesser: ss,
60		}
61	case size == 8:
62		return &swap8{
63			baseMem:   baseMem,
64			lenLesser: ss,
65		}
66	case size == 4:
67		return &swap4{
68			baseMem:   baseMem,
69			lenLesser: ss,
70		}
71	default:
72		// Make a properly-typed (for GC) chunk of memory for swap
73		// operations.
74		temp := reflect.New(sv.Type().Elem()).Elem()
75		tempMem := unsafe.Pointer(temp.Addr().Pointer())
76		ms := newMemSwap(size, baseMem, tempMem)
77		ms.lenLesser = ss
78		return ms
79	}
80}
81
82func newMemSwap(size uintptr, baseMem, tempMem unsafe.Pointer) *memSwap {
83	tempSlice := *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{
84		Data: uintptr(tempMem),
85		Len:  int(size),
86		Cap:  int(size),
87	}))
88	ms := &memSwap{
89		imem: *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: uintptr(baseMem), Len: int(size), Cap: int(size)})),
90		jmem: *(*[]byte)(unsafe.Pointer(&reflect.SliceHeader{Data: uintptr(baseMem), Len: int(size), Cap: int(size)})),
91		temp: tempSlice,
92		size: size,
93		base: baseMem,
94	}
95	ms.ibase = (*uintptr)(unsafe.Pointer(&ms.imem))
96	ms.jbase = (*uintptr)(unsafe.Pointer(&ms.jmem))
97	return ms
98}
99
100type lenLesser struct {
101	less  func(i, j int) bool
102	slice reflect.Value
103	len   int
104	size  uintptr
105}
106
107func (s *lenLesser) Len() int { return s.len }
108
109func (s *lenLesser) Less(i, j int) bool {
110	return s.less(i, j)
111}
112
113// reflectSwap is the pure reflect-based swap. It's compiled out by
114// default because it's ridiculously slow. But it's kept here in case
115// you want to see for yourself.
116type reflectSwap struct {
117	temp reflect.Value
118	*lenLesser
119}
120
121func (s *reflectSwap) Swap(i, j int) {
122	s.temp.Set(s.slice.Index(i))
123	s.slice.Index(i).Set(s.slice.Index(j))
124	s.slice.Index(j).Set(s.temp)
125}
126
127// pointerSwap swaps pointers.
128type pointerSwap struct {
129	baseMem unsafe.Pointer
130	*lenLesser
131}
132
133func (s *pointerSwap) Swap(i, j int) {
134	base := s.baseMem
135	ip := (*unsafe.Pointer)(unsafe.Pointer(uintptr(base) + uintptr(i)*ptrSize))
136	jp := (*unsafe.Pointer)(unsafe.Pointer(uintptr(base) + uintptr(j)*ptrSize))
137	*ip, *jp = *jp, *ip
138}
139
140// swap8 swaps 8-byte non-pointer elements.
141type swap8 struct {
142	baseMem unsafe.Pointer
143	*lenLesser
144}
145
146func (s *swap8) Swap(i, j int) {
147	base := s.baseMem
148	ip := (*uint64)(unsafe.Pointer(uintptr(base) + uintptr(i)*8))
149	jp := (*uint64)(unsafe.Pointer(uintptr(base) + uintptr(j)*8))
150	*ip, *jp = *jp, *ip
151}
152
153// swap4 swaps 4-byte non-pointer elements.
154type swap4 struct {
155	baseMem unsafe.Pointer
156	*lenLesser
157}
158
159func (s *swap4) Swap(i, j int) {
160	base := s.baseMem
161	ip := (*uint32)(unsafe.Pointer(uintptr(base) + uintptr(i)*4))
162	jp := (*uint32)(unsafe.Pointer(uintptr(base) + uintptr(j)*4))
163	*ip, *jp = *jp, *ip
164}
165
166// memSwap swaps regions of memory
167type memSwap struct {
168	imem  []byte
169	jmem  []byte
170	temp  []byte   // properly typed slice of memory to use as temp space
171	ibase *uintptr // ibase points to the Data word of imem
172	jbase *uintptr // jbase points to the Data word of jmem
173	size  uintptr
174	base  unsafe.Pointer
175	*lenLesser
176}
177
178func (s *memSwap) Swap(i, j int) {
179	imem, jmem, temp := s.imem, s.jmem, s.temp
180	base, size := s.base, s.size
181	*(*uintptr)(unsafe.Pointer(&imem)) = uintptr(base) + size*uintptr(i)
182	*(*uintptr)(unsafe.Pointer(&jmem)) = uintptr(base) + size*uintptr(j)
183	copy(temp, imem)
184	copy(imem, jmem)
185	copy(jmem, temp)
186}
187