1package solver
2
3import (
4	"context"
5	"strings"
6	"sync"
7	"time"
8
9	digest "github.com/opencontainers/go-digest"
10	"github.com/pkg/errors"
11	"golang.org/x/sync/errgroup"
12)
13
14func newCombinedCacheManager(cms []CacheManager, main CacheManager) CacheManager {
15	return &combinedCacheManager{cms: cms, main: main}
16}
17
18type combinedCacheManager struct {
19	cms    []CacheManager
20	main   CacheManager
21	id     string
22	idOnce sync.Once
23}
24
25func (cm *combinedCacheManager) ID() string {
26	cm.idOnce.Do(func() {
27		ids := make([]string, len(cm.cms))
28		for i, c := range cm.cms {
29			ids[i] = c.ID()
30		}
31		cm.id = digest.FromBytes([]byte(strings.Join(ids, ","))).String()
32	})
33	return cm.id
34}
35
36func (cm *combinedCacheManager) Query(inp []CacheKeyWithSelector, inputIndex Index, dgst digest.Digest, outputIndex Index) ([]*CacheKey, error) {
37	eg, _ := errgroup.WithContext(context.TODO())
38	keys := make(map[string]*CacheKey, len(cm.cms))
39	var mu sync.Mutex
40	for _, c := range cm.cms {
41		func(c CacheManager) {
42			eg.Go(func() error {
43				recs, err := c.Query(inp, inputIndex, dgst, outputIndex)
44				if err != nil {
45					return err
46				}
47				mu.Lock()
48				for _, r := range recs {
49					if _, ok := keys[r.ID]; !ok || c == cm.main {
50						keys[r.ID] = r
51					}
52				}
53				mu.Unlock()
54				return nil
55			})
56		}(c)
57	}
58
59	if err := eg.Wait(); err != nil {
60		return nil, err
61	}
62
63	out := make([]*CacheKey, 0, len(keys))
64	for _, k := range keys {
65		out = append(out, k)
66	}
67	return out, nil
68}
69
70func (cm *combinedCacheManager) Load(ctx context.Context, rec *CacheRecord) (res Result, err error) {
71	results, err := rec.cacheManager.LoadWithParents(ctx, rec)
72	if err != nil {
73		return nil, err
74	}
75	defer func() {
76		for i, res := range results {
77			if err == nil && i == 0 {
78				continue
79			}
80			res.Result.Release(context.TODO())
81		}
82	}()
83	if rec.cacheManager != cm.main {
84		for _, res := range results {
85			if _, err := cm.main.Save(res.CacheKey, res.Result, res.CacheResult.CreatedAt); err != nil {
86				return nil, err
87			}
88		}
89	}
90	return results[0].Result, nil
91}
92
93func (cm *combinedCacheManager) Save(key *CacheKey, s Result, createdAt time.Time) (*ExportableCacheKey, error) {
94	return cm.main.Save(key, s, createdAt)
95}
96
97func (cm *combinedCacheManager) Records(ck *CacheKey) ([]*CacheRecord, error) {
98	if len(ck.ids) == 0 {
99		return nil, errors.Errorf("no results")
100	}
101
102	records := map[string]*CacheRecord{}
103	var mu sync.Mutex
104
105	eg, _ := errgroup.WithContext(context.TODO())
106	for c := range ck.ids {
107		func(c *cacheManager) {
108			eg.Go(func() error {
109				recs, err := c.Records(ck)
110				if err != nil {
111					return err
112				}
113				mu.Lock()
114				for _, rec := range recs {
115					if _, ok := records[rec.ID]; !ok || c == cm.main {
116						if c == cm.main {
117							rec.Priority = 1
118						}
119						records[rec.ID] = rec
120					}
121				}
122				mu.Unlock()
123				return nil
124			})
125		}(c)
126	}
127
128	if err := eg.Wait(); err != nil {
129		return nil, err
130	}
131
132	out := make([]*CacheRecord, 0, len(records))
133	for _, rec := range records {
134		out = append(out, rec)
135	}
136	return out, nil
137}
138