1// Copyright (C) 2020 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package metabase
5
6import (
7	"database/sql/driver"
8	"encoding/binary"
9
10	"github.com/jackc/pgtype"
11
12	"storj.io/common/storj"
13)
14
15type encryptionParameters struct {
16	*storj.EncryptionParameters
17}
18
19// Check that EncryptionParameters layout doesn't change.
20var _ struct {
21	CipherSuite storj.CipherSuite
22	BlockSize   int32
23} = storj.EncryptionParameters{}
24
25// Value implements sql/driver.Valuer interface.
26func (params encryptionParameters) Value() (driver.Value, error) {
27	var bytes [8]byte
28	bytes[0] = byte(params.CipherSuite)
29	binary.LittleEndian.PutUint32(bytes[1:], uint32(params.BlockSize))
30	return int64(binary.LittleEndian.Uint64(bytes[:])), nil
31}
32
33// Scan implements sql.Scanner interface.
34func (params encryptionParameters) Scan(value interface{}) error {
35	switch value := value.(type) {
36	case int64:
37		var bytes [8]byte
38		binary.LittleEndian.PutUint64(bytes[:], uint64(value))
39		params.CipherSuite = storj.CipherSuite(bytes[0])
40		params.BlockSize = int32(binary.LittleEndian.Uint32(bytes[1:]))
41		return nil
42	default:
43		return Error.New("unable to scan %T into EncryptionParameters", value)
44	}
45}
46
47// Value implements sql/driver.Valuer interface.
48func (params SegmentPosition) Value() (driver.Value, error) {
49	return int64(params.Encode()), nil
50}
51
52// Scan implements sql.Scanner interface.
53func (params *SegmentPosition) Scan(value interface{}) error {
54	switch value := value.(type) {
55	case int64:
56		*params = SegmentPositionFromEncoded(uint64(value))
57		return nil
58	default:
59		return Error.New("unable to scan %T into EncryptionParameters", value)
60	}
61}
62
63type redundancyScheme struct {
64	*storj.RedundancyScheme
65}
66
67// Check that RedundancyScheme layout doesn't change.
68var _ struct {
69	Algorithm      storj.RedundancyAlgorithm
70	ShareSize      int32
71	RequiredShares int16
72	RepairShares   int16
73	OptimalShares  int16
74	TotalShares    int16
75} = storj.RedundancyScheme{}
76
77func (params redundancyScheme) Value() (driver.Value, error) {
78	switch {
79	case params.ShareSize < 0 || params.ShareSize >= 1<<24:
80		return nil, Error.New("invalid share size %v", params.ShareSize)
81	case params.RequiredShares < 0 || params.RequiredShares >= 1<<8:
82		return nil, Error.New("invalid required shares %v", params.RequiredShares)
83	case params.RepairShares < 0 || params.RepairShares >= 1<<8:
84		return nil, Error.New("invalid repair shares %v", params.RepairShares)
85	case params.OptimalShares < 0 || params.OptimalShares >= 1<<8:
86		return nil, Error.New("invalid optimal shares %v", params.OptimalShares)
87	case params.TotalShares < 0 || params.TotalShares >= 1<<8:
88		return nil, Error.New("invalid total shares %v", params.TotalShares)
89	}
90
91	var bytes [8]byte
92	bytes[0] = byte(params.Algorithm)
93
94	// little endian uint32
95	bytes[1] = byte(params.ShareSize >> 0)
96	bytes[2] = byte(params.ShareSize >> 8)
97	bytes[3] = byte(params.ShareSize >> 16)
98
99	bytes[4] = byte(params.RequiredShares)
100	bytes[5] = byte(params.RepairShares)
101	bytes[6] = byte(params.OptimalShares)
102	bytes[7] = byte(params.TotalShares)
103
104	return int64(binary.LittleEndian.Uint64(bytes[:])), nil
105}
106
107func (params redundancyScheme) Scan(value interface{}) error {
108	switch value := value.(type) {
109	case int64:
110		var bytes [8]byte
111		binary.LittleEndian.PutUint64(bytes[:], uint64(value))
112
113		params.Algorithm = storj.RedundancyAlgorithm(bytes[0])
114
115		// little endian uint32
116		params.ShareSize = int32(bytes[1]) | int32(bytes[2])<<8 | int32(bytes[3])<<16
117
118		params.RequiredShares = int16(bytes[4])
119		params.RepairShares = int16(bytes[5])
120		params.OptimalShares = int16(bytes[6])
121		params.TotalShares = int16(bytes[7])
122
123		return nil
124	default:
125		return Error.New("unable to scan %T into RedundancyScheme", value)
126	}
127}
128
129// Value implements sql/driver.Valuer interface.
130func (pieces Pieces) Value() (driver.Value, error) {
131	if len(pieces) == 0 {
132		arr := &pgtype.ByteaArray{Status: pgtype.Null}
133		return arr.Value()
134	}
135
136	elems := make([]pgtype.Bytea, len(pieces))
137	for i, piece := range pieces {
138		var buf [2 + len(piece.StorageNode)]byte
139		binary.BigEndian.PutUint16(buf[0:], piece.Number)
140		copy(buf[2:], piece.StorageNode[:])
141
142		elems[i].Bytes = buf[:]
143		elems[i].Status = pgtype.Present
144	}
145
146	arr := &pgtype.ByteaArray{
147		Elements:   elems,
148		Dimensions: []pgtype.ArrayDimension{{Length: int32(len(pieces)), LowerBound: 1}},
149		Status:     pgtype.Present,
150	}
151	return arr.Value()
152}
153
154type unexpectedDimension struct{}
155type invalidElementLength struct{}
156
157func (unexpectedDimension) Error() string  { return "unexpected data dimension" }
158func (invalidElementLength) Error() string { return "invalid element length" }
159
160// Scan implements sql.Scanner interface.
161func (pieces *Pieces) Scan(value interface{}) error {
162	var arr pgtype.ByteaArray
163	if err := arr.Scan(value); err != nil {
164		return err
165	}
166
167	if len(arr.Dimensions) == 0 {
168		*pieces = nil
169		return nil
170	} else if len(arr.Dimensions) != 1 {
171		return unexpectedDimension{}
172	}
173
174	scan := make(Pieces, len(arr.Elements))
175	for i, elem := range arr.Elements {
176		piece := Piece{}
177		if len(elem.Bytes) != 2+len(piece.StorageNode) {
178			return invalidElementLength{}
179		}
180
181		piece.Number = binary.BigEndian.Uint16(elem.Bytes[0:])
182		copy(piece.StorageNode[:], elem.Bytes[2:])
183		scan[i] = piece
184	}
185
186	*pieces = scan
187	return nil
188}
189