1package db
2
3import (
4	"database/sql"
5	"strings"
6
7	"encoding/json"
8
9	sq "github.com/Masterminds/squirrel"
10	"github.com/concourse/concourse/atc"
11	"github.com/concourse/concourse/atc/db/lock"
12)
13
14//go:generate counterfeiter . TeamFactory
15
16type TeamFactory interface {
17	CreateTeam(atc.Team) (Team, error)
18	FindTeam(string) (Team, bool, error)
19	GetTeams() ([]Team, error)
20	GetByID(teamID int) Team
21	CreateDefaultTeamIfNotExists() (Team, error)
22	NotifyResourceScanner() error
23	NotifyCacher() error
24}
25
26type teamFactory struct {
27	conn        Conn
28	lockFactory lock.LockFactory
29}
30
31func NewTeamFactory(conn Conn, lockFactory lock.LockFactory) TeamFactory {
32	return &teamFactory{
33		conn:        conn,
34		lockFactory: lockFactory,
35	}
36}
37
38func (factory *teamFactory) CreateTeam(t atc.Team) (Team, error) {
39	return factory.createTeam(t, false)
40}
41
42func (factory *teamFactory) createTeam(t atc.Team, admin bool) (Team, error) {
43	tx, err := factory.conn.Begin()
44	if err != nil {
45		return nil, err
46	}
47
48	defer Rollback(tx)
49
50	auth, err := json.Marshal(t.Auth)
51	if err != nil {
52		return nil, err
53	}
54
55	row := psql.Insert("teams").
56		Columns("name, auth, admin").
57		Values(t.Name, auth, admin).
58		Suffix("RETURNING id, name, admin, auth").
59		RunWith(tx).
60		QueryRow()
61
62	team := &team{
63		conn:        factory.conn,
64		lockFactory: factory.lockFactory,
65	}
66
67	err = factory.scanTeam(team, row)
68	if err != nil {
69		return nil, err
70	}
71
72	err = tx.Commit()
73	if err != nil {
74		return nil, err
75	}
76
77	return team, nil
78}
79
80func (factory *teamFactory) GetByID(teamID int) Team {
81	return &team{
82		id:          teamID,
83		conn:        factory.conn,
84		lockFactory: factory.lockFactory,
85	}
86}
87
88func (factory *teamFactory) FindTeam(teamName string) (Team, bool, error) {
89	team := &team{
90		conn:        factory.conn,
91		lockFactory: factory.lockFactory,
92	}
93
94	row := psql.Select("id, name, admin, auth").
95		From("teams").
96		Where(sq.Eq{"LOWER(name)": strings.ToLower(teamName)}).
97		RunWith(factory.conn).
98		QueryRow()
99
100	err := factory.scanTeam(team, row)
101
102	if err != nil {
103		if err == sql.ErrNoRows {
104			return nil, false, nil
105		}
106		return nil, false, err
107	}
108
109	return team, true, nil
110}
111
112func (factory *teamFactory) GetTeams() ([]Team, error) {
113	rows, err := psql.Select("id, name, admin, auth").
114		From("teams").
115		OrderBy("name ASC").
116		RunWith(factory.conn).
117		Query()
118	if err != nil {
119		return nil, err
120	}
121	defer Close(rows)
122
123	teams := []Team{}
124
125	for rows.Next() {
126		team := &team{
127			conn:        factory.conn,
128			lockFactory: factory.lockFactory,
129		}
130
131		err = factory.scanTeam(team, rows)
132		if err != nil {
133			return nil, err
134		}
135
136		teams = append(teams, team)
137	}
138
139	return teams, nil
140}
141
142func (factory *teamFactory) CreateDefaultTeamIfNotExists() (Team, error) {
143	_, err := psql.Update("teams").
144		Set("admin", true).
145		Where(sq.Eq{"LOWER(name)": strings.ToLower(atc.DefaultTeamName)}).
146		RunWith(factory.conn).
147		Exec()
148
149	if err != nil && err != sql.ErrNoRows {
150		return nil, err
151	}
152
153	t, found, err := factory.FindTeam(atc.DefaultTeamName)
154	if err != nil {
155		return nil, err
156	}
157
158	if found {
159		return t, nil
160	}
161
162	//not found, have to create
163	return factory.createTeam(atc.Team{
164		Name: atc.DefaultTeamName,
165	},
166		true,
167	)
168}
169
170func (factory *teamFactory) NotifyResourceScanner() error {
171	return factory.conn.Bus().Notify(atc.ComponentLidarScanner)
172}
173
174func (factory *teamFactory) NotifyCacher() error {
175	return factory.conn.Bus().Notify(atc.TeamCacheChannel)
176}
177
178func (factory *teamFactory) scanTeam(t *team, rows scannable) error {
179	var providerAuth sql.NullString
180
181	err := rows.Scan(
182		&t.id,
183		&t.name,
184		&t.admin,
185		&providerAuth,
186	)
187
188	if providerAuth.Valid {
189		err = json.Unmarshal([]byte(providerAuth.String), &t.auth)
190		if err != nil {
191			return err
192		}
193	}
194
195	return err
196}
197