1package mapiter
2
3import (
4	"context"
5	"reflect"
6	"sync"
7
8	"github.com/pkg/errors"
9)
10
11// Iterate creates an iterator from arbitrary map types. This is not
12// the most efficient tool, but it's the quickest way to create an
13// iterator for maps.
14// Also, note that you cannot make any assumptions on the order of
15// pairs being returned.
16func Iterate(ctx context.Context, m interface{}) (Iterator, error) {
17	mrv := reflect.ValueOf(m)
18
19	if mrv.Kind() != reflect.Map {
20		return nil, errors.Errorf(`argument must be a map (%s)`, mrv.Type())
21	}
22
23	ch := make(chan *Pair)
24	go func(ctx context.Context, ch chan *Pair, mrv reflect.Value) {
25		defer close(ch)
26		for _, key := range mrv.MapKeys() {
27			value := mrv.MapIndex(key)
28			pair := &Pair{
29				Key:   key.Interface(),
30				Value: value.Interface(),
31			}
32			select {
33			case <-ctx.Done():
34				return
35			case ch <- pair:
36			}
37		}
38	}(ctx, ch, mrv)
39
40	return New(ch), nil
41}
42
43// Source represents a map that knows how to create an iterator
44type Source interface {
45	Iterate(context.Context) Iterator
46}
47
48// Pair represents a single pair of key and value from a map
49type Pair struct {
50	Key   interface{}
51	Value interface{}
52}
53
54// Iterator iterates through keys and values of a map
55type Iterator interface {
56	Next(context.Context) bool
57	Pair() *Pair
58}
59
60type iter struct {
61	ch   chan *Pair
62	mu   sync.RWMutex
63	next *Pair
64}
65
66// Visitor represents an object that handles each pair in a map
67type Visitor interface {
68	Visit(interface{}, interface{}) error
69}
70
71// VisitorFunc is a type of Visitor based on a function
72type VisitorFunc func(interface{}, interface{}) error
73
74func (fn VisitorFunc) Visit(s interface{}, v interface{}) error {
75	return fn(s, v)
76}
77
78func New(ch chan *Pair) Iterator {
79	return &iter{
80		ch: ch,
81	}
82}
83
84// Next returns true if there are more items to read from the iterator
85func (i *iter) Next(ctx context.Context) bool {
86	i.mu.RLock()
87	if i.ch == nil {
88		i.mu.RUnlock()
89		return false
90	}
91	i.mu.RUnlock()
92
93	i.mu.Lock()
94	defer i.mu.Unlock()
95	select {
96	case <-ctx.Done():
97		i.ch = nil
98		return false
99	case v, ok := <-i.ch:
100		if !ok {
101			i.ch = nil
102			return false
103		}
104		i.next = v
105		return true
106	}
107
108	//nolint:govet
109	return false // never reached
110}
111
112// Pair returns the currently buffered Pair. Calling Next() will reset its value
113func (i *iter) Pair() *Pair {
114	i.mu.RLock()
115	defer i.mu.RUnlock()
116	return i.next
117}
118
119// Walk walks through each element in the map
120func Walk(ctx context.Context, s Source, v Visitor) error {
121	for i := s.Iterate(ctx); i.Next(ctx); {
122		pair := i.Pair()
123		if err := v.Visit(pair.Key, pair.Value); err != nil {
124			return errors.Wrapf(err, `failed to visit key %s`, pair.Key)
125		}
126	}
127	return nil
128}
129
130// AsMap returns the values obtained from the source as a map
131func AsMap(ctx context.Context, s interface{}, v interface{}) error {
132	var iter Iterator
133	switch reflect.ValueOf(s).Kind() {
134	case reflect.Map:
135		x, err := Iterate(ctx, s)
136		if err != nil {
137			return errors.Wrap(err, `failed to iterate over map type`)
138		}
139		iter = x
140	default:
141		ssrc, ok := s.(Source)
142		if !ok {
143			return errors.Errorf(`cannot iterate over %T: not a mapiter.Source type`, s)
144		}
145		iter = ssrc.Iterate(ctx)
146	}
147
148	dst := reflect.ValueOf(v)
149
150	// dst MUST be a pointer to a map type
151	if kind := dst.Kind(); kind != reflect.Ptr {
152		return errors.Errorf(`dst must be a pointer to a map (%s)`, dst.Type())
153	}
154
155	dst = dst.Elem()
156	if dst.Kind() != reflect.Map {
157		return errors.Errorf(`dst must be a pointer to a map (%s)`, dst.Type())
158	}
159
160	if dst.IsNil() {
161		dst.Set(reflect.MakeMap(dst.Type()))
162	}
163
164	// dst must be assignable
165	if !dst.CanSet() {
166		return errors.New(`dst is not writeable`)
167	}
168
169	keytyp := dst.Type().Key()
170	valtyp := dst.Type().Elem()
171
172	for iter.Next(ctx) {
173		pair := iter.Pair()
174
175		rvkey := reflect.ValueOf(pair.Key)
176		rvvalue := reflect.ValueOf(pair.Value)
177
178		if !rvkey.Type().AssignableTo(keytyp) {
179			return errors.Errorf(`cannot assign key of type %s to map key of type %s`, rvkey.Type(), keytyp)
180		}
181
182		switch rvvalue.Kind() {
183		// we can only check if we can assign to rvvalue to valtyp if it's non-nil
184		case reflect.Invalid:
185			rvvalue = reflect.New(valtyp).Elem()
186		default:
187			if !rvvalue.Type().AssignableTo(valtyp) {
188				return errors.Errorf(`cannot assign value of type %s to map value of type %s`, rvvalue.Type(), valtyp)
189			}
190		}
191
192		dst.SetMapIndex(rvkey, rvvalue)
193	}
194
195	return nil
196}
197