1
2
3
4
5 package template
6
7 import (
8 "bytes"
9 "errors"
10 "fmt"
11 "io"
12 "net/url"
13 "reflect"
14 "strings"
15 "sync"
16 "unicode"
17 "unicode/utf8"
18 )
19
20
21
22
23
24
25
26
27
28
29
30
31 type FuncMap map[string]interface{}
32
33
34
35
36
37 func builtins() FuncMap {
38 return FuncMap{
39 "and": and,
40 "call": call,
41 "html": HTMLEscaper,
42 "index": index,
43 "slice": slice,
44 "js": JSEscaper,
45 "len": length,
46 "not": not,
47 "or": or,
48 "print": fmt.Sprint,
49 "printf": fmt.Sprintf,
50 "println": fmt.Sprintln,
51 "urlquery": URLQueryEscaper,
52
53
54 "eq": eq,
55 "ge": ge,
56 "gt": gt,
57 "le": le,
58 "lt": lt,
59 "ne": ne,
60 }
61 }
62
63 var builtinFuncsOnce struct {
64 sync.Once
65 v map[string]reflect.Value
66 }
67
68
69
70 func builtinFuncs() map[string]reflect.Value {
71 builtinFuncsOnce.Do(func() {
72 builtinFuncsOnce.v = createValueFuncs(builtins())
73 })
74 return builtinFuncsOnce.v
75 }
76
77
78 func createValueFuncs(funcMap FuncMap) map[string]reflect.Value {
79 m := make(map[string]reflect.Value)
80 addValueFuncs(m, funcMap)
81 return m
82 }
83
84
85 func addValueFuncs(out map[string]reflect.Value, in FuncMap) {
86 for name, fn := range in {
87 if !goodName(name) {
88 panic(fmt.Errorf("function name %q is not a valid identifier", name))
89 }
90 v := reflect.ValueOf(fn)
91 if v.Kind() != reflect.Func {
92 panic("value for " + name + " not a function")
93 }
94 if !goodFunc(v.Type()) {
95 panic(fmt.Errorf("can't install method/function %q with %d results", name, v.Type().NumOut()))
96 }
97 out[name] = v
98 }
99 }
100
101
102
103 func addFuncs(out, in FuncMap) {
104 for name, fn := range in {
105 out[name] = fn
106 }
107 }
108
109
110 func goodFunc(typ reflect.Type) bool {
111
112 switch {
113 case typ.NumOut() == 1:
114 return true
115 case typ.NumOut() == 2 && typ.Out(1) == errorType:
116 return true
117 }
118 return false
119 }
120
121
122 func goodName(name string) bool {
123 if name == "" {
124 return false
125 }
126 for i, r := range name {
127 switch {
128 case r == '_':
129 case i == 0 && !unicode.IsLetter(r):
130 return false
131 case !unicode.IsLetter(r) && !unicode.IsDigit(r):
132 return false
133 }
134 }
135 return true
136 }
137
138
139 func findFunction(name string, tmpl *Template) (reflect.Value, bool) {
140 if tmpl != nil && tmpl.common != nil {
141 tmpl.muFuncs.RLock()
142 defer tmpl.muFuncs.RUnlock()
143 if fn := tmpl.execFuncs[name]; fn.IsValid() {
144 return fn, true
145 }
146 }
147 if fn := builtinFuncs()[name]; fn.IsValid() {
148 return fn, true
149 }
150 return reflect.Value{}, false
151 }
152
153
154
155 func prepareArg(value reflect.Value, argType reflect.Type) (reflect.Value, error) {
156 if !value.IsValid() {
157 if !canBeNil(argType) {
158 return reflect.Value{}, fmt.Errorf("value is nil; should be of type %s", argType)
159 }
160 value = reflect.Zero(argType)
161 }
162 if value.Type().AssignableTo(argType) {
163 return value, nil
164 }
165 if intLike(value.Kind()) && intLike(argType.Kind()) && value.Type().ConvertibleTo(argType) {
166 value = value.Convert(argType)
167 return value, nil
168 }
169 return reflect.Value{}, fmt.Errorf("value has type %s; should be %s", value.Type(), argType)
170 }
171
172 func intLike(typ reflect.Kind) bool {
173 switch typ {
174 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
175 return true
176 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
177 return true
178 }
179 return false
180 }
181
182
183 func indexArg(index reflect.Value, cap int) (int, error) {
184 var x int64
185 switch index.Kind() {
186 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
187 x = index.Int()
188 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
189 x = int64(index.Uint())
190 case reflect.Invalid:
191 return 0, fmt.Errorf("cannot index slice/array with nil")
192 default:
193 return 0, fmt.Errorf("cannot index slice/array with type %s", index.Type())
194 }
195 if x < 0 || int(x) < 0 || int(x) > cap {
196 return 0, fmt.Errorf("index out of range: %d", x)
197 }
198 return int(x), nil
199 }
200
201
202
203
204
205
206 func index(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
207 item = indirectInterface(item)
208 if !item.IsValid() {
209 return reflect.Value{}, fmt.Errorf("index of untyped nil")
210 }
211 for _, index := range indexes {
212 index = indirectInterface(index)
213 var isNil bool
214 if item, isNil = indirect(item); isNil {
215 return reflect.Value{}, fmt.Errorf("index of nil pointer")
216 }
217 switch item.Kind() {
218 case reflect.Array, reflect.Slice, reflect.String:
219 x, err := indexArg(index, item.Len())
220 if err != nil {
221 return reflect.Value{}, err
222 }
223 item = item.Index(x)
224 case reflect.Map:
225 index, err := prepareArg(index, item.Type().Key())
226 if err != nil {
227 return reflect.Value{}, err
228 }
229 if x := item.MapIndex(index); x.IsValid() {
230 item = x
231 } else {
232 item = reflect.Zero(item.Type().Elem())
233 }
234 case reflect.Invalid:
235
236 panic("unreachable")
237 default:
238 return reflect.Value{}, fmt.Errorf("can't index item of type %s", item.Type())
239 }
240 }
241 return item, nil
242 }
243
244
245
246
247
248
249
250 func slice(item reflect.Value, indexes ...reflect.Value) (reflect.Value, error) {
251 item = indirectInterface(item)
252 if !item.IsValid() {
253 return reflect.Value{}, fmt.Errorf("slice of untyped nil")
254 }
255 if len(indexes) > 3 {
256 return reflect.Value{}, fmt.Errorf("too many slice indexes: %d", len(indexes))
257 }
258 var cap int
259 switch item.Kind() {
260 case reflect.String:
261 if len(indexes) == 3 {
262 return reflect.Value{}, fmt.Errorf("cannot 3-index slice a string")
263 }
264 cap = item.Len()
265 case reflect.Array, reflect.Slice:
266 cap = item.Cap()
267 default:
268 return reflect.Value{}, fmt.Errorf("can't slice item of type %s", item.Type())
269 }
270
271 idx := [3]int{0, item.Len()}
272 for i, index := range indexes {
273 x, err := indexArg(index, cap)
274 if err != nil {
275 return reflect.Value{}, err
276 }
277 idx[i] = x
278 }
279
280 if idx[0] > idx[1] {
281 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[0], idx[1])
282 }
283 if len(indexes) < 3 {
284 return item.Slice(idx[0], idx[1]), nil
285 }
286
287 if idx[1] > idx[2] {
288 return reflect.Value{}, fmt.Errorf("invalid slice index: %d > %d", idx[1], idx[2])
289 }
290 return item.Slice3(idx[0], idx[1], idx[2]), nil
291 }
292
293
294
295
296 func length(item reflect.Value) (int, error) {
297 item, isNil := indirect(item)
298 if isNil {
299 return 0, fmt.Errorf("len of nil pointer")
300 }
301 switch item.Kind() {
302 case reflect.Array, reflect.Chan, reflect.Map, reflect.Slice, reflect.String:
303 return item.Len(), nil
304 }
305 return 0, fmt.Errorf("len of type %s", item.Type())
306 }
307
308
309
310
311
312 func call(fn reflect.Value, args ...reflect.Value) (reflect.Value, error) {
313 fn = indirectInterface(fn)
314 if !fn.IsValid() {
315 return reflect.Value{}, fmt.Errorf("call of nil")
316 }
317 typ := fn.Type()
318 if typ.Kind() != reflect.Func {
319 return reflect.Value{}, fmt.Errorf("non-function of type %s", typ)
320 }
321 if !goodFunc(typ) {
322 return reflect.Value{}, fmt.Errorf("function called with %d args; should be 1 or 2", typ.NumOut())
323 }
324 numIn := typ.NumIn()
325 var dddType reflect.Type
326 if typ.IsVariadic() {
327 if len(args) < numIn-1 {
328 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want at least %d", len(args), numIn-1)
329 }
330 dddType = typ.In(numIn - 1).Elem()
331 } else {
332 if len(args) != numIn {
333 return reflect.Value{}, fmt.Errorf("wrong number of args: got %d want %d", len(args), numIn)
334 }
335 }
336 argv := make([]reflect.Value, len(args))
337 for i, arg := range args {
338 arg = indirectInterface(arg)
339
340 argType := dddType
341 if !typ.IsVariadic() || i < numIn-1 {
342 argType = typ.In(i)
343 }
344
345 var err error
346 if argv[i], err = prepareArg(arg, argType); err != nil {
347 return reflect.Value{}, fmt.Errorf("arg %d: %s", i, err)
348 }
349 }
350 return safeCall(fn, argv)
351 }
352
353
354
355 func safeCall(fun reflect.Value, args []reflect.Value) (val reflect.Value, err error) {
356 defer func() {
357 if r := recover(); r != nil {
358 if e, ok := r.(error); ok {
359 err = e
360 } else {
361 err = fmt.Errorf("%v", r)
362 }
363 }
364 }()
365 ret := fun.Call(args)
366 if len(ret) == 2 && !ret[1].IsNil() {
367 return ret[0], ret[1].Interface().(error)
368 }
369 return ret[0], nil
370 }
371
372
373
374 func truth(arg reflect.Value) bool {
375 t, _ := isTrue(indirectInterface(arg))
376 return t
377 }
378
379
380
381 func and(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
382 if !truth(arg0) {
383 return arg0
384 }
385 for i := range args {
386 arg0 = args[i]
387 if !truth(arg0) {
388 break
389 }
390 }
391 return arg0
392 }
393
394
395
396 func or(arg0 reflect.Value, args ...reflect.Value) reflect.Value {
397 if truth(arg0) {
398 return arg0
399 }
400 for i := range args {
401 arg0 = args[i]
402 if truth(arg0) {
403 break
404 }
405 }
406 return arg0
407 }
408
409
410 func not(arg reflect.Value) bool {
411 return !truth(arg)
412 }
413
414
415
416
417
418 var (
419 errBadComparisonType = errors.New("invalid type for comparison")
420 errBadComparison = errors.New("incompatible types for comparison")
421 errNoComparison = errors.New("missing argument for comparison")
422 )
423
424 type kind int
425
426 const (
427 invalidKind kind = iota
428 boolKind
429 complexKind
430 intKind
431 floatKind
432 stringKind
433 uintKind
434 )
435
436 func basicKind(v reflect.Value) (kind, error) {
437 switch v.Kind() {
438 case reflect.Bool:
439 return boolKind, nil
440 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
441 return intKind, nil
442 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr:
443 return uintKind, nil
444 case reflect.Float32, reflect.Float64:
445 return floatKind, nil
446 case reflect.Complex64, reflect.Complex128:
447 return complexKind, nil
448 case reflect.String:
449 return stringKind, nil
450 }
451 return invalidKind, errBadComparisonType
452 }
453
454
455 func eq(arg1 reflect.Value, arg2 ...reflect.Value) (bool, error) {
456 arg1 = indirectInterface(arg1)
457 if arg1 != zero {
458 if t1 := arg1.Type(); !t1.Comparable() {
459 return false, fmt.Errorf("uncomparable type %s: %v", t1, arg1)
460 }
461 }
462 if len(arg2) == 0 {
463 return false, errNoComparison
464 }
465 k1, _ := basicKind(arg1)
466 for _, arg := range arg2 {
467 arg = indirectInterface(arg)
468 k2, _ := basicKind(arg)
469 truth := false
470 if k1 != k2 {
471
472 switch {
473 case k1 == intKind && k2 == uintKind:
474 truth = arg1.Int() >= 0 && uint64(arg1.Int()) == arg.Uint()
475 case k1 == uintKind && k2 == intKind:
476 truth = arg.Int() >= 0 && arg1.Uint() == uint64(arg.Int())
477 default:
478 return false, errBadComparison
479 }
480 } else {
481 switch k1 {
482 case boolKind:
483 truth = arg1.Bool() == arg.Bool()
484 case complexKind:
485 truth = arg1.Complex() == arg.Complex()
486 case floatKind:
487 truth = arg1.Float() == arg.Float()
488 case intKind:
489 truth = arg1.Int() == arg.Int()
490 case stringKind:
491 truth = arg1.String() == arg.String()
492 case uintKind:
493 truth = arg1.Uint() == arg.Uint()
494 default:
495 if arg == zero {
496 truth = arg1 == arg
497 } else {
498 if t2 := arg.Type(); !t2.Comparable() {
499 return false, fmt.Errorf("uncomparable type %s: %v", t2, arg)
500 }
501 truth = arg1.Interface() == arg.Interface()
502 }
503 }
504 }
505 if truth {
506 return true, nil
507 }
508 }
509 return false, nil
510 }
511
512
513 func ne(arg1, arg2 reflect.Value) (bool, error) {
514
515 equal, err := eq(arg1, arg2)
516 return !equal, err
517 }
518
519
520 func lt(arg1, arg2 reflect.Value) (bool, error) {
521 arg1 = indirectInterface(arg1)
522 k1, err := basicKind(arg1)
523 if err != nil {
524 return false, err
525 }
526 arg2 = indirectInterface(arg2)
527 k2, err := basicKind(arg2)
528 if err != nil {
529 return false, err
530 }
531 truth := false
532 if k1 != k2 {
533
534 switch {
535 case k1 == intKind && k2 == uintKind:
536 truth = arg1.Int() < 0 || uint64(arg1.Int()) < arg2.Uint()
537 case k1 == uintKind && k2 == intKind:
538 truth = arg2.Int() >= 0 && arg1.Uint() < uint64(arg2.Int())
539 default:
540 return false, errBadComparison
541 }
542 } else {
543 switch k1 {
544 case boolKind, complexKind:
545 return false, errBadComparisonType
546 case floatKind:
547 truth = arg1.Float() < arg2.Float()
548 case intKind:
549 truth = arg1.Int() < arg2.Int()
550 case stringKind:
551 truth = arg1.String() < arg2.String()
552 case uintKind:
553 truth = arg1.Uint() < arg2.Uint()
554 default:
555 panic("invalid kind")
556 }
557 }
558 return truth, nil
559 }
560
561
562 func le(arg1, arg2 reflect.Value) (bool, error) {
563
564 lessThan, err := lt(arg1, arg2)
565 if lessThan || err != nil {
566 return lessThan, err
567 }
568 return eq(arg1, arg2)
569 }
570
571
572 func gt(arg1, arg2 reflect.Value) (bool, error) {
573
574 lessOrEqual, err := le(arg1, arg2)
575 if err != nil {
576 return false, err
577 }
578 return !lessOrEqual, nil
579 }
580
581
582 func ge(arg1, arg2 reflect.Value) (bool, error) {
583
584 lessThan, err := lt(arg1, arg2)
585 if err != nil {
586 return false, err
587 }
588 return !lessThan, nil
589 }
590
591
592
593 var (
594 htmlQuot = []byte(""")
595 htmlApos = []byte("'")
596 htmlAmp = []byte("&")
597 htmlLt = []byte("<")
598 htmlGt = []byte(">")
599 htmlNull = []byte("\uFFFD")
600 )
601
602
603 func HTMLEscape(w io.Writer, b []byte) {
604 last := 0
605 for i, c := range b {
606 var html []byte
607 switch c {
608 case '\000':
609 html = htmlNull
610 case '"':
611 html = htmlQuot
612 case '\'':
613 html = htmlApos
614 case '&':
615 html = htmlAmp
616 case '<':
617 html = htmlLt
618 case '>':
619 html = htmlGt
620 default:
621 continue
622 }
623 w.Write(b[last:i])
624 w.Write(html)
625 last = i + 1
626 }
627 w.Write(b[last:])
628 }
629
630
631 func HTMLEscapeString(s string) string {
632
633 if !strings.ContainsAny(s, "'\"&<>\000") {
634 return s
635 }
636 var b bytes.Buffer
637 HTMLEscape(&b, []byte(s))
638 return b.String()
639 }
640
641
642
643 func HTMLEscaper(args ...interface{}) string {
644 return HTMLEscapeString(evalArgs(args))
645 }
646
647
648
649 var (
650 jsLowUni = []byte(`\u00`)
651 hex = []byte("0123456789ABCDEF")
652
653 jsBackslash = []byte(`\\`)
654 jsApos = []byte(`\'`)
655 jsQuot = []byte(`\"`)
656 jsLt = []byte(`\u003C`)
657 jsGt = []byte(`\u003E`)
658 jsAmp = []byte(`\u0026`)
659 jsEq = []byte(`\u003D`)
660 )
661
662
663 func JSEscape(w io.Writer, b []byte) {
664 last := 0
665 for i := 0; i < len(b); i++ {
666 c := b[i]
667
668 if !jsIsSpecial(rune(c)) {
669
670 continue
671 }
672 w.Write(b[last:i])
673
674 if c < utf8.RuneSelf {
675
676
677 switch c {
678 case '\\':
679 w.Write(jsBackslash)
680 case '\'':
681 w.Write(jsApos)
682 case '"':
683 w.Write(jsQuot)
684 case '<':
685 w.Write(jsLt)
686 case '>':
687 w.Write(jsGt)
688 case '&':
689 w.Write(jsAmp)
690 case '=':
691 w.Write(jsEq)
692 default:
693 w.Write(jsLowUni)
694 t, b := c>>4, c&0x0f
695 w.Write(hex[t : t+1])
696 w.Write(hex[b : b+1])
697 }
698 } else {
699
700 r, size := utf8.DecodeRune(b[i:])
701 if unicode.IsPrint(r) {
702 w.Write(b[i : i+size])
703 } else {
704 fmt.Fprintf(w, "\\u%04X", r)
705 }
706 i += size - 1
707 }
708 last = i + 1
709 }
710 w.Write(b[last:])
711 }
712
713
714 func JSEscapeString(s string) string {
715
716 if strings.IndexFunc(s, jsIsSpecial) < 0 {
717 return s
718 }
719 var b bytes.Buffer
720 JSEscape(&b, []byte(s))
721 return b.String()
722 }
723
724 func jsIsSpecial(r rune) bool {
725 switch r {
726 case '\\', '\'', '"', '<', '>', '&', '=':
727 return true
728 }
729 return r < ' ' || utf8.RuneSelf <= r
730 }
731
732
733
734 func JSEscaper(args ...interface{}) string {
735 return JSEscapeString(evalArgs(args))
736 }
737
738
739
740 func URLQueryEscaper(args ...interface{}) string {
741 return url.QueryEscape(evalArgs(args))
742 }
743
744
745
746
747
748
749 func evalArgs(args []interface{}) string {
750 ok := false
751 var s string
752
753 if len(args) == 1 {
754 s, ok = args[0].(string)
755 }
756 if !ok {
757 for i, arg := range args {
758 a, ok := printableValue(reflect.ValueOf(arg))
759 if ok {
760 args[i] = a
761 }
762 }
763 s = fmt.Sprint(args...)
764 }
765 return s
766 }
767
View as plain text