1// Copyright 2021 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"sync"
18	"time"
19)
20
21type requestTimerStatus int
22
23const (
24	requestTimerNew requestTimerStatus = iota
25	requestTimerStopped
26	requestTimerTriggered
27)
28
29// requestTimer bounds the duration of a request and executes `onTimeout` if
30// the timer is triggered.
31type requestTimer struct {
32	onTimeout  func()
33	timeoutErr error
34	timer      *time.Timer
35	mu         sync.Mutex
36	status     requestTimerStatus
37}
38
39func newRequestTimer(duration time.Duration, onTimeout func(), timeoutErr error) *requestTimer {
40	rt := &requestTimer{
41		onTimeout:  onTimeout,
42		timeoutErr: timeoutErr,
43		status:     requestTimerNew,
44	}
45	rt.timer = time.AfterFunc(duration, rt.onTriggered)
46	return rt
47}
48
49func (rt *requestTimer) onTriggered() {
50	rt.mu.Lock()
51	defer rt.mu.Unlock()
52	if rt.status == requestTimerNew {
53		rt.status = requestTimerTriggered
54		rt.onTimeout()
55	}
56}
57
58// Stop should be called upon a successful request to prevent the timer from
59// expiring.
60func (rt *requestTimer) Stop() {
61	rt.mu.Lock()
62	defer rt.mu.Unlock()
63	if rt.status == requestTimerNew {
64		rt.status = requestTimerStopped
65		rt.timer.Stop()
66	}
67}
68
69// ResolveError returns `timeoutErr` if the timer triggered, or otherwise
70// `originalErr`.
71func (rt *requestTimer) ResolveError(originalErr error) error {
72	rt.mu.Lock()
73	defer rt.mu.Unlock()
74	if rt.status == requestTimerTriggered {
75		return rt.timeoutErr
76	}
77	return originalErr
78}
79