1package identitytpl
2
3import (
4	"encoding/json"
5	"errors"
6	"fmt"
7	"strconv"
8	"strings"
9	"time"
10
11	"github.com/hashicorp/errwrap"
12	"github.com/hashicorp/vault/sdk/logical"
13)
14
15var (
16	ErrUnbalancedTemplatingCharacter = errors.New("unbalanced templating characters")
17	ErrNoEntityAttachedToToken       = errors.New("string contains entity template directives but no entity was provided")
18	ErrNoGroupsAttachedToToken       = errors.New("string contains groups template directives but no groups were provided")
19	ErrTemplateValueNotFound         = errors.New("no value could be found for one of the template directives")
20)
21
22const (
23	ACLTemplating = iota // must be the first value for backwards compatibility
24	JSONTemplating
25)
26
27type PopulateStringInput struct {
28	String            string
29	ValidityCheckOnly bool
30	Entity            *logical.Entity
31	Groups            []*logical.Group
32	NamespaceID       string
33	Mode              int       // processing mode, ACLTemplate or JSONTemplating
34	Now               time.Time // optional, defaults to current time
35
36	templateHandler templateHandlerFunc
37	groupIDs        []string
38	groupNames      []string
39}
40
41// templateHandlerFunc allows generating string outputs based on data type, and
42// different handlers can be used based on mode. For example in ACL mode, strings
43// are emitted verbatim, but they're wrapped in double quotes for JSON mode. And
44// some structures, like slices, might be rendered in one mode but prohibited in
45// another.
46type templateHandlerFunc func(interface{}, ...string) (string, error)
47
48// aclTemplateHandler processes known parameter data types when operating
49// in ACL mode.
50func aclTemplateHandler(v interface{}, keys ...string) (string, error) {
51	switch t := v.(type) {
52	case string:
53		if t == "" {
54			return "", ErrTemplateValueNotFound
55		}
56		return t, nil
57	case []string:
58		return "", ErrTemplateValueNotFound
59	case map[string]string:
60		if len(keys) > 0 {
61			val, ok := t[keys[0]]
62			if ok {
63				return val, nil
64			}
65		}
66		return "", ErrTemplateValueNotFound
67	}
68
69	return "", fmt.Errorf("unknown type: %T", v)
70}
71
72// jsonTemplateHandler processes known parameter data types when operating
73// in JSON mode.
74func jsonTemplateHandler(v interface{}, keys ...string) (string, error) {
75	jsonMarshaller := func(v interface{}) (string, error) {
76		enc, err := json.Marshal(v)
77		if err != nil {
78			return "", err
79		}
80		return string(enc), nil
81	}
82
83	switch t := v.(type) {
84	case string:
85		return strconv.Quote(t), nil
86	case []string:
87		return jsonMarshaller(t)
88	case map[string]string:
89		if len(keys) > 0 {
90			return strconv.Quote(t[keys[0]]), nil
91		}
92		if t == nil {
93			return "{}", nil
94		}
95		return jsonMarshaller(t)
96	}
97
98	return "", fmt.Errorf("unknown type: %T", v)
99}
100
101func PopulateString(p PopulateStringInput) (bool, string, error) {
102	if p.String == "" {
103		return false, "", nil
104	}
105
106	// preprocess groups
107	for _, g := range p.Groups {
108		p.groupNames = append(p.groupNames, g.Name)
109		p.groupIDs = append(p.groupIDs, g.ID)
110	}
111
112	// set up mode-specific handler
113	switch p.Mode {
114	case ACLTemplating:
115		p.templateHandler = aclTemplateHandler
116	case JSONTemplating:
117		p.templateHandler = jsonTemplateHandler
118	default:
119		return false, "", fmt.Errorf("unknown mode %q", p.Mode)
120	}
121
122	var subst bool
123	splitStr := strings.Split(p.String, "{{")
124
125	if len(splitStr) >= 1 {
126		if strings.Contains(splitStr[0], "}}") {
127			return false, "", ErrUnbalancedTemplatingCharacter
128		}
129		if len(splitStr) == 1 {
130			return false, p.String, nil
131		}
132	}
133
134	var b strings.Builder
135	if !p.ValidityCheckOnly {
136		b.Grow(2 * len(p.String))
137	}
138
139	for i, str := range splitStr {
140		if i == 0 {
141			if !p.ValidityCheckOnly {
142				b.WriteString(str)
143			}
144			continue
145		}
146		splitPiece := strings.Split(str, "}}")
147		switch len(splitPiece) {
148		case 2:
149			subst = true
150			if !p.ValidityCheckOnly {
151				tmplStr, err := performTemplating(strings.TrimSpace(splitPiece[0]), &p)
152				if err != nil {
153					return false, "", err
154				}
155				b.WriteString(tmplStr)
156				b.WriteString(splitPiece[1])
157			}
158		default:
159			return false, "", ErrUnbalancedTemplatingCharacter
160		}
161	}
162
163	return subst, b.String(), nil
164}
165
166func performTemplating(input string, p *PopulateStringInput) (string, error) {
167	performAliasTemplating := func(trimmed string, alias *logical.Alias) (string, error) {
168		switch {
169		case trimmed == "id":
170			return p.templateHandler(alias.ID)
171
172		case trimmed == "name":
173			return p.templateHandler(alias.Name)
174
175		case trimmed == "metadata":
176			return p.templateHandler(alias.Metadata)
177
178		case strings.HasPrefix(trimmed, "metadata."):
179			split := strings.SplitN(trimmed, ".", 2)
180			return p.templateHandler(alias.Metadata, split[1])
181		}
182
183		return "", ErrTemplateValueNotFound
184	}
185
186	performEntityTemplating := func(trimmed string) (string, error) {
187		switch {
188		case trimmed == "id":
189			return p.templateHandler(p.Entity.ID)
190
191		case trimmed == "name":
192			return p.templateHandler(p.Entity.Name)
193
194		case trimmed == "metadata":
195			return p.templateHandler(p.Entity.Metadata)
196
197		case strings.HasPrefix(trimmed, "metadata."):
198			split := strings.SplitN(trimmed, ".", 2)
199			return p.templateHandler(p.Entity.Metadata, split[1])
200
201		case trimmed == "groups.names":
202			return p.templateHandler(p.groupNames)
203
204		case trimmed == "groups.ids":
205			return p.templateHandler(p.groupIDs)
206
207		case strings.HasPrefix(trimmed, "aliases."):
208			split := strings.SplitN(strings.TrimPrefix(trimmed, "aliases."), ".", 2)
209			if len(split) != 2 {
210				return "", errors.New("invalid alias selector")
211			}
212			var alias *logical.Alias
213			for _, a := range p.Entity.Aliases {
214				if split[0] == a.MountAccessor {
215					alias = a
216					break
217				}
218			}
219			if alias == nil {
220				if p.Mode == ACLTemplating {
221					return "", errors.New("alias not found")
222				}
223
224				// An empty alias is sufficient for generating defaults
225				alias = &logical.Alias{Metadata: make(map[string]string)}
226			}
227			return performAliasTemplating(split[1], alias)
228		}
229
230		return "", ErrTemplateValueNotFound
231	}
232
233	performGroupsTemplating := func(trimmed string) (string, error) {
234		var ids bool
235
236		selectorSplit := strings.SplitN(trimmed, ".", 2)
237
238		switch {
239		case len(selectorSplit) != 2:
240			return "", errors.New("invalid groups selector")
241
242		case selectorSplit[0] == "ids":
243			ids = true
244
245		case selectorSplit[0] == "names":
246
247		default:
248			return "", errors.New("invalid groups selector")
249		}
250		trimmed = selectorSplit[1]
251
252		accessorSplit := strings.SplitN(trimmed, ".", 2)
253		if len(accessorSplit) != 2 {
254			return "", errors.New("invalid groups accessor")
255		}
256		var found *logical.Group
257		for _, group := range p.Groups {
258			var compare string
259			if ids {
260				compare = group.ID
261			} else {
262				if p.NamespaceID != "" && group.NamespaceID != p.NamespaceID {
263					continue
264				}
265				compare = group.Name
266			}
267
268			if compare == accessorSplit[0] {
269				found = group
270				break
271			}
272		}
273
274		if found == nil {
275			return "", fmt.Errorf("entity is not a member of group %q", accessorSplit[0])
276		}
277
278		trimmed = accessorSplit[1]
279
280		switch {
281		case trimmed == "id":
282			return found.ID, nil
283
284		case trimmed == "name":
285			if found.Name == "" {
286				return "", ErrTemplateValueNotFound
287			}
288			return found.Name, nil
289
290		case strings.HasPrefix(trimmed, "metadata."):
291			val, ok := found.Metadata[strings.TrimPrefix(trimmed, "metadata.")]
292			if !ok {
293				return "", ErrTemplateValueNotFound
294			}
295			return val, nil
296		}
297
298		return "", ErrTemplateValueNotFound
299	}
300
301	performTimeTemplating := func(trimmed string) (string, error) {
302		now := p.Now
303		if now.IsZero() {
304			now = time.Now()
305		}
306
307		opsSplit := strings.SplitN(trimmed, ".", 3)
308
309		if opsSplit[0] != "now" {
310			return "", fmt.Errorf("invalid time selector %q", opsSplit[0])
311		}
312
313		result := now
314		switch len(opsSplit) {
315		case 1:
316			// return current time
317		case 2:
318			return "", errors.New("missing time operand")
319
320		case 3:
321			duration, err := time.ParseDuration(opsSplit[2])
322			if err != nil {
323				return "", errwrap.Wrapf("invalid duration: {{err}}", err)
324			}
325
326			switch opsSplit[1] {
327			case "plus":
328				result = result.Add(duration)
329			case "minus":
330				result = result.Add(-duration)
331			default:
332				return "", fmt.Errorf("invalid time operator %q", opsSplit[1])
333			}
334		}
335
336		return strconv.FormatInt(result.Unix(), 10), nil
337	}
338
339	switch {
340	case strings.HasPrefix(input, "identity.entity."):
341		if p.Entity == nil {
342			return "", ErrNoEntityAttachedToToken
343		}
344		return performEntityTemplating(strings.TrimPrefix(input, "identity.entity."))
345
346	case strings.HasPrefix(input, "identity.groups."):
347		if len(p.Groups) == 0 {
348			return "", ErrNoGroupsAttachedToToken
349		}
350		return performGroupsTemplating(strings.TrimPrefix(input, "identity.groups."))
351
352	case strings.HasPrefix(input, "time."):
353		return performTimeTemplating(strings.TrimPrefix(input, "time."))
354	}
355
356	return "", ErrTemplateValueNotFound
357}
358