1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"crypto/sha256"
18	"math/big"
19	"math/rand"
20)
21
22// messageRouter outputs a partition number, given an ordering key (which it may
23// ignore).
24type messageRouter interface {
25	Route(orderingKey []byte) int
26}
27
28// roundRobinMsgRouter sequentially cycles through partition numbers, starting
29// from a random partition.
30type roundRobinMsgRouter struct {
31	rng            *rand.Rand
32	partitionCount int
33	nextPartition  int
34}
35
36func newRoundRobinMsgRouter(rng *rand.Rand, count int) *roundRobinMsgRouter {
37	return &roundRobinMsgRouter{
38		rng:            rng,
39		partitionCount: count,
40		nextPartition:  int(rng.Int63n(int64(count))),
41	}
42}
43
44func (r *roundRobinMsgRouter) Route(_ []byte) (partition int) {
45	partition = r.nextPartition
46	r.nextPartition = (partition + 1) % r.partitionCount
47	return
48}
49
50// hashingMsgRouter hashes an ordering key using SHA256 to obtain a partition
51// number. It should only be used for messages with an ordering key.
52//
53// Matches implementation at:
54// https://github.com/googleapis/java-pubsublite/blob/master/google-cloud-pubsublite/src/main/java/com/google/cloud/pubsublite/internal/DefaultRoutingPolicy.java
55type hashingMsgRouter struct {
56	partitionCount *big.Int
57}
58
59func newHashingMsgRouter(count int) *hashingMsgRouter {
60	return &hashingMsgRouter{
61		partitionCount: big.NewInt(int64(count)),
62	}
63}
64
65func (r *hashingMsgRouter) Route(orderingKey []byte) int {
66	if len(orderingKey) == 0 {
67		return -1
68	}
69	h := sha256.Sum256(orderingKey)
70	num := new(big.Int).SetBytes(h[:])
71	partition := new(big.Int).Mod(num, r.partitionCount)
72	return int(partition.Int64())
73}
74
75// compositeMsgRouter delegates to different message routers for messages
76// with/without ordering keys.
77type compositeMsgRouter struct {
78	keyedRouter   messageRouter
79	keylessRouter messageRouter
80}
81
82func (r *compositeMsgRouter) Route(orderingKey []byte) int {
83	if len(orderingKey) > 0 {
84		return r.keyedRouter.Route(orderingKey)
85	}
86	return r.keylessRouter.Route(orderingKey)
87}
88
89type messageRouterFactory struct {
90	rng *rand.Rand
91}
92
93func newMessageRouterFactory(rng *rand.Rand) *messageRouterFactory {
94	return &messageRouterFactory{rng: rng}
95}
96
97// New returns a compositeMsgRouter that uses hashingMsgRouter for messages with
98// ordering key and roundRobinMsgRouter for messages without.
99func (f *messageRouterFactory) New(partitionCount int) messageRouter {
100	return &compositeMsgRouter{
101		keyedRouter:   newHashingMsgRouter(partitionCount),
102		keylessRouter: newRoundRobinMsgRouter(f.rng, partitionCount),
103	}
104}
105