1// Copyright 2017 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package loadtest implements load testing for pubsub,
16// following the interface defined in https://github.com/GoogleCloudPlatform/pubsub/tree/master/load-test-framework/ .
17//
18// This package is experimental.
19package loadtest
20
21import (
22	"bytes"
23	"context"
24	"errors"
25	"log"
26	"runtime"
27	"strconv"
28	"sync"
29	"sync/atomic"
30	"time"
31
32	"cloud.google.com/go/pubsub"
33	pb "cloud.google.com/go/pubsub/loadtest/pb"
34	"golang.org/x/time/rate"
35)
36
37type pubServerConfig struct {
38	topic     *pubsub.Topic
39	msgData   []byte
40	batchSize int32
41}
42
43// PubServer is a dummy Pub/Sub server for load testing.
44type PubServer struct {
45	ID string
46
47	cfg    atomic.Value
48	seqNum int32
49}
50
51// Start starts the server.
52func (l *PubServer) Start(ctx context.Context, req *pb.StartRequest) (*pb.StartResponse, error) {
53	log.Println("received start")
54	c, err := pubsub.NewClient(ctx, req.Project)
55	if err != nil {
56		return nil, err
57	}
58	dur := req.PublishBatchDuration.AsDuration()
59	l.init(c, req.Topic, req.MessageSize, req.PublishBatchSize, dur)
60	log.Println("started")
61	return &pb.StartResponse{}, nil
62}
63
64func (l *PubServer) init(c *pubsub.Client, topicName string, msgSize, batchSize int32, batchDur time.Duration) {
65	topic := c.Topic(topicName)
66	topic.PublishSettings = pubsub.PublishSettings{
67		DelayThreshold: batchDur,
68		CountThreshold: 950,
69		ByteThreshold:  9500000,
70	}
71
72	l.cfg.Store(pubServerConfig{
73		topic:     topic,
74		msgData:   bytes.Repeat([]byte{'A'}, int(msgSize)),
75		batchSize: batchSize,
76	})
77}
78
79// Execute executes a request.
80func (l *PubServer) Execute(ctx context.Context, _ *pb.ExecuteRequest) (*pb.ExecuteResponse, error) {
81	latencies, err := l.publishBatch()
82	if err != nil {
83		log.Printf("error: %v", err)
84		return nil, err
85	}
86	return &pb.ExecuteResponse{Latencies: latencies}, nil
87}
88
89func (l *PubServer) publishBatch() ([]int64, error) {
90	var cfg pubServerConfig
91	if c, ok := l.cfg.Load().(pubServerConfig); ok {
92		cfg = c
93	} else {
94		return nil, errors.New("config not loaded")
95	}
96
97	start := time.Now()
98	latencies := make([]int64, cfg.batchSize)
99	startStr := strconv.FormatInt(start.UnixNano()/1e6, 10)
100	seqNum := atomic.AddInt32(&l.seqNum, cfg.batchSize) - cfg.batchSize
101
102	rs := make([]*pubsub.PublishResult, cfg.batchSize)
103	for i := int32(0); i < cfg.batchSize; i++ {
104		rs[i] = cfg.topic.Publish(context.TODO(), &pubsub.Message{
105			Data: cfg.msgData,
106			Attributes: map[string]string{
107				"sendTime":       startStr,
108				"clientId":       l.ID,
109				"sequenceNumber": strconv.Itoa(int(seqNum + i)),
110			},
111		})
112	}
113	for i, r := range rs {
114		_, err := r.Get(context.Background())
115		if err != nil {
116			return nil, err
117		}
118		// TODO(jba,pongad): fix latencies
119		// Later values will be skewed by earlier ones, since we wait for the
120		// results in order. (On the other hand, it may not matter much, since
121		// messages are added to bundles in order and bundles get sent more or
122		// less in order.) If we want more accurate values, we can either start
123		// a goroutine for each result (similar to the original code using a
124		// callback), or call reflect.Select with the Ready channels of the
125		// results.
126		latencies[i] = time.Since(start).Nanoseconds() / 1e6
127	}
128	return latencies, nil
129}
130
131// SubServer is a dummy Pub/Sub server for load testing.
132type SubServer struct {
133	// TODO(deklerk): what is this actually for?
134	lim *rate.Limiter
135
136	mu        sync.Mutex
137	idents    []*pb.MessageIdentifier
138	latencies []int64
139}
140
141// Start starts the server.
142func (s *SubServer) Start(ctx context.Context, req *pb.StartRequest) (*pb.StartResponse, error) {
143	log.Println("received start")
144	s.lim = rate.NewLimiter(rate.Every(time.Second), 1)
145
146	c, err := pubsub.NewClient(ctx, req.Project)
147	if err != nil {
148		return nil, err
149	}
150
151	// Load test API doesn't define any way to stop right now.
152	go func() {
153		sub := c.Subscription(req.GetPubsubOptions().Subscription)
154		sub.ReceiveSettings.NumGoroutines = 10 * runtime.GOMAXPROCS(0)
155		err := sub.Receive(context.Background(), s.callback)
156		log.Fatal(err)
157	}()
158
159	log.Println("started")
160	return &pb.StartResponse{}, nil
161}
162
163func (s *SubServer) callback(_ context.Context, m *pubsub.Message) {
164	id, err := strconv.ParseInt(m.Attributes["clientId"], 10, 64)
165	if err != nil {
166		log.Println(err)
167		m.Nack()
168		return
169	}
170
171	seqNum, err := strconv.ParseInt(m.Attributes["sequenceNumber"], 10, 32)
172	if err != nil {
173		log.Println(err)
174		m.Nack()
175		return
176	}
177
178	sendTimeMillis, err := strconv.ParseInt(m.Attributes["sendTime"], 10, 64)
179	if err != nil {
180		log.Println(err)
181		m.Nack()
182		return
183	}
184
185	latency := time.Now().UnixNano()/1e6 - sendTimeMillis
186	ident := &pb.MessageIdentifier{
187		PublisherClientId: id,
188		SequenceNumber:    int32(seqNum),
189	}
190
191	s.mu.Lock()
192	s.idents = append(s.idents, ident)
193	s.latencies = append(s.latencies, latency)
194	s.mu.Unlock()
195	m.Ack()
196}
197
198// Execute executes the request.
199func (s *SubServer) Execute(ctx context.Context, _ *pb.ExecuteRequest) (*pb.ExecuteResponse, error) {
200	// Throttle so the load tester doesn't spam us and consume all our CPU.
201	if err := s.lim.Wait(ctx); err != nil {
202		return nil, err
203	}
204
205	s.mu.Lock()
206	idents := s.idents
207	s.idents = nil
208	latencies := s.latencies
209	s.latencies = nil
210	s.mu.Unlock()
211
212	return &pb.ExecuteResponse{
213		Latencies:        latencies,
214		ReceivedMessages: idents,
215	}, nil
216}
217