1// Copyright (c) 2015 The gocql Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package gocql
6
7import (
8	"bytes"
9	"crypto/md5"
10	"fmt"
11	"math/big"
12	"sort"
13	"strconv"
14	"strings"
15
16	"github.com/gocql/gocql/internal/murmur"
17)
18
19// a token partitioner
20type partitioner interface {
21	Name() string
22	Hash([]byte) token
23	ParseString(string) token
24}
25
26// a token
27type token interface {
28	fmt.Stringer
29	Less(token) bool
30}
31
32// murmur3 partitioner and token
33type murmur3Partitioner struct{}
34type murmur3Token int64
35
36func (p murmur3Partitioner) Name() string {
37	return "Murmur3Partitioner"
38}
39
40func (p murmur3Partitioner) Hash(partitionKey []byte) token {
41	h1 := murmur.Murmur3H1(partitionKey)
42	return murmur3Token(h1)
43}
44
45// murmur3 little-endian, 128-bit hash, but returns only h1
46func (p murmur3Partitioner) ParseString(str string) token {
47	val, _ := strconv.ParseInt(str, 10, 64)
48	return murmur3Token(val)
49}
50
51func (m murmur3Token) String() string {
52	return strconv.FormatInt(int64(m), 10)
53}
54
55func (m murmur3Token) Less(token token) bool {
56	return m < token.(murmur3Token)
57}
58
59// order preserving partitioner and token
60type orderedPartitioner struct{}
61type orderedToken string
62
63func (p orderedPartitioner) Name() string {
64	return "OrderedPartitioner"
65}
66
67func (p orderedPartitioner) Hash(partitionKey []byte) token {
68	// the partition key is the token
69	return orderedToken(partitionKey)
70}
71
72func (p orderedPartitioner) ParseString(str string) token {
73	return orderedToken(str)
74}
75
76func (o orderedToken) String() string {
77	return string(o)
78}
79
80func (o orderedToken) Less(token token) bool {
81	return o < token.(orderedToken)
82}
83
84// random partitioner and token
85type randomPartitioner struct{}
86type randomToken big.Int
87
88func (r randomPartitioner) Name() string {
89	return "RandomPartitioner"
90}
91
92// 2 ** 128
93var maxHashInt, _ = new(big.Int).SetString("340282366920938463463374607431768211456", 10)
94
95func (p randomPartitioner) Hash(partitionKey []byte) token {
96	sum := md5.Sum(partitionKey)
97	val := new(big.Int)
98	val.SetBytes(sum[:])
99	if sum[0] > 127 {
100		val.Sub(val, maxHashInt)
101		val.Abs(val)
102	}
103
104	return (*randomToken)(val)
105}
106
107func (p randomPartitioner) ParseString(str string) token {
108	val := new(big.Int)
109	val.SetString(str, 10)
110	return (*randomToken)(val)
111}
112
113func (r *randomToken) String() string {
114	return (*big.Int)(r).String()
115}
116
117func (r *randomToken) Less(token token) bool {
118	return -1 == (*big.Int)(r).Cmp((*big.Int)(token.(*randomToken)))
119}
120
121type hostToken struct {
122	token token
123	host  *HostInfo
124}
125
126func (ht hostToken) String() string {
127	return fmt.Sprintf("{token=%v host=%v}", ht.token, ht.host.HostID())
128}
129
130// a data structure for organizing the relationship between tokens and hosts
131type tokenRing struct {
132	partitioner partitioner
133	tokens      []hostToken
134}
135
136func newTokenRing(partitioner string, hosts []*HostInfo) (*tokenRing, error) {
137	tokenRing := &tokenRing{}
138
139	if strings.HasSuffix(partitioner, "Murmur3Partitioner") {
140		tokenRing.partitioner = murmur3Partitioner{}
141	} else if strings.HasSuffix(partitioner, "OrderedPartitioner") {
142		tokenRing.partitioner = orderedPartitioner{}
143	} else if strings.HasSuffix(partitioner, "RandomPartitioner") {
144		tokenRing.partitioner = randomPartitioner{}
145	} else {
146		return nil, fmt.Errorf("Unsupported partitioner '%s'", partitioner)
147	}
148
149	for _, host := range hosts {
150		for _, strToken := range host.Tokens() {
151			token := tokenRing.partitioner.ParseString(strToken)
152			tokenRing.tokens = append(tokenRing.tokens, hostToken{token, host})
153		}
154	}
155
156	sort.Sort(tokenRing)
157
158	return tokenRing, nil
159}
160
161func (t *tokenRing) Len() int {
162	return len(t.tokens)
163}
164
165func (t *tokenRing) Less(i, j int) bool {
166	return t.tokens[i].token.Less(t.tokens[j].token)
167}
168
169func (t *tokenRing) Swap(i, j int) {
170	t.tokens[i], t.tokens[j] = t.tokens[j], t.tokens[i]
171}
172
173func (t *tokenRing) String() string {
174	buf := &bytes.Buffer{}
175	buf.WriteString("TokenRing(")
176	if t.partitioner != nil {
177		buf.WriteString(t.partitioner.Name())
178	}
179	buf.WriteString("){")
180	sep := ""
181	for i, th := range t.tokens {
182		buf.WriteString(sep)
183		sep = ","
184		buf.WriteString("\n\t[")
185		buf.WriteString(strconv.Itoa(i))
186		buf.WriteString("]")
187		buf.WriteString(th.token.String())
188		buf.WriteString(":")
189		buf.WriteString(th.host.ConnectAddress().String())
190	}
191	buf.WriteString("\n}")
192	return string(buf.Bytes())
193}
194
195func (t *tokenRing) GetHostForPartitionKey(partitionKey []byte) (host *HostInfo, endToken token) {
196	if t == nil {
197		return nil, nil
198	}
199
200	return t.GetHostForToken(t.partitioner.Hash(partitionKey))
201}
202
203func (t *tokenRing) GetHostForToken(token token) (host *HostInfo, endToken token) {
204	if t == nil || len(t.tokens) == 0 {
205		return nil, nil
206	}
207
208	// find the primary replica
209	ringIndex := sort.Search(len(t.tokens), func(i int) bool {
210		return !t.tokens[i].token.Less(token)
211	})
212
213	if ringIndex == len(t.tokens) {
214		// wrap around to the first in the ring
215		ringIndex = 0
216	}
217
218	v := t.tokens[ringIndex]
219	return v.host, v.token
220}
221