1package ipinfo
2
3import (
4	"bytes"
5	"context"
6	"encoding/json"
7	"net"
8	"strings"
9	"sync"
10	"time"
11
12	"golang.org/x/sync/errgroup"
13)
14
15const (
16	batchMaxSize           = 1000
17	batchReqTimeoutDefault = 5
18)
19
20// Internal batch type used by common batch functionality to temporarily store
21// the URL-to-result mapping in a half-decoded state (specifically the value
22// not being decoded yet). This allows us to decode the value to a proper
23// concrete type like `Core` or `ASNDetails` after analyzing the key to
24// determine which one it should be.
25type batch map[string]json.RawMessage
26
27// Batch is a mapped result of any valid API endpoint (e.g. `<ip>`,
28// `<ip>/<field>`, `<asn>`, etc) to its corresponding data.
29//
30// The corresponding value will be either `*Core`, `*ASNDetails` or a generic
31// map for unknown value results.
32type Batch map[string]interface{}
33
34// BatchCore is a mapped result of IPs to their corresponding `Core` data.
35type BatchCore map[string]*Core
36
37// BatchASNDetails is a mapped result of ASNs to their corresponding
38// `ASNDetails` data.
39type BatchASNDetails map[string]*ASNDetails
40
41// BatchReqOpts are options input into batch request functions.
42type BatchReqOpts struct {
43	// BatchSize is the internal batch size used per API request; the IPinfo
44	// API has a maximum batch size, but the batch request functions available
45	// in this library do not. Therefore the library chunks the input slices
46	// internally into chunks of size `BatchSize`, clipping to the maximum
47	// allowed by the IPinfo API.
48	//
49	// 0 means to use the default batch size which is the max allowed by the
50	// IPinfo API.
51	BatchSize uint32
52
53	// TimeoutPerBatch is the timeout in seconds that each batch of size
54	// `BatchSize` will have for its own request.
55	//
56	// 0 means to use a default of 5 seconds; any negative number will turn it
57	// off; turning it off does _not_ disable the effects of `TimeoutTotal`.
58	TimeoutPerBatch int64
59
60	// TimeoutTotal is the total timeout in seconds for all batch requests in a
61	// batch request function to complete.
62	//
63	// 0 means no total timeout; `TimeoutPerBatch` will still apply.
64	TimeoutTotal uint64
65
66	// Filter, if turned on, will filter out a URL whose value was deemed empty
67	// on the server.
68	Filter bool
69}
70
71/* GENERIC */
72
73// GetBatch does a batch request for all `urls` at once.
74func GetBatch(
75	urls []string,
76	opts BatchReqOpts,
77) (Batch, error) {
78	return DefaultClient.GetBatch(urls, opts)
79}
80
81// GetBatch does a batch request for all `urls` at once.
82func (c *Client) GetBatch(
83	urls []string,
84	opts BatchReqOpts,
85) (Batch, error) {
86	var batchSize int
87	var timeoutPerBatch int64
88	var totalTimeoutCtx context.Context
89	var totalTimeoutCancel context.CancelFunc
90	var lookupUrls []string
91	var result Batch
92	var mu sync.Mutex
93
94	// if the cache is available, filter out URLs already cached.
95	result = make(Batch, len(urls))
96	if c.Cache != nil {
97		lookupUrls = make([]string, 0, len(urls)/2)
98		for _, url := range urls {
99			if res, err := c.Cache.Get(cacheKey(url)); err == nil {
100				result[url] = res
101			} else {
102				lookupUrls = append(lookupUrls, url)
103			}
104		}
105	} else {
106		lookupUrls = urls
107	}
108
109	// everything cached; exit early.
110	if len(lookupUrls) == 0 {
111		return result, nil
112	}
113
114	// use correct batch size; default/clip to `batchMaxSize`.
115	if opts.BatchSize == 0 || opts.BatchSize > batchMaxSize {
116		batchSize = batchMaxSize
117	} else {
118		batchSize = int(opts.BatchSize)
119	}
120
121	// use correct timeout per batch; either default or user-provided.
122	if opts.TimeoutPerBatch == 0 {
123		timeoutPerBatch = batchReqTimeoutDefault
124	} else {
125		timeoutPerBatch = opts.TimeoutPerBatch
126	}
127
128	// use correct timeout total; either ignore it or apply user-provided.
129	if opts.TimeoutTotal > 0 {
130		totalTimeoutCtx, totalTimeoutCancel = context.WithTimeout(
131			context.Background(),
132			time.Duration(opts.TimeoutTotal)*time.Second,
133		)
134		defer totalTimeoutCancel()
135	} else {
136		totalTimeoutCtx = context.Background()
137	}
138
139	errg, ctx := errgroup.WithContext(totalTimeoutCtx)
140	for i := 0; i < len(lookupUrls); i += batchSize {
141		end := i + batchSize
142		if end > len(lookupUrls) {
143			end = len(lookupUrls)
144		}
145
146		urlsChunk := lookupUrls[i:end]
147		errg.Go(func() error {
148			var postURL string
149
150			// prepare request.
151
152			var timeoutPerBatchCtx context.Context
153			var timeoutPerBatchCancel context.CancelFunc
154			if timeoutPerBatch > 0 {
155				timeoutPerBatchCtx, timeoutPerBatchCancel = context.WithTimeout(
156					ctx,
157					time.Duration(timeoutPerBatch)*time.Second,
158				)
159				defer timeoutPerBatchCancel()
160			} else {
161				timeoutPerBatchCtx = context.Background()
162			}
163
164			if opts.Filter {
165				postURL = "batch?filter=1"
166			} else {
167				postURL = "batch"
168			}
169
170			jsonArrStr, err := json.Marshal(urlsChunk)
171			if err != nil {
172				return err
173			}
174			jsonBuf := bytes.NewBuffer(jsonArrStr)
175
176			req, err := c.newRequest(timeoutPerBatchCtx, "POST", postURL, jsonBuf)
177			if err != nil {
178				return err
179			}
180			req.Header.Set("Content-Type", "application/json")
181
182			// temporarily make a new local result map so that we can read the
183			// network data into it; once we have it local we'll merge it with
184			// `result` in a concurrency-safe way.
185			localResult := new(batch)
186			if _, err := c.do(req, localResult); err != nil {
187				return err
188			}
189
190			// update final result.
191			mu.Lock()
192			defer mu.Unlock()
193			for k, v := range *localResult {
194				if strings.HasPrefix(k, "AS") {
195					decodedV := new(ASNDetails)
196					if err := json.Unmarshal(v, decodedV); err != nil {
197						return err
198					}
199
200					decodedV.setCountryName()
201					result[k] = decodedV
202				} else if net.ParseIP(k) != nil {
203					decodedV := new(Core)
204					if err := json.Unmarshal(v, decodedV); err != nil {
205						return err
206					}
207
208					decodedV.setCountryName()
209					result[k] = decodedV
210				} else {
211					decodedV := new(interface{})
212					if err := json.Unmarshal(v, decodedV); err != nil {
213						return err
214					}
215
216					result[k] = decodedV
217				}
218			}
219
220			return nil
221		})
222	}
223	if err := errg.Wait(); err != nil {
224		return result, err
225	}
226
227	// we delay inserting into the cache until now because:
228	// 1. it's likely more cache-line friendly.
229	// 2. doing it while updating `result` inside the request workers would be
230	//    problematic if the cache is external since we take a mutex lock for
231	//    that entire period.
232	if c.Cache != nil {
233		for _, url := range lookupUrls {
234			if v, exists := result[url]; exists {
235				if err := c.Cache.Set(cacheKey(url), v); err != nil {
236					// NOTE: still return the result even if the cache fails.
237					return result, err
238				}
239			}
240		}
241	}
242
243	return result, nil
244}
245
246/* CORE (net.IP) */
247
248// GetIPInfoBatch does a batch request for all `ips` at once.
249func GetIPInfoBatch(
250	ips []net.IP,
251	opts BatchReqOpts,
252) (BatchCore, error) {
253	return DefaultClient.GetIPInfoBatch(ips, opts)
254}
255
256// GetIPInfoBatch does a batch request for all `ips` at once.
257func (c *Client) GetIPInfoBatch(
258	ips []net.IP,
259	opts BatchReqOpts,
260) (BatchCore, error) {
261	ipstrs := make([]string, 0, len(ips))
262	for _, ip := range ips {
263		ipstrs = append(ipstrs, ip.String())
264	}
265
266	return c.GetIPStrInfoBatch(ipstrs, opts)
267}
268
269/* CORE (string) */
270
271// GetIPStrInfoBatch does a batch request for all `ips` at once.
272func GetIPStrInfoBatch(
273	ips []string,
274	opts BatchReqOpts,
275) (BatchCore, error) {
276	return DefaultClient.GetIPStrInfoBatch(ips, opts)
277}
278
279// GetIPStrInfoBatch does a batch request for all `ips` at once.
280func (c *Client) GetIPStrInfoBatch(
281	ips []string,
282	opts BatchReqOpts,
283) (BatchCore, error) {
284	intermediateRes, err := c.GetBatch(ips, opts)
285
286	// if we have items in the result, don't throw them away; we'll convert
287	// below and return the error together if it existed.
288	if err != nil && len(intermediateRes) == 0 {
289		return nil, err
290	}
291
292	res := make(BatchCore, len(intermediateRes))
293	for k, v := range intermediateRes {
294		res[k] = v.(*Core)
295	}
296
297	return res, err
298}
299
300/* ASN */
301
302// GetASNDetailsBatch does a batch request for all `asns` at once.
303func GetASNDetailsBatch(
304	asns []string,
305	opts BatchReqOpts,
306) (BatchASNDetails, error) {
307	return DefaultClient.GetASNDetailsBatch(asns, opts)
308}
309
310// GetASNDetailsBatch does a batch request for all `asns` at once.
311func (c *Client) GetASNDetailsBatch(
312	asns []string,
313	opts BatchReqOpts,
314) (BatchASNDetails, error) {
315	intermediateRes, err := c.GetBatch(asns, opts)
316
317	// if we have items in the result, don't throw them away; we'll convert
318	// below and return the error together if it existed.
319	if err != nil && len(intermediateRes) == 0 {
320		return nil, err
321	}
322
323	res := make(BatchASNDetails, len(intermediateRes))
324	for k, v := range intermediateRes {
325		res[k] = v.(*ASNDetails)
326	}
327	return res, err
328}
329