1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package storagenodedb
5
6import (
7	"context"
8	"database/sql"
9	"errors"
10	"time"
11
12	"github.com/zeebo/errs"
13
14	"storj.io/common/pb"
15	"storj.io/common/storj"
16	"storj.io/private/tagsql"
17	"storj.io/storj/storagenode/orders"
18	"storj.io/storj/storagenode/orders/ordersfile"
19)
20
21// ErrOrders represents errors from the ordersdb database.
22var ErrOrders = errs.Class("ordersdb")
23
24// OrdersDBName represents the database name.
25const OrdersDBName = "orders"
26
27type ordersDB struct {
28	dbContainerImpl
29}
30
31// Enqueue inserts order to the unsent list.
32func (db *ordersDB) Enqueue(ctx context.Context, info *ordersfile.Info) (err error) {
33	defer mon.Task()(&ctx)(&err)
34
35	limitSerialized, err := pb.Marshal(info.Limit)
36	if err != nil {
37		return ErrOrders.Wrap(err)
38	}
39
40	orderSerialized, err := pb.Marshal(info.Order)
41	if err != nil {
42		return ErrOrders.Wrap(err)
43	}
44
45	// TODO: remove uplink_cert_id
46	_, err = db.ExecContext(ctx, `
47		INSERT INTO unsent_order(
48			satellite_id, serial_number,
49			order_limit_serialized, order_serialized, order_limit_expiration,
50			uplink_cert_id
51		) VALUES (?,?, ?,?,?, ?)
52	`, info.Limit.SatelliteId, info.Limit.SerialNumber, limitSerialized, orderSerialized, info.Limit.OrderExpiration.UTC(), 0)
53
54	return ErrOrders.Wrap(err)
55}
56
57// ListUnsent returns orders that haven't been sent yet.
58//
59// If there is some unmarshal error while reading an order, the method proceed
60// with the following ones and the function will return the ones which have
61// been successfully read but returning an error with information of the ones
62// which have not. In case of database or other system error, the method will
63// stop without any further processing and will return an error without any
64// order.
65func (db *ordersDB) ListUnsent(ctx context.Context, limit int) (_ []*ordersfile.Info, err error) {
66	defer mon.Task()(&ctx)(&err)
67
68	rows, err := db.QueryContext(ctx, `
69		SELECT order_limit_serialized, order_serialized
70		FROM unsent_order
71		LIMIT ?
72	`, limit)
73	if err != nil {
74		if errors.Is(err, sql.ErrNoRows) {
75			return nil, nil
76		}
77		return nil, ErrOrders.Wrap(err)
78	}
79
80	var unmarshalErrors errs.Group
81	defer func() { err = errs.Combine(err, unmarshalErrors.Err(), rows.Close()) }()
82
83	var infos []*ordersfile.Info
84	for rows.Next() {
85		var limitSerialized []byte
86		var orderSerialized []byte
87
88		err := rows.Scan(&limitSerialized, &orderSerialized)
89		if err != nil {
90			return nil, ErrOrders.Wrap(err)
91		}
92
93		var info ordersfile.Info
94		info.Limit = &pb.OrderLimit{}
95		info.Order = &pb.Order{}
96
97		err = pb.Unmarshal(limitSerialized, info.Limit)
98		if err != nil {
99			unmarshalErrors.Add(ErrOrders.Wrap(err))
100			continue
101		}
102
103		err = pb.Unmarshal(orderSerialized, info.Order)
104		if err != nil {
105			unmarshalErrors.Add(ErrOrders.Wrap(err))
106			continue
107		}
108
109		infos = append(infos, &info)
110	}
111
112	return infos, ErrOrders.Wrap(rows.Err())
113}
114
115// ListUnsentBySatellite returns orders that haven't been sent yet and are not expired.
116// The orders are ordered by the Satellite ID.
117//
118// If there is some unmarshal error while reading an order, the method proceed
119// with the following ones and the function will return the ones which have
120// been successfully read but returning an error with information of the ones
121// which have not. In case of database or other system error, the method will
122// stop without any further processing and will return an error without any
123// order.
124func (db *ordersDB) ListUnsentBySatellite(ctx context.Context) (_ map[storj.NodeID][]*ordersfile.Info, err error) {
125	defer mon.Task()(&ctx)(&err)
126	// TODO: add some limiting
127
128	rows, err := db.QueryContext(ctx, `
129		SELECT order_limit_serialized, order_serialized
130		FROM unsent_order
131		WHERE order_limit_expiration >= $1
132	`, time.Now().UTC())
133	if err != nil {
134		if errors.Is(err, sql.ErrNoRows) {
135			return nil, nil
136		}
137		return nil, ErrOrders.Wrap(err)
138	}
139
140	var unmarshalErrors errs.Group
141	defer func() { err = errs.Combine(err, unmarshalErrors.Err(), rows.Close()) }()
142
143	infos := map[storj.NodeID][]*ordersfile.Info{}
144	for rows.Next() {
145		var limitSerialized []byte
146		var orderSerialized []byte
147
148		err := rows.Scan(&limitSerialized, &orderSerialized)
149		if err != nil {
150			return nil, ErrOrders.Wrap(err)
151		}
152
153		var info ordersfile.Info
154		info.Limit = &pb.OrderLimit{}
155		info.Order = &pb.Order{}
156
157		err = pb.Unmarshal(limitSerialized, info.Limit)
158		if err != nil {
159			unmarshalErrors.Add(ErrOrders.Wrap(err))
160			continue
161		}
162
163		err = pb.Unmarshal(orderSerialized, info.Order)
164		if err != nil {
165			unmarshalErrors.Add(ErrOrders.Wrap(err))
166			continue
167		}
168
169		infos[info.Limit.SatelliteId] = append(infos[info.Limit.SatelliteId], &info)
170	}
171
172	return infos, ErrOrders.Wrap(rows.Err())
173}
174
175// Archive marks order as being handled.
176//
177// If any of the request contains an order which doesn't exist the method will
178// follow with the next ones without interrupting the operation and it will
179// return an error of the class orders.OrderNotFoundError. Any other error, will
180// abort the operation, rolling back the transaction.
181func (db *ordersDB) Archive(ctx context.Context, archivedAt time.Time, requests ...orders.ArchiveRequest) (err error) {
182	defer mon.Task()(&ctx)(&err)
183
184	// change input parameter to UTC timezone before we send it to the database
185	archivedAt = archivedAt.UTC()
186
187	tx, err := db.BeginTx(ctx, nil)
188	if err != nil {
189		return ErrOrders.Wrap(err)
190	}
191
192	var notFoundErrs errs.Group
193	defer func() {
194		if err == nil {
195			err = tx.Commit()
196			if err == nil {
197				if len(notFoundErrs) > 0 {
198					// Return a class error to allow to the caler to identify this case
199					err = orders.OrderNotFoundError.Wrap(notFoundErrs.Err())
200				}
201			}
202		} else {
203			err = errs.Combine(err, tx.Rollback())
204		}
205	}()
206
207	for _, req := range requests {
208		err := db.archiveOne(ctx, tx, archivedAt, req)
209		if err != nil {
210			if orders.OrderNotFoundError.Has(err) {
211				notFoundErrs.Add(err)
212				continue
213			}
214
215			return err
216		}
217	}
218
219	return nil
220}
221
222// archiveOne marks order as being handled.
223func (db *ordersDB) archiveOne(ctx context.Context, tx tagsql.Tx, archivedAt time.Time, req orders.ArchiveRequest) (err error) {
224	defer mon.Task()(&ctx)(&err)
225
226	result, err := tx.ExecContext(ctx, `
227		INSERT INTO order_archive_ (
228			satellite_id, serial_number,
229			order_limit_serialized, order_serialized,
230			uplink_cert_id,
231			status, archived_at
232		) SELECT
233			satellite_id, serial_number,
234			order_limit_serialized, order_serialized,
235			uplink_cert_id,
236			?, ?
237		FROM unsent_order
238		WHERE satellite_id = ? AND serial_number = ?;
239
240		DELETE FROM unsent_order
241		WHERE satellite_id = ? AND serial_number = ?;
242	`, int(req.Status), archivedAt, req.Satellite, req.Serial, req.Satellite, req.Serial)
243	if err != nil {
244		return ErrOrders.Wrap(err)
245	}
246
247	count, err := result.RowsAffected()
248	if err != nil {
249		return ErrOrders.Wrap(err)
250	}
251	if count == 0 {
252		return orders.OrderNotFoundError.New("satellite: %s, serial number: %s",
253			req.Satellite.String(), req.Serial.String(),
254		)
255	}
256
257	return nil
258}
259
260// ListArchived returns orders that have been sent.
261func (db *ordersDB) ListArchived(ctx context.Context, limit int) (_ []*orders.ArchivedInfo, err error) {
262	defer mon.Task()(&ctx)(&err)
263
264	rows, err := db.QueryContext(ctx, `
265		SELECT order_limit_serialized, order_serialized, status, archived_at
266		FROM order_archive_
267		LIMIT ?
268	`, limit)
269	if err != nil {
270		if errors.Is(err, sql.ErrNoRows) {
271			return nil, nil
272		}
273		return nil, ErrOrders.Wrap(err)
274	}
275	defer func() { err = errs.Combine(err, rows.Close()) }()
276
277	var infos []*orders.ArchivedInfo
278	for rows.Next() {
279		var limitSerialized []byte
280		var orderSerialized []byte
281
282		var status int
283		var archivedAt time.Time
284
285		err := rows.Scan(&limitSerialized, &orderSerialized, &status, &archivedAt)
286		if err != nil {
287			return nil, ErrOrders.Wrap(err)
288		}
289
290		var info orders.ArchivedInfo
291		info.Limit = &pb.OrderLimit{}
292		info.Order = &pb.Order{}
293
294		info.Status = orders.Status(status)
295		info.ArchivedAt = archivedAt
296
297		err = pb.Unmarshal(limitSerialized, info.Limit)
298		if err != nil {
299			return nil, ErrOrders.Wrap(err)
300		}
301
302		err = pb.Unmarshal(orderSerialized, info.Order)
303		if err != nil {
304			return nil, ErrOrders.Wrap(err)
305		}
306
307		infos = append(infos, &info)
308	}
309
310	return infos, ErrOrders.Wrap(rows.Err())
311}
312
313// CleanArchive deletes all entries older than ttl.
314func (db *ordersDB) CleanArchive(ctx context.Context, deleteBefore time.Time) (_ int, err error) {
315	defer mon.Task()(&ctx)(&err)
316
317	result, err := db.ExecContext(ctx, `
318		DELETE FROM order_archive_
319		WHERE archived_at <= ?
320	`, deleteBefore.UTC())
321	if err != nil {
322		if errors.Is(err, sql.ErrNoRows) {
323			return 0, nil
324		}
325		return 0, ErrOrders.Wrap(err)
326	}
327	count, err := result.RowsAffected()
328	if err != nil {
329		return 0, ErrOrders.Wrap(err)
330	}
331	return int(count), nil
332}
333