1package db
2
3import (
4	"database/sql"
5	"encoding/json"
6	"errors"
7	"fmt"
8
9	sq "github.com/Masterminds/squirrel"
10	"github.com/concourse/concourse/atc"
11	"github.com/lib/pq"
12	uuid "github.com/nu7hatch/gouuid"
13)
14
15var (
16	ErrVolumeCannotBeDestroyedWithChildrenPresent = errors.New("volume cannot be destroyed as children are present")
17	ErrVolumeStateTransitionFailed                = errors.New("could not transition volume state")
18	ErrVolumeMissing                              = errors.New("volume no longer in db")
19	ErrInvalidResourceCache                       = errors.New("invalid resource cache")
20)
21
22type ErrVolumeMarkStateFailed struct {
23	State VolumeState
24}
25
26func (e ErrVolumeMarkStateFailed) Error() string {
27	return fmt.Sprintf("could not mark volume as %s", e.State)
28}
29
30type ErrVolumeMarkCreatedFailed struct {
31	Handle string
32}
33
34func (e ErrVolumeMarkCreatedFailed) Error() string {
35	return fmt.Sprintf("failed to mark volume as created %s", e.Handle)
36}
37
38type VolumeState string
39
40const (
41	VolumeStateCreating   VolumeState = "creating"
42	VolumeStateCreated    VolumeState = "created"
43	VolumeStateDestroying VolumeState = "destroying"
44	VolumeStateFailed     VolumeState = "failed"
45)
46
47type VolumeType string
48
49const (
50	VolumeTypeContainer     VolumeType = "container"
51	VolumeTypeResource      VolumeType = "resource"
52	VolumeTypeResourceType  VolumeType = "resource-type"
53	VolumeTypeResourceCerts VolumeType = "resource-certs"
54	VolumeTypeTaskCache     VolumeType = "task-cache"
55	VolumeTypeArtifact      VolumeType = "artifact"
56	VolumeTypeUknown        VolumeType = "unknown" // for migration to life
57)
58
59//go:generate counterfeiter . CreatingVolume
60
61type CreatingVolume interface {
62	Handle() string
63	ID() int
64	Created() (CreatedVolume, error)
65	Failed() (FailedVolume, error)
66}
67
68type creatingVolume struct {
69	id                       int
70	workerName               string
71	handle                   string
72	path                     string
73	teamID                   int
74	typ                      VolumeType
75	containerHandle          string
76	parentHandle             string
77	resourceCacheID          int
78	workerBaseResourceTypeID int
79	workerTaskCacheID        int
80	workerResourceCertsID    int
81	workerArtifactID         int
82	conn                     Conn
83}
84
85func (volume *creatingVolume) ID() int { return volume.id }
86
87func (volume *creatingVolume) Handle() string { return volume.handle }
88
89func (volume *creatingVolume) Created() (CreatedVolume, error) {
90	err := volumeStateTransition(
91		volume.id,
92		volume.conn,
93		VolumeStateCreating,
94		VolumeStateCreated,
95	)
96	if err != nil {
97		if err == ErrVolumeStateTransitionFailed {
98			return nil, ErrVolumeMarkCreatedFailed{Handle: volume.handle}
99		}
100		return nil, err
101	}
102
103	return &createdVolume{
104		id:                       volume.id,
105		workerName:               volume.workerName,
106		typ:                      volume.typ,
107		handle:                   volume.handle,
108		path:                     volume.path,
109		teamID:                   volume.teamID,
110		conn:                     volume.conn,
111		containerHandle:          volume.containerHandle,
112		parentHandle:             volume.parentHandle,
113		resourceCacheID:          volume.resourceCacheID,
114		workerBaseResourceTypeID: volume.workerBaseResourceTypeID,
115		workerTaskCacheID:        volume.workerTaskCacheID,
116		workerResourceCertsID:    volume.workerResourceCertsID,
117	}, nil
118}
119
120func (volume *creatingVolume) Failed() (FailedVolume, error) {
121	err := volumeStateTransition(
122		volume.id,
123		volume.conn,
124		VolumeStateCreating,
125		VolumeStateFailed,
126	)
127	if err != nil {
128		if err == ErrVolumeStateTransitionFailed {
129			return nil, ErrVolumeMarkStateFailed{VolumeStateFailed}
130		}
131		return nil, err
132	}
133
134	return &failedVolume{
135		id:         volume.id,
136		workerName: volume.workerName,
137		handle:     volume.handle,
138		conn:       volume.conn,
139	}, nil
140}
141
142//go:generate counterfeiter . CreatedVolume
143// TODO-Later Consider separating CORE & Runtime concerns by breaking this abstraction up.
144type CreatedVolume interface {
145	Handle() string
146	Path() string
147	Type() VolumeType
148	TeamID() int
149	WorkerArtifactID() int
150	CreateChildForContainer(CreatingContainer, string) (CreatingVolume, error)
151	Destroying() (DestroyingVolume, error)
152	WorkerName() string
153
154	InitializeResourceCache(UsedResourceCache) error
155	GetResourceCacheID() int
156	InitializeArtifact(name string, buildID int) (WorkerArtifact, error)
157	InitializeTaskCache(jobID int, stepName string, path string) error
158
159	ContainerHandle() string
160	ParentHandle() string
161	ResourceType() (*VolumeResourceType, error)
162	BaseResourceType() (*UsedWorkerBaseResourceType, error)
163	TaskIdentifier() (string, string, string, error)
164}
165
166type createdVolume struct {
167	id                       int
168	workerName               string
169	handle                   string
170	path                     string
171	teamID                   int
172	typ                      VolumeType
173	containerHandle          string
174	parentHandle             string
175	resourceCacheID          int
176	workerBaseResourceTypeID int
177	workerTaskCacheID        int
178	workerResourceCertsID    int
179	workerArtifactID         int
180	conn                     Conn
181}
182
183type VolumeResourceType struct {
184	WorkerBaseResourceType *UsedWorkerBaseResourceType
185	ResourceType           *VolumeResourceType
186	Version                atc.Version
187}
188
189func (volume *createdVolume) Handle() string          { return volume.handle }
190func (volume *createdVolume) Path() string            { return volume.path }
191func (volume *createdVolume) WorkerName() string      { return volume.workerName }
192func (volume *createdVolume) Type() VolumeType        { return volume.typ }
193func (volume *createdVolume) TeamID() int             { return volume.teamID }
194func (volume *createdVolume) ContainerHandle() string { return volume.containerHandle }
195func (volume *createdVolume) ParentHandle() string    { return volume.parentHandle }
196func (volume *createdVolume) WorkerArtifactID() int   { return volume.workerArtifactID }
197
198func (volume *createdVolume) ResourceType() (*VolumeResourceType, error) {
199	if volume.resourceCacheID == 0 {
200		return nil, nil
201	}
202
203	return volume.findVolumeResourceTypeByCacheID(volume.resourceCacheID)
204}
205
206func (volume *createdVolume) BaseResourceType() (*UsedWorkerBaseResourceType, error) {
207	if volume.workerBaseResourceTypeID == 0 {
208		return nil, nil
209	}
210
211	return volume.findWorkerBaseResourceTypeByID(volume.workerBaseResourceTypeID)
212}
213
214func (volume *createdVolume) TaskIdentifier() (string, string, string, error) {
215	if volume.workerTaskCacheID == 0 {
216		return "", "", "", nil
217	}
218
219	var pipelineName string
220	var jobName string
221	var stepName string
222
223	err := psql.Select("p.name, j.name, tc.step_name").
224		From("worker_task_caches wtc").
225		LeftJoin("task_caches tc on tc.id = wtc.task_cache_id").
226		LeftJoin("jobs j ON j.id = tc.job_id").
227		LeftJoin("pipelines p ON p.id = j.pipeline_id").
228		Where(sq.Eq{
229			"wtc.id": volume.workerTaskCacheID,
230		}).
231		RunWith(volume.conn).
232		QueryRow().
233		Scan(&pipelineName, &jobName, &stepName)
234	if err != nil {
235		return "", "", "", err
236	}
237
238	return pipelineName, jobName, stepName, nil
239}
240
241func (volume *createdVolume) findVolumeResourceTypeByCacheID(resourceCacheID int) (*VolumeResourceType, error) {
242	var versionString []byte
243	var sqBaseResourceTypeID sql.NullInt64
244	var sqResourceCacheID sql.NullInt64
245
246	err := psql.Select("rc.version, rcfg.base_resource_type_id, rcfg.resource_cache_id").
247		From("resource_caches rc").
248		LeftJoin("resource_configs rcfg ON rcfg.id = rc.resource_config_id").
249		Where(sq.Eq{
250			"rc.id": resourceCacheID,
251		}).
252		RunWith(volume.conn).
253		QueryRow().
254		Scan(&versionString, &sqBaseResourceTypeID, &sqResourceCacheID)
255	if err != nil {
256		return nil, err
257	}
258
259	var version atc.Version
260	err = json.Unmarshal(versionString, &version)
261	if err != nil {
262		return nil, err
263	}
264
265	if sqBaseResourceTypeID.Valid {
266		workerBaseResourceType, err := volume.findWorkerBaseResourceTypeByBaseResourceTypeID(int(sqBaseResourceTypeID.Int64))
267		if err != nil {
268			return nil, err
269		}
270
271		return &VolumeResourceType{
272			WorkerBaseResourceType: workerBaseResourceType,
273			Version:                version,
274		}, nil
275	}
276
277	if sqResourceCacheID.Valid {
278		resourceType, err := volume.findVolumeResourceTypeByCacheID(int(sqResourceCacheID.Int64))
279		if err != nil {
280			return nil, err
281		}
282
283		return &VolumeResourceType{
284			ResourceType: resourceType,
285			Version:      version,
286		}, nil
287	}
288
289	return nil, ErrInvalidResourceCache
290}
291
292func (volume *createdVolume) findWorkerBaseResourceTypeByID(workerBaseResourceTypeID int) (*UsedWorkerBaseResourceType, error) {
293	var name string
294	var version string
295
296	err := psql.Select("brt.name, wbrt.version").
297		From("worker_base_resource_types wbrt").
298		LeftJoin("base_resource_types brt ON brt.id = wbrt.base_resource_type_id").
299		Where(sq.Eq{
300			"wbrt.id":          workerBaseResourceTypeID,
301			"wbrt.worker_name": volume.workerName,
302		}).
303		RunWith(volume.conn).
304		QueryRow().
305		Scan(&name, &version)
306	if err != nil {
307		return nil, err
308	}
309
310	return &UsedWorkerBaseResourceType{
311		ID:         workerBaseResourceTypeID,
312		Name:       name,
313		Version:    version,
314		WorkerName: volume.workerName,
315	}, nil
316}
317
318func (volume *createdVolume) findWorkerBaseResourceTypeByBaseResourceTypeID(baseResourceTypeID int) (*UsedWorkerBaseResourceType, error) {
319	var id int
320	var name string
321	var version string
322
323	err := psql.Select("wbrt.id, brt.name, wbrt.version").
324		From("worker_base_resource_types wbrt").
325		LeftJoin("base_resource_types brt ON brt.id = wbrt.base_resource_type_id").
326		Where(sq.Eq{
327			"brt.id":           baseResourceTypeID,
328			"wbrt.worker_name": volume.workerName,
329		}).
330		RunWith(volume.conn).
331		QueryRow().
332		Scan(&id, &name, &version)
333	if err != nil {
334		return nil, err
335	}
336
337	return &UsedWorkerBaseResourceType{
338		ID:         id,
339		Name:       name,
340		Version:    version,
341		WorkerName: volume.workerName,
342	}, nil
343}
344
345func (volume *createdVolume) InitializeResourceCache(resourceCache UsedResourceCache) error {
346	tx, err := volume.conn.Begin()
347	if err != nil {
348		return err
349	}
350
351	defer tx.Rollback()
352
353	workerResourceCache, err := WorkerResourceCache{
354		WorkerName:    volume.WorkerName(),
355		ResourceCache: resourceCache,
356	}.FindOrCreate(tx)
357	if err != nil {
358		return err
359	}
360
361	rows, err := psql.Update("volumes").
362		Set("worker_resource_cache_id", workerResourceCache.ID).
363		Set("team_id", nil).
364		Where(sq.Eq{"id": volume.id}).
365		RunWith(tx).
366		Exec()
367	if err != nil {
368		if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == pqUniqueViolationErrCode {
369			// another volume was 'blessed' as the cache volume - leave this one
370			// owned by the container so it just expires when the container is GCed
371			return nil
372		}
373
374		return err
375	}
376
377	affected, err := rows.RowsAffected()
378	if err != nil {
379		return err
380	}
381
382	if affected == 0 {
383		return ErrVolumeMissing
384	}
385
386	err = tx.Commit()
387	if err != nil {
388		return err
389	}
390
391	volume.resourceCacheID = resourceCache.ID()
392	volume.typ = VolumeTypeResource
393
394	return nil
395}
396
397func (volume *createdVolume) GetResourceCacheID() int {
398	return volume.resourceCacheID
399}
400
401func (volume *createdVolume) InitializeArtifact(name string, buildID int) (WorkerArtifact, error) {
402	tx, err := volume.conn.Begin()
403	if err != nil {
404		return nil, err
405	}
406
407	defer Rollback(tx)
408
409	atcWorkerArtifact := atc.WorkerArtifact{
410		Name:    name,
411		BuildID: buildID,
412	}
413
414	workerArtifact, err := saveWorkerArtifact(tx, volume.conn, atcWorkerArtifact)
415	if err != nil {
416		return nil, err
417	}
418
419	rows, err := psql.Update("volumes").
420		Set("worker_artifact_id", workerArtifact.ID()).
421		Where(sq.Eq{"id": volume.id}).
422		RunWith(tx).
423		Exec()
424	if err != nil {
425		return nil, err
426	}
427
428	affected, err := rows.RowsAffected()
429	if err != nil {
430		return nil, err
431	}
432
433	if affected == 0 {
434		return nil, ErrVolumeMissing
435	}
436
437	err = tx.Commit()
438	if err != nil {
439		return nil, err
440	}
441
442	return workerArtifact, nil
443}
444
445func (volume *createdVolume) InitializeTaskCache(jobID int, stepName string, path string) error {
446	tx, err := volume.conn.Begin()
447	if err != nil {
448		return err
449	}
450
451	defer Rollback(tx)
452
453	usedTaskCache, err := usedTaskCache{
454		jobID:    jobID,
455		stepName: stepName,
456		path:     path,
457	}.findOrCreate(tx)
458	if err != nil {
459		return err
460	}
461
462	usedWorkerTaskCache, err := WorkerTaskCache{
463		WorkerName: volume.WorkerName(),
464		TaskCache:  usedTaskCache,
465	}.findOrCreate(tx)
466	if err != nil {
467		return err
468	}
469
470	// release other old volumes for gc
471	_, err = psql.Update("volumes").
472		Set("worker_task_cache_id", nil).
473		Where(sq.Eq{"worker_task_cache_id": usedWorkerTaskCache.ID}).
474		RunWith(tx).
475		Exec()
476	if err != nil {
477		return err
478	}
479
480	rows, err := psql.Update("volumes").
481		Set("worker_task_cache_id", usedWorkerTaskCache.ID).
482		Where(sq.Eq{"id": volume.id}).
483		RunWith(tx).
484		Exec()
485	if err != nil {
486		if pqErr, ok := err.(*pq.Error); ok && pqErr.Code.Name() == pqUniqueViolationErrCode {
487			// another volume was 'blessed' as the cache volume - leave this one
488			// owned by the container so it just expires when the container is GCed
489			return nil
490		}
491
492		return err
493	}
494
495	affected, err := rows.RowsAffected()
496	if err != nil {
497		return err
498	}
499
500	if affected == 0 {
501		return ErrVolumeMissing
502	}
503
504	err = tx.Commit()
505	if err != nil {
506		return err
507	}
508
509	return nil
510}
511
512func (volume *createdVolume) CreateChildForContainer(container CreatingContainer, mountPath string) (CreatingVolume, error) {
513	tx, err := volume.conn.Begin()
514	if err != nil {
515		return nil, err
516	}
517
518	defer Rollback(tx)
519
520	handle, err := uuid.NewV4()
521	if err != nil {
522		return nil, err
523	}
524
525	columnNames := []string{
526		"worker_name",
527		"parent_id",
528		"parent_state",
529		"handle",
530		"container_id",
531		"path",
532	}
533	columnValues := []interface{}{
534		volume.workerName,
535		volume.id,
536		VolumeStateCreated,
537		handle.String(),
538		container.ID(),
539		mountPath,
540	}
541
542	if volume.teamID != 0 {
543		columnNames = append(columnNames, "team_id")
544		columnValues = append(columnValues, volume.teamID)
545	}
546
547	var volumeID int
548	err = psql.Insert("volumes").
549		Columns(columnNames...).
550		Values(columnValues...).
551		Suffix("RETURNING id").
552		RunWith(tx).
553		QueryRow().
554		Scan(&volumeID)
555	if err != nil {
556		return nil, err
557	}
558
559	err = tx.Commit()
560	if err != nil {
561		return nil, err
562	}
563
564	return &creatingVolume{
565		id:              volumeID,
566		workerName:      volume.workerName,
567		handle:          handle.String(),
568		path:            mountPath,
569		teamID:          volume.teamID,
570		typ:             VolumeTypeContainer,
571		containerHandle: container.Handle(),
572		parentHandle:    volume.Handle(),
573		conn:            volume.conn,
574	}, nil
575}
576
577func (volume *createdVolume) Destroying() (DestroyingVolume, error) {
578	err := volumeStateTransition(
579		volume.id,
580		volume.conn,
581		VolumeStateCreated,
582		VolumeStateDestroying,
583	)
584	if err != nil {
585		if err == ErrVolumeStateTransitionFailed {
586			return nil, ErrVolumeMarkStateFailed{VolumeStateDestroying}
587
588		}
589
590		if pqErr, ok := err.(*pq.Error); ok &&
591			pqErr.Code.Name() == pqFKeyViolationErrCode &&
592			pqErr.Constraint == "volumes_parent_id_fkey" {
593			return nil, ErrVolumeCannotBeDestroyedWithChildrenPresent
594		}
595
596		return nil, err
597	}
598
599	return &destroyingVolume{
600		id:         volume.id,
601		workerName: volume.workerName,
602		handle:     volume.handle,
603		conn:       volume.conn,
604	}, nil
605}
606
607//go:generate counterfeiter . DestroyingVolume
608type DestroyingVolume interface {
609	Handle() string
610	Destroy() (bool, error)
611	WorkerName() string
612}
613
614type destroyingVolume struct {
615	id         int
616	workerName string
617	handle     string
618	conn       Conn
619}
620
621func (volume *destroyingVolume) Handle() string     { return volume.handle }
622func (volume *destroyingVolume) WorkerName() string { return volume.workerName }
623
624func (volume *destroyingVolume) Destroy() (bool, error) {
625	rows, err := psql.Delete("volumes").
626		Where(sq.Eq{
627			"id":    volume.id,
628			"state": VolumeStateDestroying,
629		}).
630		RunWith(volume.conn).
631		Exec()
632	if err != nil {
633		return false, err
634	}
635
636	affected, err := rows.RowsAffected()
637	if err != nil {
638		return false, err
639	}
640
641	if affected == 0 {
642		return false, nil
643	}
644
645	return true, nil
646}
647
648type FailedVolume interface {
649	Handle() string
650	Destroy() (bool, error)
651	WorkerName() string
652}
653
654type failedVolume struct {
655	id         int
656	workerName string
657	handle     string
658	conn       Conn
659}
660
661func (volume *failedVolume) Handle() string     { return volume.handle }
662func (volume *failedVolume) WorkerName() string { return volume.workerName }
663
664func (volume *failedVolume) Destroy() (bool, error) {
665	rows, err := psql.Delete("volumes").
666		Where(sq.Eq{
667			"id":    volume.id,
668			"state": VolumeStateFailed,
669		}).
670		RunWith(volume.conn).
671		Exec()
672	if err != nil {
673		return false, err
674	}
675
676	affected, err := rows.RowsAffected()
677	if err != nil {
678		return false, err
679	}
680
681	if affected == 0 {
682		return false, nil
683	}
684
685	return true, nil
686}
687
688func volumeStateTransition(volumeID int, conn Conn, from, to VolumeState) error {
689	rows, err := psql.Update("volumes").
690		Set("state", string(to)).
691		Where(sq.And{
692			sq.Eq{"id": volumeID},
693			sq.Or{
694				sq.Eq{"state": string(from)},
695				sq.Eq{"state": string(to)},
696			},
697		}).
698		RunWith(conn).
699		Exec()
700	if err != nil {
701		return err
702	}
703
704	affected, err := rows.RowsAffected()
705	if err != nil {
706		return err
707	}
708
709	if affected == 0 {
710		return ErrVolumeStateTransitionFailed
711	}
712
713	return nil
714}
715