1package s3
2
3import (
4	"bytes"
5	"crypto/md5"
6	"encoding/hex"
7	"fmt"
8	"github.com/keybase/client/go/libkb"
9	"io"
10	"io/ioutil"
11	"sort"
12	"sync"
13
14	"golang.org/x/net/context"
15)
16
17type Mem struct {
18	mc     *MemConn
19	mcMake sync.Once
20}
21
22var _ Root = &Mem{}
23
24func (m *Mem) New(g *libkb.GlobalContext, signer Signer, region Region) Connection {
25	return m.NewMemConn()
26}
27
28type MemConn struct {
29	buckets map[string]*MemBucket
30	sync.Mutex
31}
32
33func (m *Mem) NewMemConn() *MemConn {
34	m.mcMake.Do(func() {
35		m.mc = &MemConn{
36			buckets: make(map[string]*MemBucket),
37		}
38	})
39	return m.mc
40}
41
42var _ Connection = &MemConn{}
43
44func (s *MemConn) SetAccessKey(key string) {}
45
46func (s *MemConn) Bucket(name string) BucketInt {
47	s.Lock()
48	defer s.Unlock()
49	b, ok := s.buckets[name]
50	if ok {
51		return b
52	}
53	b = NewMemBucket(s, name)
54	s.buckets[name] = b
55	return b
56}
57
58func (s *MemConn) AllMultis() []*MemMulti {
59	s.Lock()
60	defer s.Unlock()
61	var all []*MemMulti
62	for _, b := range s.buckets {
63		for _, m := range b.multis {
64			all = append(all, m)
65		}
66	}
67	return all
68}
69
70type MemBucket struct {
71	conn    *MemConn
72	name    string
73	objects map[string][]byte
74	multis  map[string]*MemMulti
75	sync.Mutex
76}
77
78func NewMemBucket(conn *MemConn, name string) *MemBucket {
79	return &MemBucket{
80		conn:    conn,
81		name:    name,
82		objects: make(map[string][]byte),
83		multis:  make(map[string]*MemMulti),
84	}
85}
86
87var _ BucketInt = &MemBucket{}
88
89func (b *MemBucket) GetReader(ctx context.Context, path string) (io.ReadCloser, error) {
90	b.Lock()
91	defer b.Unlock()
92	obj, ok := b.objects[path]
93	if !ok {
94		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
95	}
96	return ioutil.NopCloser(bytes.NewBuffer(obj)), nil
97}
98
99func (b *MemBucket) GetReaderWithRange(ctx context.Context, path string, begin, end int64) (io.ReadCloser, error) {
100	b.Lock()
101	defer b.Unlock()
102	obj, ok := b.objects[path]
103	if !ok {
104		return nil, fmt.Errorf("bucket %q, path %q does not exist", b.name, path)
105	}
106	if end >= int64(len(obj)) {
107		end = int64(len(obj))
108	}
109	return ioutil.NopCloser(bytes.NewBuffer(obj[begin:end])), nil
110}
111
112func (b *MemBucket) PutReader(ctx context.Context, path string, r io.Reader, length int64, contType string, perm ACL, options Options) error {
113	b.Lock()
114	defer b.Unlock()
115
116	var buf bytes.Buffer
117	_, err := buf.ReadFrom(r)
118	if err != nil {
119		return err
120	}
121	b.objects[path] = buf.Bytes()
122
123	return nil
124}
125
126func (b *MemBucket) Del(ctx context.Context, path string) error {
127	b.Lock()
128	defer b.Unlock()
129
130	delete(b.objects, path)
131	return nil
132}
133
134func (b *MemBucket) setObject(path string, data []byte) {
135	b.Lock()
136	defer b.Unlock()
137	b.objects[path] = data
138}
139
140func (b *MemBucket) Multi(ctx context.Context, key, contType string, perm ACL) (MultiInt, error) {
141	b.Lock()
142	defer b.Unlock()
143	m, ok := b.multis[key]
144	if ok {
145		return m, nil
146	}
147	m = NewMemMulti(b, key)
148	b.multis[key] = m
149	return m, nil
150}
151
152type MemMulti struct {
153	bucket      *MemBucket
154	path        string
155	parts       map[int]*part
156	numPutParts int
157	sync.Mutex
158}
159
160var _ MultiInt = &MemMulti{}
161
162func NewMemMulti(b *MemBucket, path string) *MemMulti {
163	return &MemMulti{
164		bucket: b,
165		path:   path,
166		parts:  make(map[int]*part),
167	}
168}
169
170func (m *MemMulti) ListParts(ctx context.Context) ([]Part, error) {
171	m.Lock()
172	defer m.Unlock()
173
174	var ps []Part
175	for _, p := range m.parts {
176		ps = append(ps, p.export())
177	}
178	return ps, nil
179}
180
181func (m *MemMulti) Complete(ctx context.Context, parts []Part) error {
182	m.Lock()
183	defer m.Unlock()
184
185	// match parts coming in with existing parts
186	var scratch partList
187	for _, p := range parts {
188		if pp, ok := m.parts[p.N]; ok {
189			scratch = append(scratch, pp)
190		}
191	}
192
193	// assemble into one block
194	sort.Sort(scratch)
195	var buf bytes.Buffer
196	for _, p := range scratch {
197		buf.Write(p.data)
198	}
199
200	// store in bucket
201	m.bucket.setObject(m.path, buf.Bytes())
202
203	return nil
204}
205
206func (m *MemMulti) PutPart(ctx context.Context, index int, r io.ReadSeeker) (Part, error) {
207	m.Lock()
208	defer m.Unlock()
209
210	var buf bytes.Buffer
211	_, err := buf.ReadFrom(r)
212	if err != nil {
213		return Part{}, err
214	}
215	p := newPart(index, buf)
216	m.parts[index] = p
217
218	m.numPutParts++
219
220	return p.export(), nil
221}
222
223// NumPutParts returns the number of times PutPart was called.
224func (m *MemMulti) NumPutParts() int {
225	m.Lock()
226	defer m.Unlock()
227
228	return m.numPutParts
229}
230
231type part struct {
232	index int
233	hash  string
234	data  []byte
235}
236
237func newPart(index int, buf bytes.Buffer) *part {
238	p := &part{
239		index: index,
240		data:  buf.Bytes(),
241	}
242	h := md5.Sum(p.data)
243	p.hash = hex.EncodeToString(h[:])
244	return p
245}
246
247func (p *part) export() Part {
248	return Part{N: p.index, ETag: `"` + p.hash + `"`, Size: int64(len(p.data))}
249}
250
251type partList []*part
252
253func (x partList) Len() int           { return len(x) }
254func (x partList) Less(a, b int) bool { return x[a].index < x[b].index }
255func (x partList) Swap(a, b int)      { x[a], x[b] = x[b], x[a] }
256