1// Copyright 2011 The Go Authors. All rights reserved.
2// Use of this source code is governed by a BSD-style
3// license that can be found in the LICENSE file.
4
5package template
6
7import (
8	"bytes"
9	"errors"
10	"fmt"
11	"io"
12	"net/url"
13	"reflect"
14	"strings"
15	"unicode"
16	"unicode/utf8"
17)
18
19// FuncMap is the type of the map defining the mapping from names to functions.
20// Each function must have either a single return value, or two return values of
21// which the second has type error. In that case, if the second (error)
22// return value evaluates to non-nil during execution, execution terminates and
23// Execute returns that error.
24type FuncMap map[string]interface{}
25
26var builtins = FuncMap{
27	"and":      and,
28	"call":     call,
29	"html":     HTMLEscaper,
30	"index":    index,
31	"js":       JSEscaper,
32	"len":      length,
33	"not":      not,
34	"or":       or,
35	"print":    fmt.Sprint,
36	"printf":   fmt.Sprintf,
37	"println":  fmt.Sprintln,
38	"urlquery": URLQueryEscaper,
39
40	// Comparisons
41	"eq": eq, // ==
42	"ge": ge, // >=
43	"gt": gt, // >
44	"le": le, // <=
45	"lt": lt, // <
46	"ne": ne, // !=
47}
48
49var builtinFuncs = createValueFuncs(builtins)
50
51// createValueFuncs turns a FuncMap into a map[string]reflect.Value
52func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
53	m := make(map[string]reflect.Value)
54	addValueFuncs(m, funcMap)
55	return m
56}
57
58// addValueFuncs adds to values the functions in funcs, converting them to reflect.Values.
59func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
60	for name, fn := range in {
61		v := reflect.ValueOf(fn)
62		if v.Kind() != reflect.Func {
63			panic("value for " + name + " not a function")
64		}
65		if !goodFunc(v.Type()) {
66			panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
67		}
68		out[name] = v
69	}
70}
71
72// addFuncs adds to values the functions in funcs. It does no checking of the input -
73// call addValueFuncs first.
74func addFuncs(out, in FuncMap) {
75	for name, fn := range in {
76		out[name] = fn
77	}
78}
79
80// goodFunc checks that the function or method has the right result signature.
81func goodFunc(typ reflect.Type) bool {
82	// We allow functions with 1 result or 2 results where the second is an error.
83	switch {
84	case typ.NumOut() == 1:
85		return true
86	case typ.NumOut() == 2 && typ.Out(1) == errorType:
87		return true
88	}
89	return false
90}
91
92// findFunction looks for a function in the template, and global map.
93func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
94	if tmpl != nil && tmpl.common != nil {
95		if fn := tmpl.execFuncs[name]; fn.IsValid() {
96			return fn, true
97		}
98	}
99	if fn := builtinFuncs[name]; fn.IsValid() {
100		return fn, true
101	}
102	return reflect.Value{}, false
103}
104
105// Indexing.
106
107// index returns the result of indexing its first argument by the following
108// arguments.  Thus "index x 1 2 3" is, in Go syntax, x[1][2][3]. Each
109// indexed item must be a map, slice, or array.
110func index(item interface{}, indices ...interface{}) (interface{}, error) {
111	v := reflect.ValueOf(item)
112	for _, i := range indices {
113		index := reflect.ValueOf(i)
114		var isNil bool
115		if v, isNil = indirect(v); isNil {
116			return nil, fmt.Errorf("index of nil pointer")
117		}
118		switch v.Kind() {
119		case reflect.Array, reflect.Slice, reflect.String:
120			var x int64
121			switch index.Kind() {
122			case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
123				x = index.Int()
124			case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
125				x = int64(index.Uint())
126			default:
127				return nil, fmt.Errorf("cannot index slice/array with type %s", index.Type())
128			}
129			if x < 0 || x >= int64(v.Len()) {
130				return nil, fmt.Errorf("index out of range: %d", x)
131			}
132			v = v.Index(int(x))
133		case reflect.Map:
134			if !index.IsValid() {
135				index = reflect.Zero(v.Type().Key())
136			}
137			if !index.Type().AssignableTo(v.Type().Key()) {
138				return nil, fmt.Errorf("%s is not index type for %s", index.Type(), v.Type())
139			}
140			if x := v.MapIndex(index); x.IsValid() {
141				v = x
142			} else {
143				v = reflect.Zero(v.Type().Elem())
144			}
145		default:
146			return nil, fmt.Errorf("can't index item of type %s", v.Type())
147		}
148	}
149	return v.Interface(), nil
150}
151
152// Length
153
154// length returns the length of the item, with an error if it has no defined length.
155func length(item interface{}) (int, error) {
156	v, isNil := indirect(reflect.ValueOf(item))
157	if isNil {
158		return 0, fmt.Errorf("len of nil pointer")
159	}
160	switch v.Kind() {
161	case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
162		return v.Len(), nil
163	}
164	return 0, fmt.Errorf("len of type %s", v.Type())
165}
166
167// Function invocation
168
169// call returns the result of evaluating the first argument as a function.
170// The function must return 1 result, or 2 results, the second of which is an error.
171func call(fn interface{}, args ...interface{}) (interface{}, error) {
172	v := reflect.ValueOf(fn)
173	typ := v.Type()
174	if typ.Kind() != reflect.Func {
175		return nil, fmt.Errorf("non-function of type %s", typ)
176	}
177	if !goodFunc(typ) {
178		return nil, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
179	}
180	numIn := typ.NumIn()
181	var dddType reflect.Type
182	if typ.IsVariadic() {
183		if len(args) < numIn-1 {
184			return nil, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
185		}
186		dddType = typ.In(numIn - 1).Elem()
187	} else {
188		if len(args) != numIn {
189			return nil, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
190		}
191	}
192	argv := make([]reflect.Value, len(args))
193	for i, arg := range args {
194		value := reflect.ValueOf(arg)
195		// Compute the expected type. Clumsy because of variadics.
196		var argType reflect.Type
197		if !typ.IsVariadic() || i < numIn-1 {
198			argType = typ.In(i)
199		} else {
200			argType = dddType
201		}
202		if !value.IsValid() && canBeNil(argType) {
203			value = reflect.Zero(argType)
204		}
205		if !value.Type().AssignableTo(argType) {
206			return nil, fmt.Errorf("arg %d has type %s; should be %s", i, value.Type(), argType)
207		}
208		argv[i] = value
209	}
210	result := v.Call(argv)
211	if len(result) == 2 && !result[1].IsNil() {
212		return result[0].Interface(), result[1].Interface().(error)
213	}
214	return result[0].Interface(), nil
215}
216
217// Boolean logic.
218
219func truth(a interface{}) bool {
220	t, _ := isTrue(reflect.ValueOf(a))
221	return t
222}
223
224// and computes the Boolean AND of its arguments, returning
225// the first false argument it encounters, or the last argument.
226func and(arg0 interface{}, args ...interface{}) interface{} {
227	if !truth(arg0) {
228		return arg0
229	}
230	for i := range args {
231		arg0 = args[i]
232		if !truth(arg0) {
233			break
234		}
235	}
236	return arg0
237}
238
239// or computes the Boolean OR of its arguments, returning
240// the first true argument it encounters, or the last argument.
241func or(arg0 interface{}, args ...interface{}) interface{} {
242	if truth(arg0) {
243		return arg0
244	}
245	for i := range args {
246		arg0 = args[i]
247		if truth(arg0) {
248			break
249		}
250	}
251	return arg0
252}
253
254// not returns the Boolean negation of its argument.
255func not(arg interface{}) (truth bool) {
256	truth, _ = isTrue(reflect.ValueOf(arg))
257	return !truth
258}
259
260// Comparison.
261
262// TODO: Perhaps allow comparison between signed and unsigned integers.
263
264var (
265	errBadComparisonType = errors.New("invalid type for comparison")
266	errBadComparison     = errors.New("incompatible types for comparison")
267	errNoComparison      = errors.New("missing argument for comparison")
268)
269
270type kind int
271
272const (
273	invalidKind kind = iota
274	boolKind
275	complexKind
276	intKind
277	floatKind
278	integerKind
279	stringKind
280	uintKind
281)
282
283func basicKind(v reflect.Value) (kind, error) {
284	switch v.Kind() {
285	case reflect.Bool:
286		return boolKind, nil
287	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
288		return intKind, nil
289	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
290		return uintKind, nil
291	case reflect.Float32, reflect.Float64:
292		return floatKind, nil
293	case reflect.Complex64, reflect.Complex128:
294		return complexKind, nil
295	case reflect.String:
296		return stringKind, nil
297	}
298	return invalidKind, errBadComparisonType
299}
300
301// eq evaluates the comparison a == b || a == c || ...
302func eq(arg1 interface{}, arg2 ...interface{}) (bool, error) {
303	v1 := reflect.ValueOf(arg1)
304	k1, err := basicKind(v1)
305	if err != nil {
306		return false, err
307	}
308	if len(arg2) == 0 {
309		return false, errNoComparison
310	}
311	for _, arg := range arg2 {
312		v2 := reflect.ValueOf(arg)
313		k2, err := basicKind(v2)
314		if err != nil {
315			return false, err
316		}
317		truth := false
318		if k1 != k2 {
319			// Special case: Can compare integer values regardless of type's sign.
320			switch {
321			case k1 == intKind && k2 == uintKind:
322				truth = v1.Int() >= 0 && uint64(v1.Int()) == v2.Uint()
323			case k1 == uintKind && k2 == intKind:
324				truth = v2.Int() >= 0 && v1.Uint() == uint64(v2.Int())
325			default:
326				return false, errBadComparison
327			}
328		} else {
329			switch k1 {
330			case boolKind:
331				truth = v1.Bool() == v2.Bool()
332			case complexKind:
333				truth = v1.Complex() == v2.Complex()
334			case floatKind:
335				truth = v1.Float() == v2.Float()
336			case intKind:
337				truth = v1.Int() == v2.Int()
338			case stringKind:
339				truth = v1.String() == v2.String()
340			case uintKind:
341				truth = v1.Uint() == v2.Uint()
342			default:
343				panic("invalid kind")
344			}
345		}
346		if truth {
347			return true, nil
348		}
349	}
350	return false, nil
351}
352
353// ne evaluates the comparison a != b.
354func ne(arg1, arg2 interface{}) (bool, error) {
355	// != is the inverse of ==.
356	equal, err := eq(arg1, arg2)
357	return !equal, err
358}
359
360// lt evaluates the comparison a < b.
361func lt(arg1, arg2 interface{}) (bool, error) {
362	v1 := reflect.ValueOf(arg1)
363	k1, err := basicKind(v1)
364	if err != nil {
365		return false, err
366	}
367	v2 := reflect.ValueOf(arg2)
368	k2, err := basicKind(v2)
369	if err != nil {
370		return false, err
371	}
372	truth := false
373	if k1 != k2 {
374		// Special case: Can compare integer values regardless of type's sign.
375		switch {
376		case k1 == intKind && k2 == uintKind:
377			truth = v1.Int() < 0 || uint64(v1.Int()) < v2.Uint()
378		case k1 == uintKind && k2 == intKind:
379			truth = v2.Int() >= 0 && v1.Uint() < uint64(v2.Int())
380		default:
381			return false, errBadComparison
382		}
383	} else {
384		switch k1 {
385		case boolKind, complexKind:
386			return false, errBadComparisonType
387		case floatKind:
388			truth = v1.Float() < v2.Float()
389		case intKind:
390			truth = v1.Int() < v2.Int()
391		case stringKind:
392			truth = v1.String() < v2.String()
393		case uintKind:
394			truth = v1.Uint() < v2.Uint()
395		default:
396			panic("invalid kind")
397		}
398	}
399	return truth, nil
400}
401
402// le evaluates the comparison <= b.
403func le(arg1, arg2 interface{}) (bool, error) {
404	// <= is < or ==.
405	lessThan, err := lt(arg1, arg2)
406	if lessThan || err != nil {
407		return lessThan, err
408	}
409	return eq(arg1, arg2)
410}
411
412// gt evaluates the comparison a > b.
413func gt(arg1, arg2 interface{}) (bool, error) {
414	// > is the inverse of <=.
415	lessOrEqual, err := le(arg1, arg2)
416	if err != nil {
417		return false, err
418	}
419	return !lessOrEqual, nil
420}
421
422// ge evaluates the comparison a >= b.
423func ge(arg1, arg2 interface{}) (bool, error) {
424	// >= is the inverse of <.
425	lessThan, err := lt(arg1, arg2)
426	if err != nil {
427		return false, err
428	}
429	return !lessThan, nil
430}
431
432// HTML escaping.
433
434var (
435	htmlQuot = []byte("&#34;") // shorter than "&quot;"
436	htmlApos = []byte("&#39;") // shorter than "&apos;" and apos was not in HTML until HTML5
437	htmlAmp  = []byte("&amp;")
438	htmlLt   = []byte("&lt;")
439	htmlGt   = []byte("&gt;")
440)
441
442// HTMLEscape writes to w the escaped HTML equivalent of the plain text data b.
443func HTMLEscape(w io.Writer, b []byte) {
444	last := 0
445	for i, c := range b {
446		var html []byte
447		switch c {
448		case '"':
449			html = htmlQuot
450		case '\'':
451			html = htmlApos
452		case '&':
453			html = htmlAmp
454		case '<':
455			html = htmlLt
456		case '>':
457			html = htmlGt
458		default:
459			continue
460		}
461		w.Write(b[last:i])
462		w.Write(html)
463		last = i + 1
464	}
465	w.Write(b[last:])
466}
467
468// HTMLEscapeString returns the escaped HTML equivalent of the plain text data s.
469func HTMLEscapeString(s string) string {
470	// Avoid allocation if we can.
471	if strings.IndexAny(s, `'"&<>`) < 0 {
472		return s
473	}
474	var b bytes.Buffer
475	HTMLEscape(&b, []byte(s))
476	return b.String()
477}
478
479// HTMLEscaper returns the escaped HTML equivalent of the textual
480// representation of its arguments.
481func HTMLEscaper(args ...interface{}) string {
482	return HTMLEscapeString(evalArgs(args))
483}
484
485// JavaScript escaping.
486
487var (
488	jsLowUni = []byte(`\u00`)
489	hex      = []byte("0123456789ABCDEF")
490
491	jsBackslash = []byte(`\\`)
492	jsApos      = []byte(`\'`)
493	jsQuot      = []byte(`\"`)
494	jsLt        = []byte(`\x3C`)
495	jsGt        = []byte(`\x3E`)
496)
497
498// JSEscape writes to w the escaped JavaScript equivalent of the plain text data b.
499func JSEscape(w io.Writer, b []byte) {
500	last := 0
501	for i := 0; i < len(b); i++ {
502		c := b[i]
503
504		if !jsIsSpecial(rune(c)) {
505			// fast path: nothing to do
506			continue
507		}
508		w.Write(b[last:i])
509
510		if c < utf8.RuneSelf {
511			// Quotes, slashes and angle brackets get quoted.
512			// Control characters get written as \u00XX.
513			switch c {
514			case '\\':
515				w.Write(jsBackslash)
516			case '\'':
517				w.Write(jsApos)
518			case '"':
519				w.Write(jsQuot)
520			case '<':
521				w.Write(jsLt)
522			case '>':
523				w.Write(jsGt)
524			default:
525				w.Write(jsLowUni)
526				t, b := c>>4, c&0x0f
527				w.Write(hex[t : t+1])
528				w.Write(hex[b : b+1])
529			}
530		} else {
531			// Unicode rune.
532			r, size := utf8.DecodeRune(b[i:])
533			if unicode.IsPrint(r) {
534				w.Write(b[i : i+size])
535			} else {
536				fmt.Fprintf(w, "\\u%04X", r)
537			}
538			i += size - 1
539		}
540		last = i + 1
541	}
542	w.Write(b[last:])
543}
544
545// JSEscapeString returns the escaped JavaScript equivalent of the plain text data s.
546func JSEscapeString(s string) string {
547	// Avoid allocation if we can.
548	if strings.IndexFunc(s, jsIsSpecial) < 0 {
549		return s
550	}
551	var b bytes.Buffer
552	JSEscape(&b, []byte(s))
553	return b.String()
554}
555
556func jsIsSpecial(r rune) bool {
557	switch r {
558	case '\\', '\'', '"', '<', '>':
559		return true
560	}
561	return r < ' ' || utf8.RuneSelf <= r
562}
563
564// JSEscaper returns the escaped JavaScript equivalent of the textual
565// representation of its arguments.
566func JSEscaper(args ...interface{}) string {
567	return JSEscapeString(evalArgs(args))
568}
569
570// URLQueryEscaper returns the escaped value of the textual representation of
571// its arguments in a form suitable for embedding in a URL query.
572func URLQueryEscaper(args ...interface{}) string {
573	return url.QueryEscape(evalArgs(args))
574}
575
576// evalArgs formats the list of arguments into a string. It is therefore equivalent to
577//	fmt.Sprint(args...)
578// except that each argument is indirected (if a pointer), as required,
579// using the same rules as the default string evaluation during template
580// execution.
581func evalArgs(args []interface{}) string {
582	ok := false
583	var s string
584	// Fast path for simple common case.
585	if len(args) == 1 {
586		s, ok = args[0].(string)
587	}
588	if !ok {
589		for i, arg := range args {
590			a, ok := printableValue(reflect.ValueOf(arg))
591			if ok {
592				args[i] = a
593			} // else left fmt do its thing
594		}
595		s = fmt.Sprint(args...)
596	}
597	return s
598}
599