1package db
2
3import (
4	"context"
5	"encoding/json"
6	"fmt"
7	"io/ioutil"
8	"os"
9
10	"github.com/cheggaaa/pb/v3"
11	"github.com/go-redis/redis/v8"
12	"github.com/knqyf263/go-cpe/common"
13	"github.com/knqyf263/go-cpe/naming"
14	c "github.com/kotakanbe/go-cve-dictionary/config"
15	log "github.com/kotakanbe/go-cve-dictionary/log"
16	"github.com/kotakanbe/go-cve-dictionary/models"
17)
18
19/**
20# Redis Data Structure
21
22- HASH
23  ┌──────────────┬──────────┬──────────┬──────────────────────────────────┐
24  │    HASH      │  FIELD   │  VALUE   │             PURPOSE              │
25  └──────────────┴──────────┴──────────┴──────────────────────────────────┘
26  ┌──────────────┬──────────┬───────────┬─────────────────────────────────┐
27  │CVE#${CVEID}  │NVD or JVN│${CVEJSON} │Get CVEJSON by CVEID             │
28  ├──────────────┼──────────┼───────────┼─────────────────────────────────┤
29  │CVE#C#${CVEID}│NVD or JVN│${CPEJSON} │Get CPEJSON BY CVEID             │
30  ├──────────────┼──────────┼───────────┼─────────────────────────────────┤
31  │ CVE#Meta     │${URL}    │${METAJSON}│Get FeedMeta BY URL              │
32  └──────────────┴──────────┴───────────┴─────────────────────────────────┘
33
34- ZINDEX
35  ┌─────────────────────────┬──────────┬─────────────┬─────────────────────────────────────┐
36  │       KEY               │  SCORE   │  MEMBER     │             PURPOSE                 │
37  └─────────────────────────┴──────────┴─────────────┴─────────────────────────────────────┘
38  ┌─────────────────────────┬──────────┬─────────────┬─────────────────────────────────────┐
39  │CVE#${Vendor}::${Product}│    0     │[]${CVEID}   │Get related []CVEID by Vendor,Product│
40  └─────────────────────────┴──────────┴─────────────┴─────────────────────────────────────┘
41
42**/
43
44const (
45	dialectRedis     = "redis"
46	hashKeyPrefix    = "CVE#"
47	cpeHashKeyPrefix = "CVE#C#"
48)
49
50// RedisDriver is Driver for Redis
51type RedisDriver struct {
52	name string
53	conn *redis.Client
54}
55
56// Name return db name
57func (r *RedisDriver) Name() string {
58	return r.name
59}
60
61// NewRedis return Redis driver
62func NewRedis(dbType, dbpath string, debugSQL bool) (driver *RedisDriver, locked bool, err error) {
63	driver = &RedisDriver{
64		name: dbType,
65	}
66	log.Debugf("Opening DB (%s).", driver.Name())
67	if err = driver.OpenDB(dbType, dbpath, debugSQL); err != nil {
68		return
69	}
70
71	return
72}
73
74// OpenDB opens Database
75func (r *RedisDriver) OpenDB(dbType, dbPath string, debugSQL bool) (err error) {
76	ctx := context.Background()
77	var option *redis.Options
78	if option, err = redis.ParseURL(dbPath); err != nil {
79		log.Errorf("%s", err)
80		return fmt.Errorf("Failed to Parse Redis URL. dbpath: %s, err: %s", dbPath, err)
81	}
82	r.conn = redis.NewClient(option)
83	if err = r.conn.Ping(ctx).Err(); err != nil {
84		return fmt.Errorf("Failed to open DB. dbtype: %s, dbpath: %s, err: %s", dbType, dbPath, err)
85	}
86	return nil
87}
88
89// CloseDB close Database
90func (r *RedisDriver) CloseDB() (err error) {
91	if err = r.conn.Close(); err != nil {
92		log.Errorf("Failed to close DB. Type: %s. err: %s", r.name, err)
93		return
94	}
95	return
96}
97
98// Get Select Cve information from DB.
99func (r *RedisDriver) Get(cveID string) (*models.CveDetail, error) {
100	ctx := context.Background()
101	var cveResult, cpeResult *redis.StringStringMapCmd
102	if cveResult = r.conn.HGetAll(ctx, hashKeyPrefix+cveID); cveResult.Err() != nil {
103		return nil, cveResult.Err()
104	}
105	if cpeResult = r.conn.HGetAll(ctx, cpeHashKeyPrefix+cveID); cpeResult.Err() != nil {
106		return nil, cpeResult.Err()
107	}
108	return r.unmarshal(cveID, cveResult, cpeResult)
109}
110
111// GetMulti Select Cves information from DB.
112func (r *RedisDriver) GetMulti(cveIDs []string) (map[string]models.CveDetail, error) {
113	ctx := context.Background()
114	cveDetails := map[string]models.CveDetail{}
115	pipe := r.conn.Pipeline()
116	cveRs, cpeRs := map[string]*redis.StringStringMapCmd{}, map[string]*redis.StringStringMapCmd{}
117	for _, cveID := range cveIDs {
118		cveRs[cveID] = pipe.HGetAll(ctx, hashKeyPrefix+cveID)
119		cpeRs[cveID] = pipe.HGetAll(ctx, cpeHashKeyPrefix+cveID)
120	}
121	if _, err := pipe.Exec(ctx); err != nil {
122		if err != redis.Nil {
123			return nil, fmt.Errorf("Failed to get multi cve json. err : %s", err)
124		}
125	}
126
127	for cveID, cveResult := range cveRs {
128		cpeResult := cpeRs[cveID]
129		cveDetail, err := r.unmarshal(cveID, cveResult, cpeResult)
130		if err != nil {
131			return nil, err
132		}
133		cveDetails[cveID] = *cveDetail
134	}
135	return cveDetails, nil
136}
137
138func (r *RedisDriver) unmarshal(cveID string, cveResult, cpeResult *redis.StringStringMapCmd) (*models.CveDetail, error) {
139	var err error
140	jvn := &models.Jvn{}
141	if j, ok := cveResult.Val()["Jvn"]; ok {
142		if err = json.Unmarshal([]byte(j), jvn); err != nil {
143			return nil, err
144		}
145	}
146	if jc, ok := cpeResult.Val()["Jvn"]; ok {
147		if err = json.Unmarshal([]byte(jc), &jvn.Cpes); err != nil {
148			return nil, err
149		}
150	}
151	if jvn.CveID == "" {
152		jvn = nil
153	}
154
155	nvdjson := &models.NvdJSON{}
156	if j, ok := cveResult.Val()["NvdJSON"]; ok {
157		if err = json.Unmarshal([]byte(j), nvdjson); err != nil {
158			return nil, err
159		}
160	}
161	if jc, ok := cpeResult.Val()["NvdJSON"]; ok {
162		if err = json.Unmarshal([]byte(jc), &nvdjson.Cpes); err != nil {
163			return nil, err
164		}
165	}
166	if nvdjson.CveID == "" {
167		nvdjson = nil
168	}
169
170	return &models.CveDetail{
171		CveID:   cveID,
172		NvdJSON: nvdjson,
173		Jvn:     jvn,
174	}, nil
175}
176
177// GetCveIDsByCpeURI Select Cve Ids by by pseudo-CPE
178func (r *RedisDriver) GetCveIDsByCpeURI(uri string) ([]string, error) {
179	// TODO
180	return nil, nil
181}
182
183// GetByCpeURI Select Cve information from DB.
184func (r *RedisDriver) GetByCpeURI(uri string) ([]models.CveDetail, error) {
185	ctx := context.Background()
186	specified, err := naming.UnbindURI(uri)
187	if err != nil {
188		return nil, err
189	}
190	vendor := fmt.Sprintf("%s", specified.Get(common.AttributeVendor))
191	product := fmt.Sprintf("%s", specified.Get(common.AttributeProduct))
192	key := fmt.Sprintf("%s%s::%s", hashKeyPrefix, vendor, product)
193
194	var result *redis.StringSliceCmd
195	if result = r.conn.ZRange(ctx, key, 0, -1); result.Err() != nil {
196		return nil, result.Err()
197	}
198
199	uniqCveIDs := map[string]bool{}
200	for _, v := range result.Val() {
201		uniqCveIDs[v] = true
202	}
203	details := []models.CveDetail{}
204	for cveID := range uniqCveIDs {
205		d, err := r.Get(cveID)
206		if err != nil {
207			return nil, err
208		}
209		match, err := matchCpe(uri, d)
210		if err != nil {
211			log.Warnf("Failed to compare the version:%s %s %#v",
212				err, uri, d)
213			// continue matching
214			continue
215		}
216		if match {
217			details = append(details, *d)
218		}
219	}
220	return details, nil
221}
222
223func matchCpe(uri string, cve *models.CveDetail) (bool, error) {
224	cpes := []models.Cpe{}
225	if cve.NvdJSON != nil {
226		cpes = append(cpes, cve.NvdJSON.Cpes...)
227	}
228	if cve.Jvn != nil {
229		cpes = append(cpes, cve.Jvn.Cpes...)
230	}
231	for _, cpe := range cpes {
232		match, err := match(uri, cpe)
233		if err != nil {
234			log.Debugf("Failed to match: %s", err)
235
236			// Try to exact match by vendor, product and version if the version in CPE is not a semVer style.
237			if cve.NvdJSON == nil {
238				continue
239			}
240			ok, err := matchExactByAffects(uri, cve.NvdJSON.Affects)
241			if err != nil {
242				return false, err
243			}
244			if ok {
245				return true, nil
246			}
247			continue
248		} else if match {
249			return true, nil
250		}
251	}
252	return false, nil
253}
254
255// InsertJvn insert items fetched from JVN.
256func (r *RedisDriver) InsertJvn(cves []models.CveDetail) error {
257	ctx := context.Background()
258	log.Infof("Inserting fetched CVEs...")
259	var err error
260	var refreshedJvns []string
261	bar := pb.New(len(cves))
262	if c.Conf.Quiet {
263		bar.SetWriter(ioutil.Discard)
264	} else {
265		bar.SetWriter(os.Stderr)
266	}
267	bar.Start()
268
269	for chunked := range chunkSlice(cves, 10) {
270		var pipe redis.Pipeliner
271		pipe = r.conn.Pipeline()
272		for _, c := range chunked {
273			bar.Increment()
274
275			cpes := make([]models.Cpe, len(c.Jvn.Cpes))
276			copy(cpes, c.Jvn.Cpes)
277			c.Jvn.Cpes = nil
278
279			var jj []byte
280			if jj, err = json.Marshal(c.Jvn); err != nil {
281				return fmt.Errorf("Failed to marshal json. err: %s", err)
282			}
283			refreshedJvns = append(refreshedJvns, c.CveID)
284			if result := pipe.HSet(ctx, hashKeyPrefix+c.CveID, "Jvn", string(jj)); result.Err() != nil {
285				return fmt.Errorf("Failed to HSet CVE. err: %s", result.Err())
286			}
287
288			for _, cpe := range cpes {
289				if result := pipe.ZAdd(
290					ctx,
291					fmt.Sprintf("%s%s::%s", hashKeyPrefix, cpe.Vendor, cpe.Product),
292					&redis.Z{Score: 0, Member: c.CveID},
293				); result.Err() != nil {
294					return fmt.Errorf("Failed to ZAdd cpe. err: %s", result.Err())
295				}
296			}
297			var jc []byte
298			if jc, err = json.Marshal(cpes); err != nil {
299				return fmt.Errorf("Failed to marshal json. err: %s", err)
300			}
301			if result := pipe.HSet(ctx, cpeHashKeyPrefix+c.CveID, "Jvn", string(jc)); result.Err() != nil {
302				return fmt.Errorf("Failed to HSet CPE. err: %s", result.Err())
303			}
304		}
305		if _, err = pipe.Exec(ctx); err != nil {
306			return fmt.Errorf("Failed to exec pipeline. err: %s", err)
307		}
308	}
309	bar.Finish()
310	log.Infof("Refreshed %d Jvns.", len(refreshedJvns))
311	//  log.Debugf("%v", refreshedJvns)
312	return nil
313}
314
315// CountNvd count nvd table
316func (r *RedisDriver) CountNvd() (int, error) {
317	ctx := context.Background()
318	var result *redis.StringSliceCmd
319	if result = r.conn.Keys(ctx, hashKeyPrefix+"CVE*"); result.Err() != nil {
320		return 0, result.Err()
321	}
322	return len(result.Val()), nil
323}
324
325// InsertNvdJSON Cve information from DB.
326func (r *RedisDriver) InsertNvdJSON(cves []models.CveDetail) error {
327	ctx := context.Background()
328	log.Infof("Inserting CVEs...")
329	var err error
330	var refreshedNvds []string
331	bar := pb.New(len(cves))
332	if c.Conf.Quiet {
333		bar.SetWriter(ioutil.Discard)
334	} else {
335		bar.SetWriter(os.Stderr)
336	}
337	bar.Start()
338
339	for chunked := range chunkSlice(cves, 10) {
340		var pipe redis.Pipeliner
341		pipe = r.conn.Pipeline()
342		for _, c := range chunked {
343			bar.Increment()
344
345			cpes := make([]models.Cpe, len(c.NvdJSON.Cpes))
346			copy(cpes, c.NvdJSON.Cpes)
347			c.NvdJSON.Cpes = nil
348			var jn []byte
349			if jn, err = json.Marshal(c.NvdJSON); err != nil {
350				return fmt.Errorf("Failed to marshal json. err: %s", err)
351			}
352			refreshedNvds = append(refreshedNvds, c.CveID)
353			if result := pipe.HSet(ctx, hashKeyPrefix+c.CveID, "NvdJSON", string(jn)); result.Err() != nil {
354				return fmt.Errorf("Failed to HSet CVE. err: %s", result.Err())
355			}
356
357			for _, cpe := range cpes {
358				if result := pipe.ZAdd(
359					ctx,
360					fmt.Sprintf("%s%s::%s", hashKeyPrefix, cpe.Vendor, cpe.Product),
361					&redis.Z{Score: 0, Member: c.CveID},
362				); result.Err() != nil {
363					return fmt.Errorf("Failed to ZAdd cpe. err: %s", result.Err())
364				}
365			}
366			var jc []byte
367			if jc, err = json.Marshal(cpes); err != nil {
368				return fmt.Errorf("Failed to marshal json. err: %s", err)
369			}
370			if result := pipe.HSet(ctx, cpeHashKeyPrefix+c.CveID, "NvdJSON", string(jc)); result.Err() != nil {
371				return fmt.Errorf("Failed to HSet NVD CPE. err: %s", result.Err())
372			}
373		}
374		if _, err = pipe.Exec(ctx); err != nil {
375			return fmt.Errorf("Failed to exec pipeline. err: %s", err)
376		}
377	}
378	bar.Finish()
379
380	log.Infof("Refreshed %d Nvds.", len(refreshedNvds))
381	//  log.Debugf("%v", refreshedNvds)
382	return nil
383}
384
385// GetFetchedFeedMeta selects hash in metafile of the year
386func (r *RedisDriver) GetFetchedFeedMeta(url string) (*models.FeedMeta, error) {
387	ctx := context.Background()
388	var result *redis.StringStringMapCmd
389	if result = r.conn.HGetAll(ctx, hashKeyPrefix+"Meta"); result.Err() != nil {
390		return nil, result.Err()
391	}
392	meta := &models.FeedMeta{}
393	if s, ok := result.Val()[url]; ok {
394		if err := json.Unmarshal([]byte(s), meta); err != nil {
395			return nil, err
396		}
397		return meta, nil
398	}
399	return meta, nil
400}
401
402// UpsertFeedHash selects hash in metafile of the year
403func (r *RedisDriver) UpsertFeedHash(m models.FeedMeta) error {
404	ctx := context.Background()
405	jn, err := json.Marshal(m)
406	if err != nil {
407		return fmt.Errorf("Failed to marshal json. err: %s", err)
408	}
409
410	var pipe redis.Pipeliner
411	pipe = r.conn.Pipeline()
412	if result := pipe.HSet(ctx, hashKeyPrefix+"Meta", m.URL, jn); result.Err() != nil {
413		return fmt.Errorf("Failed to HSet META. err: %s", result.Err())
414	}
415	if _, err := pipe.Exec(ctx); err != nil {
416		return fmt.Errorf("Failed to exec pipeline. err: %s", err)
417	}
418	return nil
419}
420
421// GetFetchedFeedMetas selects a list of FeedMeta
422func (r *RedisDriver) GetFetchedFeedMetas() (metas []models.FeedMeta, err error) {
423	ctx := context.Background()
424	var result *redis.StringStringMapCmd
425	if result = r.conn.HGetAll(ctx, hashKeyPrefix+"Meta"); result.Err() != nil {
426		return nil, result.Err()
427	}
428	for _, s := range result.Val() {
429		m := models.FeedMeta{}
430		if err := json.Unmarshal([]byte(s), &m); err != nil {
431			return nil, err
432		}
433		metas = append(metas, m)
434	}
435	return
436}
437