1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16 package sql
17
18 import (
19 "context"
20 "database/sql/driver"
21 "errors"
22 "fmt"
23 "io"
24 "reflect"
25 "runtime"
26 "sort"
27 "strconv"
28 "sync"
29 "sync/atomic"
30 "time"
31 )
32
33 var (
34 driversMu sync.RWMutex
35 drivers = make(map[string]driver.Driver)
36 )
37
38
39 var nowFunc = time.Now
40
41
42
43
44 func Register(name string, driver driver.Driver) {
45 driversMu.Lock()
46 defer driversMu.Unlock()
47 if driver == nil {
48 panic("sql: Register driver is nil")
49 }
50 if _, dup := drivers[name]; dup {
51 panic("sql: Register called twice for driver " + name)
52 }
53 drivers[name] = driver
54 }
55
56 func unregisterAllDrivers() {
57 driversMu.Lock()
58 defer driversMu.Unlock()
59
60 drivers = make(map[string]driver.Driver)
61 }
62
63
64 func Drivers() []string {
65 driversMu.RLock()
66 defer driversMu.RUnlock()
67 var list []string
68 for name := range drivers {
69 list = append(list, name)
70 }
71 sort.Strings(list)
72 return list
73 }
74
75
76
77
78
79
80
81 type NamedArg struct {
82 _Named_Fields_Required struct{}
83
84
85
86
87
88
89
90 Name string
91
92
93
94
95 Value interface{}
96 }
97
98
99
100
101
102
103
104
105
106
107
108
109
110 func Named(name string, value interface{}) NamedArg {
111
112
113
114
115 return NamedArg{Name: name, Value: value}
116 }
117
118
119 type IsolationLevel int
120
121
122
123
124
125 const (
126 LevelDefault IsolationLevel = iota
127 LevelReadUncommitted
128 LevelReadCommitted
129 LevelWriteCommitted
130 LevelRepeatableRead
131 LevelSnapshot
132 LevelSerializable
133 LevelLinearizable
134 )
135
136 func (i IsolationLevel) String() string {
137 switch i {
138 case LevelDefault:
139 return "Default"
140 case LevelReadUncommitted:
141 return "Read Uncommitted"
142 case LevelReadCommitted:
143 return "Read Committed"
144 case LevelWriteCommitted:
145 return "Write Committed"
146 case LevelRepeatableRead:
147 return "Repeatable Read"
148 case LevelSnapshot:
149 return "Snapshot"
150 case LevelSerializable:
151 return "Serializable"
152 case LevelLinearizable:
153 return "Linearizable"
154 default:
155 return "IsolationLevel(" + strconv.Itoa(int(i)) + ")"
156 }
157 }
158
159 var _ fmt.Stringer = LevelDefault
160
161
162 type TxOptions struct {
163
164
165 Isolation IsolationLevel
166 ReadOnly bool
167 }
168
169
170
171
172 type RawBytes []byte
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187 type NullString struct {
188 String string
189 Valid bool
190 }
191
192
193 func (ns *NullString) Scan(value interface{}) error {
194 if value == nil {
195 ns.String, ns.Valid = "", false
196 return nil
197 }
198 ns.Valid = true
199 return convertAssign(&ns.String, value)
200 }
201
202
203 func (ns NullString) Value() (driver.Value, error) {
204 if !ns.Valid {
205 return nil, nil
206 }
207 return ns.String, nil
208 }
209
210
211
212
213 type NullInt64 struct {
214 Int64 int64
215 Valid bool
216 }
217
218
219 func (n *NullInt64) Scan(value interface{}) error {
220 if value == nil {
221 n.Int64, n.Valid = 0, false
222 return nil
223 }
224 n.Valid = true
225 return convertAssign(&n.Int64, value)
226 }
227
228
229 func (n NullInt64) Value() (driver.Value, error) {
230 if !n.Valid {
231 return nil, nil
232 }
233 return n.Int64, nil
234 }
235
236
237
238
239 type NullFloat64 struct {
240 Float64 float64
241 Valid bool
242 }
243
244
245 func (n *NullFloat64) Scan(value interface{}) error {
246 if value == nil {
247 n.Float64, n.Valid = 0, false
248 return nil
249 }
250 n.Valid = true
251 return convertAssign(&n.Float64, value)
252 }
253
254
255 func (n NullFloat64) Value() (driver.Value, error) {
256 if !n.Valid {
257 return nil, nil
258 }
259 return n.Float64, nil
260 }
261
262
263
264
265 type NullBool struct {
266 Bool bool
267 Valid bool
268 }
269
270
271 func (n *NullBool) Scan(value interface{}) error {
272 if value == nil {
273 n.Bool, n.Valid = false, false
274 return nil
275 }
276 n.Valid = true
277 return convertAssign(&n.Bool, value)
278 }
279
280
281 func (n NullBool) Value() (driver.Value, error) {
282 if !n.Valid {
283 return nil, nil
284 }
285 return n.Bool, nil
286 }
287
288
289 type Scanner interface {
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308 Scan(src interface{}) error
309 }
310
311
312
313
314
315
316
317
318
319 type Out struct {
320 _Named_Fields_Required struct{}
321
322
323
324 Dest interface{}
325
326
327
328
329 In bool
330 }
331
332
333
334
335 var ErrNoRows = errors.New("sql: no rows in result set")
336
337
338
339
340
341
342
343
344
345
346
347
348
349 type DB struct {
350
351
352 waitDuration int64
353
354 connector driver.Connector
355
356
357
358 numClosed uint64
359
360 mu sync.Mutex
361 freeConn []*driverConn
362 connRequests map[uint64]chan connRequest
363 nextRequest uint64
364 numOpen int
365
366
367
368
369
370 openerCh chan struct{}
371 resetterCh chan *driverConn
372 closed bool
373 dep map[finalCloser]depSet
374 lastPut map[*driverConn]string
375 maxIdle int
376 maxOpen int
377 maxLifetime time.Duration
378 cleanerCh chan struct{}
379 waitCount int64
380 maxIdleClosed int64
381 maxLifetimeClosed int64
382
383 stop func()
384 }
385
386
387 type connReuseStrategy uint8
388
389 const (
390
391 alwaysNewConn connReuseStrategy = iota
392
393
394
395 cachedOrNewConn
396 )
397
398
399
400
401
402 type driverConn struct {
403 db *DB
404 createdAt time.Time
405
406 sync.Mutex
407 ci driver.Conn
408 closed bool
409 finalClosed bool
410 openStmt map[*driverStmt]bool
411 lastErr error
412
413
414 inUse bool
415 onPut []func()
416 dbmuClosed bool
417 }
418
419 func (dc *driverConn) releaseConn(err error) {
420 dc.db.putConn(dc, err, true)
421 }
422
423 func (dc *driverConn) removeOpenStmt(ds *driverStmt) {
424 dc.Lock()
425 defer dc.Unlock()
426 delete(dc.openStmt, ds)
427 }
428
429 func (dc *driverConn) expired(timeout time.Duration) bool {
430 if timeout <= 0 {
431 return false
432 }
433 return dc.createdAt.Add(timeout).Before(nowFunc())
434 }
435
436
437
438 func (dc *driverConn) prepareLocked(ctx context.Context, cg stmtConnGrabber, query string) (*driverStmt, error) {
439 si, err := ctxDriverPrepare(ctx, dc.ci, query)
440 if err != nil {
441 return nil, err
442 }
443 ds := &driverStmt{Locker: dc, si: si}
444
445
446 if cg != nil {
447 return ds, nil
448 }
449
450
451
452
453
454 if dc.openStmt == nil {
455 dc.openStmt = make(map[*driverStmt]bool)
456 }
457 dc.openStmt[ds] = true
458 return ds, nil
459 }
460
461
462
463
464
465
466 func (dc *driverConn) resetSession(ctx context.Context) {
467 defer dc.Unlock()
468 if dc.closed {
469 return
470 }
471 dc.lastErr = dc.ci.(driver.SessionResetter).ResetSession(ctx)
472 }
473
474
475 func (dc *driverConn) closeDBLocked() func() error {
476 dc.Lock()
477 defer dc.Unlock()
478 if dc.closed {
479 return func() error { return errors.New("sql: duplicate driverConn close") }
480 }
481 dc.closed = true
482 return dc.db.removeDepLocked(dc, dc)
483 }
484
485 func (dc *driverConn) Close() error {
486 dc.Lock()
487 if dc.closed {
488 dc.Unlock()
489 return errors.New("sql: duplicate driverConn close")
490 }
491 dc.closed = true
492 dc.Unlock()
493
494
495 dc.db.mu.Lock()
496 dc.dbmuClosed = true
497 fn := dc.db.removeDepLocked(dc, dc)
498 dc.db.mu.Unlock()
499 return fn()
500 }
501
502 func (dc *driverConn) finalClose() error {
503 var err error
504
505
506
507 var openStmt []*driverStmt
508 withLock(dc, func() {
509 openStmt = make([]*driverStmt, 0, len(dc.openStmt))
510 for ds := range dc.openStmt {
511 openStmt = append(openStmt, ds)
512 }
513 dc.openStmt = nil
514 })
515 for _, ds := range openStmt {
516 ds.Close()
517 }
518 withLock(dc, func() {
519 dc.finalClosed = true
520 err = dc.ci.Close()
521 dc.ci = nil
522 })
523
524 dc.db.mu.Lock()
525 dc.db.numOpen--
526 dc.db.maybeOpenNewConnections()
527 dc.db.mu.Unlock()
528
529 atomic.AddUint64(&dc.db.numClosed, 1)
530 return err
531 }
532
533
534
535
536 type driverStmt struct {
537 sync.Locker
538 si driver.Stmt
539 closed bool
540 closeErr error
541 }
542
543
544
545 func (ds *driverStmt) Close() error {
546 ds.Lock()
547 defer ds.Unlock()
548 if ds.closed {
549 return ds.closeErr
550 }
551 ds.closed = true
552 ds.closeErr = ds.si.Close()
553 return ds.closeErr
554 }
555
556
557 type depSet map[interface{}]bool
558
559
560
561 type finalCloser interface {
562
563
564 finalClose() error
565 }
566
567
568
569 func (db *DB) addDep(x finalCloser, dep interface{}) {
570
571 db.mu.Lock()
572 defer db.mu.Unlock()
573 db.addDepLocked(x, dep)
574 }
575
576 func (db *DB) addDepLocked(x finalCloser, dep interface{}) {
577 if db.dep == nil {
578 db.dep = make(map[finalCloser]depSet)
579 }
580 xdep := db.dep[x]
581 if xdep == nil {
582 xdep = make(depSet)
583 db.dep[x] = xdep
584 }
585 xdep[dep] = true
586 }
587
588
589
590
591
592 func (db *DB) removeDep(x finalCloser, dep interface{}) error {
593 db.mu.Lock()
594 fn := db.removeDepLocked(x, dep)
595 db.mu.Unlock()
596 return fn()
597 }
598
599 func (db *DB) removeDepLocked(x finalCloser, dep interface{}) func() error {
600
601
602 xdep, ok := db.dep[x]
603 if !ok {
604 panic(fmt.Sprintf("unpaired removeDep: no deps for %T", x))
605 }
606
607 l0 := len(xdep)
608 delete(xdep, dep)
609
610 switch len(xdep) {
611 case l0:
612
613 panic(fmt.Sprintf("unpaired removeDep: no %T dep on %T", dep, x))
614 case 0:
615
616 delete(db.dep, x)
617 return x.finalClose
618 default:
619
620 return func() error { return nil }
621 }
622 }
623
624
625
626
627
628
629 var connectionRequestQueueSize = 1000000
630
631 type dsnConnector struct {
632 dsn string
633 driver driver.Driver
634 }
635
636 func (t dsnConnector) Connect(_ context.Context) (driver.Conn, error) {
637 return t.driver.Open(t.dsn)
638 }
639
640 func (t dsnConnector) Driver() driver.Driver {
641 return t.driver
642 }
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660 func OpenDB(c driver.Connector) *DB {
661 ctx, cancel := context.WithCancel(context.Background())
662 db := &DB{
663 connector: c,
664 openerCh: make(chan struct{}, connectionRequestQueueSize),
665 resetterCh: make(chan *driverConn, 50),
666 lastPut: make(map[*driverConn]string),
667 connRequests: make(map[uint64]chan connRequest),
668 stop: cancel,
669 }
670
671 go db.connectionOpener(ctx)
672 go db.connectionResetter(ctx)
673
674 return db
675 }
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694 func Open(driverName, dataSourceName string) (*DB, error) {
695 driversMu.RLock()
696 driveri, ok := drivers[driverName]
697 driversMu.RUnlock()
698 if !ok {
699 return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
700 }
701
702 if driverCtx, ok := driveri.(driver.DriverContext); ok {
703 connector, err := driverCtx.OpenConnector(dataSourceName)
704 if err != nil {
705 return nil, err
706 }
707 return OpenDB(connector), nil
708 }
709
710 return OpenDB(dsnConnector{dsn: dataSourceName, driver: driveri}), nil
711 }
712
713 func (db *DB) pingDC(ctx context.Context, dc *driverConn, release func(error)) error {
714 var err error
715 if pinger, ok := dc.ci.(driver.Pinger); ok {
716 withLock(dc, func() {
717 err = pinger.Ping(ctx)
718 })
719 }
720 release(err)
721 return err
722 }
723
724
725
726 func (db *DB) PingContext(ctx context.Context) error {
727 var dc *driverConn
728 var err error
729
730 for i := 0; i < maxBadConnRetries; i++ {
731 dc, err = db.conn(ctx, cachedOrNewConn)
732 if err != driver.ErrBadConn {
733 break
734 }
735 }
736 if err == driver.ErrBadConn {
737 dc, err = db.conn(ctx, alwaysNewConn)
738 }
739 if err != nil {
740 return err
741 }
742
743 return db.pingDC(ctx, dc, dc.releaseConn)
744 }
745
746
747
748 func (db *DB) Ping() error {
749 return db.PingContext(context.Background())
750 }
751
752
753
754
755
756
757
758 func (db *DB) Close() error {
759 db.mu.Lock()
760 if db.closed {
761 db.mu.Unlock()
762 return nil
763 }
764 if db.cleanerCh != nil {
765 close(db.cleanerCh)
766 }
767 var err error
768 fns := make([]func() error, 0, len(db.freeConn))
769 for _, dc := range db.freeConn {
770 fns = append(fns, dc.closeDBLocked())
771 }
772 db.freeConn = nil
773 db.closed = true
774 for _, req := range db.connRequests {
775 close(req)
776 }
777 db.mu.Unlock()
778 for _, fn := range fns {
779 err1 := fn()
780 if err1 != nil {
781 err = err1
782 }
783 }
784 db.stop()
785 return err
786 }
787
788 const defaultMaxIdleConns = 2
789
790 func (db *DB) maxIdleConnsLocked() int {
791 n := db.maxIdle
792 switch {
793 case n == 0:
794
795 return defaultMaxIdleConns
796 case n < 0:
797 return 0
798 default:
799 return n
800 }
801 }
802
803
804
805
806
807
808
809
810
811
812
813 func (db *DB) SetMaxIdleConns(n int) {
814 db.mu.Lock()
815 if n > 0 {
816 db.maxIdle = n
817 } else {
818
819 db.maxIdle = -1
820 }
821
822 if db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen {
823 db.maxIdle = db.maxOpen
824 }
825 var closing []*driverConn
826 idleCount := len(db.freeConn)
827 maxIdle := db.maxIdleConnsLocked()
828 if idleCount > maxIdle {
829 closing = db.freeConn[maxIdle:]
830 db.freeConn = db.freeConn[:maxIdle]
831 }
832 db.maxIdleClosed += int64(len(closing))
833 db.mu.Unlock()
834 for _, c := range closing {
835 c.Close()
836 }
837 }
838
839
840
841
842
843
844
845
846
847 func (db *DB) SetMaxOpenConns(n int) {
848 db.mu.Lock()
849 db.maxOpen = n
850 if n < 0 {
851 db.maxOpen = 0
852 }
853 syncMaxIdle := db.maxOpen > 0 && db.maxIdleConnsLocked() > db.maxOpen
854 db.mu.Unlock()
855 if syncMaxIdle {
856 db.SetMaxIdleConns(n)
857 }
858 }
859
860
861
862
863
864
865 func (db *DB) SetConnMaxLifetime(d time.Duration) {
866 if d < 0 {
867 d = 0
868 }
869 db.mu.Lock()
870
871 if d > 0 && d < db.maxLifetime && db.cleanerCh != nil {
872 select {
873 case db.cleanerCh <- struct{}{}:
874 default:
875 }
876 }
877 db.maxLifetime = d
878 db.startCleanerLocked()
879 db.mu.Unlock()
880 }
881
882
883 func (db *DB) startCleanerLocked() {
884 if db.maxLifetime > 0 && db.numOpen > 0 && db.cleanerCh == nil {
885 db.cleanerCh = make(chan struct{}, 1)
886 go db.connectionCleaner(db.maxLifetime)
887 }
888 }
889
890 func (db *DB) connectionCleaner(d time.Duration) {
891 const minInterval = time.Second
892
893 if d < minInterval {
894 d = minInterval
895 }
896 t := time.NewTimer(d)
897
898 for {
899 select {
900 case <-t.C:
901 case <-db.cleanerCh:
902 }
903
904 db.mu.Lock()
905 d = db.maxLifetime
906 if db.closed || db.numOpen == 0 || d <= 0 {
907 db.cleanerCh = nil
908 db.mu.Unlock()
909 return
910 }
911
912 expiredSince := nowFunc().Add(-d)
913 var closing []*driverConn
914 for i := 0; i < len(db.freeConn); i++ {
915 c := db.freeConn[i]
916 if c.createdAt.Before(expiredSince) {
917 closing = append(closing, c)
918 last := len(db.freeConn) - 1
919 db.freeConn[i] = db.freeConn[last]
920 db.freeConn[last] = nil
921 db.freeConn = db.freeConn[:last]
922 i--
923 }
924 }
925 db.maxLifetimeClosed += int64(len(closing))
926 db.mu.Unlock()
927
928 for _, c := range closing {
929 c.Close()
930 }
931
932 if d < minInterval {
933 d = minInterval
934 }
935 t.Reset(d)
936 }
937 }
938
939
940 type DBStats struct {
941 MaxOpenConnections int
942
943
944 OpenConnections int
945 InUse int
946 Idle int
947
948
949 WaitCount int64
950 WaitDuration time.Duration
951 MaxIdleClosed int64
952 MaxLifetimeClosed int64
953 }
954
955
956 func (db *DB) Stats() DBStats {
957 wait := atomic.LoadInt64(&db.waitDuration)
958
959 db.mu.Lock()
960 defer db.mu.Unlock()
961
962 stats := DBStats{
963 MaxOpenConnections: db.maxOpen,
964
965 Idle: len(db.freeConn),
966 OpenConnections: db.numOpen,
967 InUse: db.numOpen - len(db.freeConn),
968
969 WaitCount: db.waitCount,
970 WaitDuration: time.Duration(wait),
971 MaxIdleClosed: db.maxIdleClosed,
972 MaxLifetimeClosed: db.maxLifetimeClosed,
973 }
974 return stats
975 }
976
977
978
979
980 func (db *DB) maybeOpenNewConnections() {
981 numRequests := len(db.connRequests)
982 if db.maxOpen > 0 {
983 numCanOpen := db.maxOpen - db.numOpen
984 if numRequests > numCanOpen {
985 numRequests = numCanOpen
986 }
987 }
988 for numRequests > 0 {
989 db.numOpen++
990 numRequests--
991 if db.closed {
992 return
993 }
994 db.openerCh <- struct{}{}
995 }
996 }
997
998
999 func (db *DB) connectionOpener(ctx context.Context) {
1000 for {
1001 select {
1002 case <-ctx.Done():
1003 return
1004 case <-db.openerCh:
1005 db.openNewConnection(ctx)
1006 }
1007 }
1008 }
1009
1010
1011
1012 func (db *DB) connectionResetter(ctx context.Context) {
1013 for {
1014 select {
1015 case <-ctx.Done():
1016 close(db.resetterCh)
1017 for dc := range db.resetterCh {
1018 dc.Unlock()
1019 }
1020 return
1021 case dc := <-db.resetterCh:
1022 dc.resetSession(ctx)
1023 }
1024 }
1025 }
1026
1027
1028 func (db *DB) openNewConnection(ctx context.Context) {
1029
1030
1031
1032 ci, err := db.connector.Connect(ctx)
1033 db.mu.Lock()
1034 defer db.mu.Unlock()
1035 if db.closed {
1036 if err == nil {
1037 ci.Close()
1038 }
1039 db.numOpen--
1040 return
1041 }
1042 if err != nil {
1043 db.numOpen--
1044 db.putConnDBLocked(nil, err)
1045 db.maybeOpenNewConnections()
1046 return
1047 }
1048 dc := &driverConn{
1049 db: db,
1050 createdAt: nowFunc(),
1051 ci: ci,
1052 }
1053 if db.putConnDBLocked(dc, err) {
1054 db.addDepLocked(dc, dc)
1055 } else {
1056 db.numOpen--
1057 ci.Close()
1058 }
1059 }
1060
1061
1062
1063
1064 type connRequest struct {
1065 conn *driverConn
1066 err error
1067 }
1068
1069 var errDBClosed = errors.New("sql: database is closed")
1070
1071
1072
1073 func (db *DB) nextRequestKeyLocked() uint64 {
1074 next := db.nextRequest
1075 db.nextRequest++
1076 return next
1077 }
1078
1079
1080 func (db *DB) conn(ctx context.Context, strategy connReuseStrategy) (*driverConn, error) {
1081 db.mu.Lock()
1082 if db.closed {
1083 db.mu.Unlock()
1084 return nil, errDBClosed
1085 }
1086
1087 select {
1088 default:
1089 case <-ctx.Done():
1090 db.mu.Unlock()
1091 return nil, ctx.Err()
1092 }
1093 lifetime := db.maxLifetime
1094
1095
1096 numFree := len(db.freeConn)
1097 if strategy == cachedOrNewConn && numFree > 0 {
1098 conn := db.freeConn[0]
1099 copy(db.freeConn, db.freeConn[1:])
1100 db.freeConn = db.freeConn[:numFree-1]
1101 conn.inUse = true
1102 db.mu.Unlock()
1103 if conn.expired(lifetime) {
1104 conn.Close()
1105 return nil, driver.ErrBadConn
1106 }
1107
1108 conn.Lock()
1109 err := conn.lastErr
1110 conn.Unlock()
1111 if err == driver.ErrBadConn {
1112 conn.Close()
1113 return nil, driver.ErrBadConn
1114 }
1115 return conn, nil
1116 }
1117
1118
1119
1120 if db.maxOpen > 0 && db.numOpen >= db.maxOpen {
1121
1122
1123 req := make(chan connRequest, 1)
1124 reqKey := db.nextRequestKeyLocked()
1125 db.connRequests[reqKey] = req
1126 db.waitCount++
1127 db.mu.Unlock()
1128
1129 waitStart := time.Now()
1130
1131
1132 select {
1133 case <-ctx.Done():
1134
1135
1136 db.mu.Lock()
1137 delete(db.connRequests, reqKey)
1138 db.mu.Unlock()
1139
1140 atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
1141
1142 select {
1143 default:
1144 case ret, ok := <-req:
1145 if ok && ret.conn != nil {
1146 db.putConn(ret.conn, ret.err, false)
1147 }
1148 }
1149 return nil, ctx.Err()
1150 case ret, ok := <-req:
1151 atomic.AddInt64(&db.waitDuration, int64(time.Since(waitStart)))
1152
1153 if !ok {
1154 return nil, errDBClosed
1155 }
1156 if ret.err == nil && ret.conn.expired(lifetime) {
1157 ret.conn.Close()
1158 return nil, driver.ErrBadConn
1159 }
1160 if ret.conn == nil {
1161 return nil, ret.err
1162 }
1163
1164 ret.conn.Lock()
1165 err := ret.conn.lastErr
1166 ret.conn.Unlock()
1167 if err == driver.ErrBadConn {
1168 ret.conn.Close()
1169 return nil, driver.ErrBadConn
1170 }
1171 return ret.conn, ret.err
1172 }
1173 }
1174
1175 db.numOpen++
1176 db.mu.Unlock()
1177 ci, err := db.connector.Connect(ctx)
1178 if err != nil {
1179 db.mu.Lock()
1180 db.numOpen--
1181 db.maybeOpenNewConnections()
1182 db.mu.Unlock()
1183 return nil, err
1184 }
1185 db.mu.Lock()
1186 dc := &driverConn{
1187 db: db,
1188 createdAt: nowFunc(),
1189 ci: ci,
1190 inUse: true,
1191 }
1192 db.addDepLocked(dc, dc)
1193 db.mu.Unlock()
1194 return dc, nil
1195 }
1196
1197
1198 var putConnHook func(*DB, *driverConn)
1199
1200
1201
1202
1203 func (db *DB) noteUnusedDriverStatement(c *driverConn, ds *driverStmt) {
1204 db.mu.Lock()
1205 defer db.mu.Unlock()
1206 if c.inUse {
1207 c.onPut = append(c.onPut, func() {
1208 ds.Close()
1209 })
1210 } else {
1211 c.Lock()
1212 fc := c.finalClosed
1213 c.Unlock()
1214 if !fc {
1215 ds.Close()
1216 }
1217 }
1218 }
1219
1220
1221
1222 const debugGetPut = false
1223
1224
1225
1226 func (db *DB) putConn(dc *driverConn, err error, resetSession bool) {
1227 db.mu.Lock()
1228 if !dc.inUse {
1229 if debugGetPut {
1230 fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", dc, stack(), db.lastPut[dc])
1231 }
1232 panic("sql: connection returned that was never out")
1233 }
1234 if debugGetPut {
1235 db.lastPut[dc] = stack()
1236 }
1237 dc.inUse = false
1238
1239 for _, fn := range dc.onPut {
1240 fn()
1241 }
1242 dc.onPut = nil
1243
1244 if err == driver.ErrBadConn {
1245
1246
1247
1248
1249 db.maybeOpenNewConnections()
1250 db.mu.Unlock()
1251 dc.Close()
1252 return
1253 }
1254 if putConnHook != nil {
1255 putConnHook(db, dc)
1256 }
1257 if db.closed {
1258
1259
1260 resetSession = false
1261 }
1262 if resetSession {
1263 if _, resetSession = dc.ci.(driver.SessionResetter); resetSession {
1264
1265
1266
1267
1268 dc.Lock()
1269 }
1270 }
1271 added := db.putConnDBLocked(dc, nil)
1272 db.mu.Unlock()
1273
1274 if !added {
1275 if resetSession {
1276 dc.Unlock()
1277 }
1278 dc.Close()
1279 return
1280 }
1281 if !resetSession {
1282 return
1283 }
1284 select {
1285 default:
1286
1287
1288 dc.lastErr = driver.ErrBadConn
1289 dc.Unlock()
1290 case db.resetterCh <- dc:
1291 }
1292 }
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303 func (db *DB) putConnDBLocked(dc *driverConn, err error) bool {
1304 if db.closed {
1305 return false
1306 }
1307 if db.maxOpen > 0 && db.numOpen > db.maxOpen {
1308 return false
1309 }
1310 if c := len(db.connRequests); c > 0 {
1311 var req chan connRequest
1312 var reqKey uint64
1313 for reqKey, req = range db.connRequests {
1314 break
1315 }
1316 delete(db.connRequests, reqKey)
1317 if err == nil {
1318 dc.inUse = true
1319 }
1320 req <- connRequest{
1321 conn: dc,
1322 err: err,
1323 }
1324 return true
1325 } else if err == nil && !db.closed {
1326 if db.maxIdleConnsLocked() > len(db.freeConn) {
1327 db.freeConn = append(db.freeConn, dc)
1328 db.startCleanerLocked()
1329 return true
1330 }
1331 db.maxIdleClosed++
1332 }
1333 return false
1334 }
1335
1336
1337
1338
1339 const maxBadConnRetries = 2
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349 func (db *DB) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1350 var stmt *Stmt
1351 var err error
1352 for i := 0; i < maxBadConnRetries; i++ {
1353 stmt, err = db.prepare(ctx, query, cachedOrNewConn)
1354 if err != driver.ErrBadConn {
1355 break
1356 }
1357 }
1358 if err == driver.ErrBadConn {
1359 return db.prepare(ctx, query, alwaysNewConn)
1360 }
1361 return stmt, err
1362 }
1363
1364
1365
1366
1367
1368
1369 func (db *DB) Prepare(query string) (*Stmt, error) {
1370 return db.PrepareContext(context.Background(), query)
1371 }
1372
1373 func (db *DB) prepare(ctx context.Context, query string, strategy connReuseStrategy) (*Stmt, error) {
1374
1375
1376
1377
1378
1379
1380 dc, err := db.conn(ctx, strategy)
1381 if err != nil {
1382 return nil, err
1383 }
1384 return db.prepareDC(ctx, dc, dc.releaseConn, nil, query)
1385 }
1386
1387
1388
1389
1390 func (db *DB) prepareDC(ctx context.Context, dc *driverConn, release func(error), cg stmtConnGrabber, query string) (*Stmt, error) {
1391 var ds *driverStmt
1392 var err error
1393 defer func() {
1394 release(err)
1395 }()
1396 withLock(dc, func() {
1397 ds, err = dc.prepareLocked(ctx, cg, query)
1398 })
1399 if err != nil {
1400 return nil, err
1401 }
1402 stmt := &Stmt{
1403 db: db,
1404 query: query,
1405 cg: cg,
1406 cgds: ds,
1407 }
1408
1409
1410
1411
1412 if cg == nil {
1413 stmt.css = []connStmt{{dc, ds}}
1414 stmt.lastNumClosed = atomic.LoadUint64(&db.numClosed)
1415 db.addDep(stmt, stmt)
1416 }
1417 return stmt, nil
1418 }
1419
1420
1421
1422 func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
1423 var res Result
1424 var err error
1425 for i := 0; i < maxBadConnRetries; i++ {
1426 res, err = db.exec(ctx, query, args, cachedOrNewConn)
1427 if err != driver.ErrBadConn {
1428 break
1429 }
1430 }
1431 if err == driver.ErrBadConn {
1432 return db.exec(ctx, query, args, alwaysNewConn)
1433 }
1434 return res, err
1435 }
1436
1437
1438
1439 func (db *DB) Exec(query string, args ...interface{}) (Result, error) {
1440 return db.ExecContext(context.Background(), query, args...)
1441 }
1442
1443 func (db *DB) exec(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (Result, error) {
1444 dc, err := db.conn(ctx, strategy)
1445 if err != nil {
1446 return nil, err
1447 }
1448 return db.execDC(ctx, dc, dc.releaseConn, query, args)
1449 }
1450
1451 func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), query string, args []interface{}) (res Result, err error) {
1452 defer func() {
1453 release(err)
1454 }()
1455 execerCtx, ok := dc.ci.(driver.ExecerContext)
1456 var execer driver.Execer
1457 if !ok {
1458 execer, ok = dc.ci.(driver.Execer)
1459 }
1460 if ok {
1461 var nvdargs []driver.NamedValue
1462 var resi driver.Result
1463 withLock(dc, func() {
1464 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1465 if err != nil {
1466 return
1467 }
1468 resi, err = ctxDriverExec(ctx, execerCtx, execer, query, nvdargs)
1469 })
1470 if err != driver.ErrSkip {
1471 if err != nil {
1472 return nil, err
1473 }
1474 return driverResult{dc, resi}, nil
1475 }
1476 }
1477
1478 var si driver.Stmt
1479 withLock(dc, func() {
1480 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1481 })
1482 if err != nil {
1483 return nil, err
1484 }
1485 ds := &driverStmt{Locker: dc, si: si}
1486 defer ds.Close()
1487 return resultFromStatement(ctx, dc.ci, ds, args...)
1488 }
1489
1490
1491
1492 func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
1493 var rows *Rows
1494 var err error
1495 for i := 0; i < maxBadConnRetries; i++ {
1496 rows, err = db.query(ctx, query, args, cachedOrNewConn)
1497 if err != driver.ErrBadConn {
1498 break
1499 }
1500 }
1501 if err == driver.ErrBadConn {
1502 return db.query(ctx, query, args, alwaysNewConn)
1503 }
1504 return rows, err
1505 }
1506
1507
1508
1509 func (db *DB) Query(query string, args ...interface{}) (*Rows, error) {
1510 return db.QueryContext(context.Background(), query, args...)
1511 }
1512
1513 func (db *DB) query(ctx context.Context, query string, args []interface{}, strategy connReuseStrategy) (*Rows, error) {
1514 dc, err := db.conn(ctx, strategy)
1515 if err != nil {
1516 return nil, err
1517 }
1518
1519 return db.queryDC(ctx, nil, dc, dc.releaseConn, query, args)
1520 }
1521
1522
1523
1524
1525
1526 func (db *DB) queryDC(ctx, txctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
1527 queryerCtx, ok := dc.ci.(driver.QueryerContext)
1528 var queryer driver.Queryer
1529 if !ok {
1530 queryer, ok = dc.ci.(driver.Queryer)
1531 }
1532 if ok {
1533 var nvdargs []driver.NamedValue
1534 var rowsi driver.Rows
1535 var err error
1536 withLock(dc, func() {
1537 nvdargs, err = driverArgsConnLocked(dc.ci, nil, args)
1538 if err != nil {
1539 return
1540 }
1541 rowsi, err = ctxDriverQuery(ctx, queryerCtx, queryer, query, nvdargs)
1542 })
1543 if err != driver.ErrSkip {
1544 if err != nil {
1545 releaseConn(err)
1546 return nil, err
1547 }
1548
1549
1550 rows := &Rows{
1551 dc: dc,
1552 releaseConn: releaseConn,
1553 rowsi: rowsi,
1554 }
1555 rows.initContextClose(ctx, txctx)
1556 return rows, nil
1557 }
1558 }
1559
1560 var si driver.Stmt
1561 var err error
1562 withLock(dc, func() {
1563 si, err = ctxDriverPrepare(ctx, dc.ci, query)
1564 })
1565 if err != nil {
1566 releaseConn(err)
1567 return nil, err
1568 }
1569
1570 ds := &driverStmt{Locker: dc, si: si}
1571 rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
1572 if err != nil {
1573 ds.Close()
1574 releaseConn(err)
1575 return nil, err
1576 }
1577
1578
1579
1580 rows := &Rows{
1581 dc: dc,
1582 releaseConn: releaseConn,
1583 rowsi: rowsi,
1584 closeStmt: ds,
1585 }
1586 rows.initContextClose(ctx, txctx)
1587 return rows, nil
1588 }
1589
1590
1591
1592
1593
1594
1595
1596 func (db *DB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
1597 rows, err := db.QueryContext(ctx, query, args...)
1598 return &Row{rows: rows, err: err}
1599 }
1600
1601
1602
1603
1604
1605
1606
1607 func (db *DB) QueryRow(query string, args ...interface{}) *Row {
1608 return db.QueryRowContext(context.Background(), query, args...)
1609 }
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621 func (db *DB) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1622 var tx *Tx
1623 var err error
1624 for i := 0; i < maxBadConnRetries; i++ {
1625 tx, err = db.begin(ctx, opts, cachedOrNewConn)
1626 if err != driver.ErrBadConn {
1627 break
1628 }
1629 }
1630 if err == driver.ErrBadConn {
1631 return db.begin(ctx, opts, alwaysNewConn)
1632 }
1633 return tx, err
1634 }
1635
1636
1637
1638 func (db *DB) Begin() (*Tx, error) {
1639 return db.BeginTx(context.Background(), nil)
1640 }
1641
1642 func (db *DB) begin(ctx context.Context, opts *TxOptions, strategy connReuseStrategy) (tx *Tx, err error) {
1643 dc, err := db.conn(ctx, strategy)
1644 if err != nil {
1645 return nil, err
1646 }
1647 return db.beginDC(ctx, dc, dc.releaseConn, opts)
1648 }
1649
1650
1651 func (db *DB) beginDC(ctx context.Context, dc *driverConn, release func(error), opts *TxOptions) (tx *Tx, err error) {
1652 var txi driver.Tx
1653 withLock(dc, func() {
1654 txi, err = ctxDriverBegin(ctx, opts, dc.ci)
1655 })
1656 if err != nil {
1657 release(err)
1658 return nil, err
1659 }
1660
1661
1662
1663 ctx, cancel := context.WithCancel(ctx)
1664 tx = &Tx{
1665 db: db,
1666 dc: dc,
1667 releaseConn: release,
1668 txi: txi,
1669 cancel: cancel,
1670 ctx: ctx,
1671 }
1672 go tx.awaitDone()
1673 return tx, nil
1674 }
1675
1676
1677 func (db *DB) Driver() driver.Driver {
1678 return db.connector.Driver()
1679 }
1680
1681
1682
1683 var ErrConnDone = errors.New("sql: connection is already closed")
1684
1685
1686
1687
1688
1689
1690
1691
1692 func (db *DB) Conn(ctx context.Context) (*Conn, error) {
1693 var dc *driverConn
1694 var err error
1695 for i := 0; i < maxBadConnRetries; i++ {
1696 dc, err = db.conn(ctx, cachedOrNewConn)
1697 if err != driver.ErrBadConn {
1698 break
1699 }
1700 }
1701 if err == driver.ErrBadConn {
1702 dc, err = db.conn(ctx, cachedOrNewConn)
1703 }
1704 if err != nil {
1705 return nil, err
1706 }
1707
1708 conn := &Conn{
1709 db: db,
1710 dc: dc,
1711 }
1712 return conn, nil
1713 }
1714
1715 type releaseConn func(error)
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726 type Conn struct {
1727 db *DB
1728
1729
1730
1731
1732 closemu sync.RWMutex
1733
1734
1735
1736 dc *driverConn
1737
1738
1739
1740
1741 done int32
1742 }
1743
1744 func (c *Conn) grabConn(context.Context) (*driverConn, releaseConn, error) {
1745 if atomic.LoadInt32(&c.done) != 0 {
1746 return nil, nil, ErrConnDone
1747 }
1748 c.closemu.RLock()
1749 return c.dc, c.closemuRUnlockCondReleaseConn, nil
1750 }
1751
1752
1753 func (c *Conn) PingContext(ctx context.Context) error {
1754 dc, release, err := c.grabConn(ctx)
1755 if err != nil {
1756 return err
1757 }
1758 return c.db.pingDC(ctx, dc, release)
1759 }
1760
1761
1762
1763 func (c *Conn) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
1764 dc, release, err := c.grabConn(ctx)
1765 if err != nil {
1766 return nil, err
1767 }
1768 return c.db.execDC(ctx, dc, release, query, args)
1769 }
1770
1771
1772
1773 func (c *Conn) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
1774 dc, release, err := c.grabConn(ctx)
1775 if err != nil {
1776 return nil, err
1777 }
1778 return c.db.queryDC(ctx, nil, dc, release, query, args)
1779 }
1780
1781
1782
1783
1784
1785
1786
1787 func (c *Conn) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
1788 rows, err := c.QueryContext(ctx, query, args...)
1789 return &Row{rows: rows, err: err}
1790 }
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800 func (c *Conn) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
1801 dc, release, err := c.grabConn(ctx)
1802 if err != nil {
1803 return nil, err
1804 }
1805 return c.db.prepareDC(ctx, dc, release, c, query)
1806 }
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818 func (c *Conn) BeginTx(ctx context.Context, opts *TxOptions) (*Tx, error) {
1819 dc, release, err := c.grabConn(ctx)
1820 if err != nil {
1821 return nil, err
1822 }
1823 return c.db.beginDC(ctx, dc, release, opts)
1824 }
1825
1826
1827
1828 func (c *Conn) closemuRUnlockCondReleaseConn(err error) {
1829 c.closemu.RUnlock()
1830 if err == driver.ErrBadConn {
1831 c.close(err)
1832 }
1833 }
1834
1835 func (c *Conn) txCtx() context.Context {
1836 return nil
1837 }
1838
1839 func (c *Conn) close(err error) error {
1840 if !atomic.CompareAndSwapInt32(&c.done, 0, 1) {
1841 return ErrConnDone
1842 }
1843
1844
1845
1846 c.closemu.Lock()
1847 defer c.closemu.Unlock()
1848
1849 c.dc.releaseConn(err)
1850 c.dc = nil
1851 c.db = nil
1852 return err
1853 }
1854
1855
1856
1857
1858
1859
1860 func (c *Conn) Close() error {
1861 return c.close(nil)
1862 }
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874 type Tx struct {
1875 db *DB
1876
1877
1878
1879
1880 closemu sync.RWMutex
1881
1882
1883
1884 dc *driverConn
1885 txi driver.Tx
1886
1887
1888
1889 releaseConn func(error)
1890
1891
1892
1893
1894
1895 done int32
1896
1897
1898
1899 stmts struct {
1900 sync.Mutex
1901 v []*Stmt
1902 }
1903
1904
1905 cancel func()
1906
1907
1908 ctx context.Context
1909 }
1910
1911
1912
1913 func (tx *Tx) awaitDone() {
1914
1915
1916 <-tx.ctx.Done()
1917
1918
1919
1920
1921
1922 tx.rollback(true)
1923 }
1924
1925 func (tx *Tx) isDone() bool {
1926 return atomic.LoadInt32(&tx.done) != 0
1927 }
1928
1929
1930
1931 var ErrTxDone = errors.New("sql: transaction has already been committed or rolled back")
1932
1933
1934
1935 func (tx *Tx) close(err error) {
1936 tx.cancel()
1937
1938 tx.closemu.Lock()
1939 defer tx.closemu.Unlock()
1940
1941 tx.releaseConn(err)
1942 tx.dc = nil
1943 tx.txi = nil
1944 }
1945
1946
1947
1948 var hookTxGrabConn func()
1949
1950 func (tx *Tx) grabConn(ctx context.Context) (*driverConn, releaseConn, error) {
1951 select {
1952 default:
1953 case <-ctx.Done():
1954 return nil, nil, ctx.Err()
1955 }
1956
1957
1958
1959 tx.closemu.RLock()
1960 if tx.isDone() {
1961 tx.closemu.RUnlock()
1962 return nil, nil, ErrTxDone
1963 }
1964 if hookTxGrabConn != nil {
1965 hookTxGrabConn()
1966 }
1967 return tx.dc, tx.closemuRUnlockRelease, nil
1968 }
1969
1970 func (tx *Tx) txCtx() context.Context {
1971 return tx.ctx
1972 }
1973
1974
1975
1976
1977
1978 func (tx *Tx) closemuRUnlockRelease(error) {
1979 tx.closemu.RUnlock()
1980 }
1981
1982
1983 func (tx *Tx) closePrepared() {
1984 tx.stmts.Lock()
1985 defer tx.stmts.Unlock()
1986 for _, stmt := range tx.stmts.v {
1987 stmt.Close()
1988 }
1989 }
1990
1991
1992 func (tx *Tx) Commit() error {
1993
1994
1995
1996 select {
1997 default:
1998 case <-tx.ctx.Done():
1999 if atomic.LoadInt32(&tx.done) == 1 {
2000 return ErrTxDone
2001 }
2002 return tx.ctx.Err()
2003 }
2004 if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
2005 return ErrTxDone
2006 }
2007 var err error
2008 withLock(tx.dc, func() {
2009 err = tx.txi.Commit()
2010 })
2011 if err != driver.ErrBadConn {
2012 tx.closePrepared()
2013 }
2014 tx.close(err)
2015 return err
2016 }
2017
2018
2019
2020 func (tx *Tx) rollback(discardConn bool) error {
2021 if !atomic.CompareAndSwapInt32(&tx.done, 0, 1) {
2022 return ErrTxDone
2023 }
2024 var err error
2025 withLock(tx.dc, func() {
2026 err = tx.txi.Rollback()
2027 })
2028 if err != driver.ErrBadConn {
2029 tx.closePrepared()
2030 }
2031 if discardConn {
2032 err = driver.ErrBadConn
2033 }
2034 tx.close(err)
2035 return err
2036 }
2037
2038
2039 func (tx *Tx) Rollback() error {
2040 return tx.rollback(false)
2041 }
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053 func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
2054 dc, release, err := tx.grabConn(ctx)
2055 if err != nil {
2056 return nil, err
2057 }
2058
2059 stmt, err := tx.db.prepareDC(ctx, dc, release, tx, query)
2060 if err != nil {
2061 return nil, err
2062 }
2063 tx.stmts.Lock()
2064 tx.stmts.v = append(tx.stmts.v, stmt)
2065 tx.stmts.Unlock()
2066 return stmt, nil
2067 }
2068
2069
2070
2071
2072
2073
2074
2075 func (tx *Tx) Prepare(query string) (*Stmt, error) {
2076 return tx.PrepareContext(context.Background(), query)
2077 }
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094 func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
2095 dc, release, err := tx.grabConn(ctx)
2096 if err != nil {
2097 return &Stmt{stickyErr: err}
2098 }
2099 defer release(nil)
2100
2101 if tx.db != stmt.db {
2102 return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
2103 }
2104 var si driver.Stmt
2105 var parentStmt *Stmt
2106 stmt.mu.Lock()
2107 if stmt.closed || stmt.cg != nil {
2108
2109
2110
2111
2112
2113
2114 stmt.mu.Unlock()
2115 withLock(dc, func() {
2116 si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
2117 })
2118 if err != nil {
2119 return &Stmt{stickyErr: err}
2120 }
2121 } else {
2122 stmt.removeClosedStmtLocked()
2123
2124
2125 for _, v := range stmt.css {
2126 if v.dc == dc {
2127 si = v.ds.si
2128 break
2129 }
2130 }
2131
2132 stmt.mu.Unlock()
2133
2134 if si == nil {
2135 var ds *driverStmt
2136 withLock(dc, func() {
2137 ds, err = stmt.prepareOnConnLocked(ctx, dc)
2138 })
2139 if err != nil {
2140 return &Stmt{stickyErr: err}
2141 }
2142 si = ds.si
2143 }
2144 parentStmt = stmt
2145 }
2146
2147 txs := &Stmt{
2148 db: tx.db,
2149 cg: tx,
2150 cgds: &driverStmt{
2151 Locker: dc,
2152 si: si,
2153 },
2154 parentStmt: parentStmt,
2155 query: stmt.query,
2156 }
2157 if parentStmt != nil {
2158 tx.db.addDep(parentStmt, txs)
2159 }
2160 tx.stmts.Lock()
2161 tx.stmts.v = append(tx.stmts.v, txs)
2162 tx.stmts.Unlock()
2163 return txs
2164 }
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178 func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
2179 return tx.StmtContext(context.Background(), stmt)
2180 }
2181
2182
2183
2184 func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (Result, error) {
2185 dc, release, err := tx.grabConn(ctx)
2186 if err != nil {
2187 return nil, err
2188 }
2189 return tx.db.execDC(ctx, dc, release, query, args)
2190 }
2191
2192
2193
2194 func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
2195 return tx.ExecContext(context.Background(), query, args...)
2196 }
2197
2198
2199 func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*Rows, error) {
2200 dc, release, err := tx.grabConn(ctx)
2201 if err != nil {
2202 return nil, err
2203 }
2204
2205 return tx.db.queryDC(ctx, tx.ctx, dc, release, query, args)
2206 }
2207
2208
2209 func (tx *Tx) Query(query string, args ...interface{}) (*Rows, error) {
2210 return tx.QueryContext(context.Background(), query, args...)
2211 }
2212
2213
2214
2215
2216
2217
2218
2219 func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *Row {
2220 rows, err := tx.QueryContext(ctx, query, args...)
2221 return &Row{rows: rows, err: err}
2222 }
2223
2224
2225
2226
2227
2228
2229
2230 func (tx *Tx) QueryRow(query string, args ...interface{}) *Row {
2231 return tx.QueryRowContext(context.Background(), query, args...)
2232 }
2233
2234
2235 type connStmt struct {
2236 dc *driverConn
2237 ds *driverStmt
2238 }
2239
2240
2241
2242 type stmtConnGrabber interface {
2243
2244
2245 grabConn(context.Context) (*driverConn, releaseConn, error)
2246
2247
2248
2249
2250 txCtx() context.Context
2251 }
2252
2253 var (
2254 _ stmtConnGrabber = &Tx{}
2255 _ stmtConnGrabber = &Conn{}
2256 )
2257
2258
2259
2260 type Stmt struct {
2261
2262 db *DB
2263 query string
2264 stickyErr error
2265
2266 closemu sync.RWMutex
2267
2268
2269
2270
2271
2272
2273 cg stmtConnGrabber
2274 cgds *driverStmt
2275
2276
2277
2278
2279
2280
2281
2282 parentStmt *Stmt
2283
2284 mu sync.Mutex
2285 closed bool
2286
2287
2288
2289
2290
2291 css []connStmt
2292
2293
2294
2295 lastNumClosed uint64
2296 }
2297
2298
2299
2300 func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, error) {
2301 s.closemu.RLock()
2302 defer s.closemu.RUnlock()
2303
2304 var res Result
2305 strategy := cachedOrNewConn
2306 for i := 0; i < maxBadConnRetries+1; i++ {
2307 if i == maxBadConnRetries {
2308 strategy = alwaysNewConn
2309 }
2310 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2311 if err != nil {
2312 if err == driver.ErrBadConn {
2313 continue
2314 }
2315 return nil, err
2316 }
2317
2318 res, err = resultFromStatement(ctx, dc.ci, ds, args...)
2319 releaseConn(err)
2320 if err != driver.ErrBadConn {
2321 return res, err
2322 }
2323 }
2324 return nil, driver.ErrBadConn
2325 }
2326
2327
2328
2329 func (s *Stmt) Exec(args ...interface{}) (Result, error) {
2330 return s.ExecContext(context.Background(), args...)
2331 }
2332
2333 func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
2334 ds.Lock()
2335 defer ds.Unlock()
2336
2337 dargs, err := driverArgsConnLocked(ci, ds, args)
2338 if err != nil {
2339 return nil, err
2340 }
2341
2342 resi, err := ctxDriverStmtExec(ctx, ds.si, dargs)
2343 if err != nil {
2344 return nil, err
2345 }
2346 return driverResult{ds.Locker, resi}, nil
2347 }
2348
2349
2350
2351
2352
2353 func (s *Stmt) removeClosedStmtLocked() {
2354 t := len(s.css)/2 + 1
2355 if t > 10 {
2356 t = 10
2357 }
2358 dbClosed := atomic.LoadUint64(&s.db.numClosed)
2359 if dbClosed-s.lastNumClosed < uint64(t) {
2360 return
2361 }
2362
2363 s.db.mu.Lock()
2364 for i := 0; i < len(s.css); i++ {
2365 if s.css[i].dc.dbmuClosed {
2366 s.css[i] = s.css[len(s.css)-1]
2367 s.css = s.css[:len(s.css)-1]
2368 i--
2369 }
2370 }
2371 s.db.mu.Unlock()
2372 s.lastNumClosed = dbClosed
2373 }
2374
2375
2376
2377
2378 func (s *Stmt) connStmt(ctx context.Context, strategy connReuseStrategy) (dc *driverConn, releaseConn func(error), ds *driverStmt, err error) {
2379 if err = s.stickyErr; err != nil {
2380 return
2381 }
2382 s.mu.Lock()
2383 if s.closed {
2384 s.mu.Unlock()
2385 err = errors.New("sql: statement is closed")
2386 return
2387 }
2388
2389
2390
2391 if s.cg != nil {
2392 s.mu.Unlock()
2393 dc, releaseConn, err = s.cg.grabConn(ctx)
2394 if err != nil {
2395 return
2396 }
2397 return dc, releaseConn, s.cgds, nil
2398 }
2399
2400 s.removeClosedStmtLocked()
2401 s.mu.Unlock()
2402
2403 dc, err = s.db.conn(ctx, strategy)
2404 if err != nil {
2405 return nil, nil, nil, err
2406 }
2407
2408 s.mu.Lock()
2409 for _, v := range s.css {
2410 if v.dc == dc {
2411 s.mu.Unlock()
2412 return dc, dc.releaseConn, v.ds, nil
2413 }
2414 }
2415 s.mu.Unlock()
2416
2417
2418 withLock(dc, func() {
2419 ds, err = s.prepareOnConnLocked(ctx, dc)
2420 })
2421 if err != nil {
2422 dc.releaseConn(err)
2423 return nil, nil, nil, err
2424 }
2425
2426 return dc, dc.releaseConn, ds, nil
2427 }
2428
2429
2430
2431 func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
2432 si, err := dc.prepareLocked(ctx, s.cg, s.query)
2433 if err != nil {
2434 return nil, err
2435 }
2436 cs := connStmt{dc, si}
2437 s.mu.Lock()
2438 s.css = append(s.css, cs)
2439 s.mu.Unlock()
2440 return cs.ds, nil
2441 }
2442
2443
2444
2445 func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, error) {
2446 s.closemu.RLock()
2447 defer s.closemu.RUnlock()
2448
2449 var rowsi driver.Rows
2450 strategy := cachedOrNewConn
2451 for i := 0; i < maxBadConnRetries+1; i++ {
2452 if i == maxBadConnRetries {
2453 strategy = alwaysNewConn
2454 }
2455 dc, releaseConn, ds, err := s.connStmt(ctx, strategy)
2456 if err != nil {
2457 if err == driver.ErrBadConn {
2458 continue
2459 }
2460 return nil, err
2461 }
2462
2463 rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
2464 if err == nil {
2465
2466
2467 rows := &Rows{
2468 dc: dc,
2469 rowsi: rowsi,
2470
2471 }
2472
2473
2474 s.db.addDep(s, rows)
2475
2476
2477
2478 rows.releaseConn = func(err error) {
2479 releaseConn(err)
2480 s.db.removeDep(s, rows)
2481 }
2482 var txctx context.Context
2483 if s.cg != nil {
2484 txctx = s.cg.txCtx()
2485 }
2486 rows.initContextClose(ctx, txctx)
2487 return rows, nil
2488 }
2489
2490 releaseConn(err)
2491 if err != driver.ErrBadConn {
2492 return nil, err
2493 }
2494 }
2495 return nil, driver.ErrBadConn
2496 }
2497
2498
2499
2500 func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
2501 return s.QueryContext(context.Background(), args...)
2502 }
2503
2504 func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
2505 ds.Lock()
2506 defer ds.Unlock()
2507 dargs, err := driverArgsConnLocked(ci, ds, args)
2508 if err != nil {
2509 return nil, err
2510 }
2511 return ctxDriverStmtQuery(ctx, ds.si, dargs)
2512 }
2513
2514
2515
2516
2517
2518
2519
2520 func (s *Stmt) QueryRowContext(ctx context.Context, args ...interface{}) *Row {
2521 rows, err := s.QueryContext(ctx, args...)
2522 if err != nil {
2523 return &Row{err: err}
2524 }
2525 return &Row{rows: rows}
2526 }
2527
2528
2529
2530
2531
2532
2533
2534
2535
2536
2537
2538
2539 func (s *Stmt) QueryRow(args ...interface{}) *Row {
2540 return s.QueryRowContext(context.Background(), args...)
2541 }
2542
2543
2544 func (s *Stmt) Close() error {
2545 s.closemu.Lock()
2546 defer s.closemu.Unlock()
2547
2548 if s.stickyErr != nil {
2549 return s.stickyErr
2550 }
2551 s.mu.Lock()
2552 if s.closed {
2553 s.mu.Unlock()
2554 return nil
2555 }
2556 s.closed = true
2557 txds := s.cgds
2558 s.cgds = nil
2559
2560 s.mu.Unlock()
2561
2562 if s.cg == nil {
2563 return s.db.removeDep(s, s)
2564 }
2565
2566 if s.parentStmt != nil {
2567
2568
2569 return s.db.removeDep(s.parentStmt, s)
2570 }
2571 return txds.Close()
2572 }
2573
2574 func (s *Stmt) finalClose() error {
2575 s.mu.Lock()
2576 defer s.mu.Unlock()
2577 if s.css != nil {
2578 for _, v := range s.css {
2579 s.db.noteUnusedDriverStatement(v.dc, v.ds)
2580 v.dc.removeOpenStmt(v.ds)
2581 }
2582 s.css = nil
2583 }
2584 return nil
2585 }
2586
2587
2588
2589 type Rows struct {
2590 dc *driverConn
2591 releaseConn func(error)
2592 rowsi driver.Rows
2593 cancel func()
2594 closeStmt *driverStmt
2595
2596
2597
2598
2599
2600
2601 closemu sync.RWMutex
2602 closed bool
2603 lasterr error
2604
2605
2606
2607 lastcols []driver.Value
2608 }
2609
2610 func (rs *Rows) initContextClose(ctx, txctx context.Context) {
2611 if ctx.Done() == nil && (txctx == nil || txctx.Done() == nil) {
2612 return
2613 }
2614 ctx, rs.cancel = context.WithCancel(ctx)
2615 go rs.awaitDone(ctx, txctx)
2616 }
2617
2618
2619
2620
2621
2622 func (rs *Rows) awaitDone(ctx, txctx context.Context) {
2623 var txctxDone <-chan struct{}
2624 if txctx != nil {
2625 txctxDone = txctx.Done()
2626 }
2627 select {
2628 case <-ctx.Done():
2629 case <-txctxDone:
2630 }
2631 rs.close(ctx.Err())
2632 }
2633
2634
2635
2636
2637
2638
2639
2640 func (rs *Rows) Next() bool {
2641 var doClose, ok bool
2642 withLock(rs.closemu.RLocker(), func() {
2643 doClose, ok = rs.nextLocked()
2644 })
2645 if doClose {
2646 rs.Close()
2647 }
2648 return ok
2649 }
2650
2651 func (rs *Rows) nextLocked() (doClose, ok bool) {
2652 if rs.closed {
2653 return false, false
2654 }
2655
2656
2657
2658 rs.dc.Lock()
2659 defer rs.dc.Unlock()
2660
2661 if rs.lastcols == nil {
2662 rs.lastcols = make([]driver.Value, len(rs.rowsi.Columns()))
2663 }
2664
2665 rs.lasterr = rs.rowsi.Next(rs.lastcols)
2666 if rs.lasterr != nil {
2667
2668 if rs.lasterr != io.EOF {
2669 return true, false
2670 }
2671 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
2672 if !ok {
2673 return true, false
2674 }
2675
2676
2677
2678 if !nextResultSet.HasNextResultSet() {
2679 doClose = true
2680 }
2681 return doClose, false
2682 }
2683 return false, true
2684 }
2685
2686
2687
2688
2689
2690
2691
2692
2693
2694 func (rs *Rows) NextResultSet() bool {
2695 var doClose bool
2696 defer func() {
2697 if doClose {
2698 rs.Close()
2699 }
2700 }()
2701 rs.closemu.RLock()
2702 defer rs.closemu.RUnlock()
2703
2704 if rs.closed {
2705 return false
2706 }
2707
2708 rs.lastcols = nil
2709 nextResultSet, ok := rs.rowsi.(driver.RowsNextResultSet)
2710 if !ok {
2711 doClose = true
2712 return false
2713 }
2714
2715
2716
2717 rs.dc.Lock()
2718 defer rs.dc.Unlock()
2719
2720 rs.lasterr = nextResultSet.NextResultSet()
2721 if rs.lasterr != nil {
2722 doClose = true
2723 return false
2724 }
2725 return true
2726 }
2727
2728
2729
2730 func (rs *Rows) Err() error {
2731 rs.closemu.RLock()
2732 defer rs.closemu.RUnlock()
2733 if rs.lasterr == io.EOF {
2734 return nil
2735 }
2736 return rs.lasterr
2737 }
2738
2739
2740
2741
2742 func (rs *Rows) Columns() ([]string, error) {
2743 rs.closemu.RLock()
2744 defer rs.closemu.RUnlock()
2745 if rs.closed {
2746 return nil, errors.New("sql: Rows are closed")
2747 }
2748 if rs.rowsi == nil {
2749 return nil, errors.New("sql: no Rows available")
2750 }
2751 rs.dc.Lock()
2752 defer rs.dc.Unlock()
2753
2754 return rs.rowsi.Columns(), nil
2755 }
2756
2757
2758
2759 func (rs *Rows) ColumnTypes() ([]*ColumnType, error) {
2760 rs.closemu.RLock()
2761 defer rs.closemu.RUnlock()
2762 if rs.closed {
2763 return nil, errors.New("sql: Rows are closed")
2764 }
2765 if rs.rowsi == nil {
2766 return nil, errors.New("sql: no Rows available")
2767 }
2768 rs.dc.Lock()
2769 defer rs.dc.Unlock()
2770
2771 return rowsColumnInfoSetupConnLocked(rs.rowsi), nil
2772 }
2773
2774
2775 type ColumnType struct {
2776 name string
2777
2778 hasNullable bool
2779 hasLength bool
2780 hasPrecisionScale bool
2781
2782 nullable bool
2783 length int64
2784 databaseType string
2785 precision int64
2786 scale int64
2787 scanType reflect.Type
2788 }
2789
2790
2791 func (ci *ColumnType) Name() string {
2792 return ci.name
2793 }
2794
2795
2796
2797
2798
2799
2800 func (ci *ColumnType) Length() (length int64, ok bool) {
2801 return ci.length, ci.hasLength
2802 }
2803
2804
2805
2806 func (ci *ColumnType) DecimalSize() (precision, scale int64, ok bool) {
2807 return ci.precision, ci.scale, ci.hasPrecisionScale
2808 }
2809
2810
2811
2812
2813 func (ci *ColumnType) ScanType() reflect.Type {
2814 return ci.scanType
2815 }
2816
2817
2818
2819 func (ci *ColumnType) Nullable() (nullable, ok bool) {
2820 return ci.nullable, ci.hasNullable
2821 }
2822
2823
2824
2825
2826
2827
2828 func (ci *ColumnType) DatabaseTypeName() string {
2829 return ci.databaseType
2830 }
2831
2832 func rowsColumnInfoSetupConnLocked(rowsi driver.Rows) []*ColumnType {
2833 names := rowsi.Columns()
2834
2835 list := make([]*ColumnType, len(names))
2836 for i := range list {
2837 ci := &ColumnType{
2838 name: names[i],
2839 }
2840 list[i] = ci
2841
2842 if prop, ok := rowsi.(driver.RowsColumnTypeScanType); ok {
2843 ci.scanType = prop.ColumnTypeScanType(i)
2844 } else {
2845 ci.scanType = reflect.TypeOf(new(interface{})).Elem()
2846 }
2847 if prop, ok := rowsi.(driver.RowsColumnTypeDatabaseTypeName); ok {
2848 ci.databaseType = prop.ColumnTypeDatabaseTypeName(i)
2849 }
2850 if prop, ok := rowsi.(driver.RowsColumnTypeLength); ok {
2851 ci.length, ci.hasLength = prop.ColumnTypeLength(i)
2852 }
2853 if prop, ok := rowsi.(driver.RowsColumnTypeNullable); ok {
2854 ci.nullable, ci.hasNullable = prop.ColumnTypeNullable(i)
2855 }
2856 if prop, ok := rowsi.(driver.RowsColumnTypePrecisionScale); ok {
2857 ci.precision, ci.scale, ci.hasPrecisionScale = prop.ColumnTypePrecisionScale(i)
2858 }
2859 }
2860 return list
2861 }
2862
2863
2864
2865
2866
2867
2868
2869
2870
2871
2872
2873
2874
2875
2876
2877
2878
2879
2880
2881
2882
2883
2884
2885
2886
2887
2888
2889
2890
2891
2892
2893
2894
2895
2896
2897
2898
2899
2900
2901
2902
2903
2904
2905
2906
2907
2908
2909
2910
2911
2912
2913
2914 func (rs *Rows) Scan(dest ...interface{}) error {
2915 rs.closemu.RLock()
2916
2917 if rs.lasterr != nil && rs.lasterr != io.EOF {
2918 rs.closemu.RUnlock()
2919 return rs.lasterr
2920 }
2921 if rs.closed {
2922 rs.closemu.RUnlock()
2923 return errors.New("sql: Rows are closed")
2924 }
2925 rs.closemu.RUnlock()
2926
2927 if rs.lastcols == nil {
2928 return errors.New("sql: Scan called without calling Next")
2929 }
2930 if len(dest) != len(rs.lastcols) {
2931 return fmt.Errorf("sql: expected %d destination arguments in Scan, not %d", len(rs.lastcols), len(dest))
2932 }
2933 for i, sv := range rs.lastcols {
2934 err := convertAssign(dest[i], sv)
2935 if err != nil {
2936 return fmt.Errorf(`sql: Scan error on column index %d, name %q: %v`, i, rs.rowsi.Columns()[i], err)
2937 }
2938 }
2939 return nil
2940 }
2941
2942
2943
2944 var rowsCloseHook = func() func(*Rows, *error) { return nil }
2945
2946
2947
2948
2949
2950 func (rs *Rows) Close() error {
2951 return rs.close(nil)
2952 }
2953
2954 func (rs *Rows) close(err error) error {
2955 rs.closemu.Lock()
2956 defer rs.closemu.Unlock()
2957
2958 if rs.closed {
2959 return nil
2960 }
2961 rs.closed = true
2962
2963 if rs.lasterr == nil {
2964 rs.lasterr = err
2965 }
2966
2967 withLock(rs.dc, func() {
2968 err = rs.rowsi.Close()
2969 })
2970 if fn := rowsCloseHook(); fn != nil {
2971 fn(rs, &err)
2972 }
2973 if rs.cancel != nil {
2974 rs.cancel()
2975 }
2976
2977 if rs.closeStmt != nil {
2978 rs.closeStmt.Close()
2979 }
2980 rs.releaseConn(err)
2981 return err
2982 }
2983
2984
2985 type Row struct {
2986
2987 err error
2988 rows *Rows
2989 }
2990
2991
2992
2993
2994
2995
2996 func (r *Row) Scan(dest ...interface{}) error {
2997 if r.err != nil {
2998 return r.err
2999 }
3000
3001
3002
3003
3004
3005
3006
3007
3008
3009
3010
3011
3012
3013
3014 defer r.rows.Close()
3015 for _, dp := range dest {
3016 if _, ok := dp.(*RawBytes); ok {
3017 return errors.New("sql: RawBytes isn't allowed on Row.Scan")
3018 }
3019 }
3020
3021 if !r.rows.Next() {
3022 if err := r.rows.Err(); err != nil {
3023 return err
3024 }
3025 return ErrNoRows
3026 }
3027 err := r.rows.Scan(dest...)
3028 if err != nil {
3029 return err
3030 }
3031
3032 return r.rows.Close()
3033 }
3034
3035
3036 type Result interface {
3037
3038
3039
3040
3041
3042 LastInsertId() (int64, error)
3043
3044
3045
3046
3047 RowsAffected() (int64, error)
3048 }
3049
3050 type driverResult struct {
3051 sync.Locker
3052 resi driver.Result
3053 }
3054
3055 func (dr driverResult) LastInsertId() (int64, error) {
3056 dr.Lock()
3057 defer dr.Unlock()
3058 return dr.resi.LastInsertId()
3059 }
3060
3061 func (dr driverResult) RowsAffected() (int64, error) {
3062 dr.Lock()
3063 defer dr.Unlock()
3064 return dr.resi.RowsAffected()
3065 }
3066
3067 func stack() string {
3068 var buf [2 << 10]byte
3069 return string(buf[:runtime.Stack(buf[:], false)])
3070 }
3071
3072
3073 func withLock(lk sync.Locker, fn func()) {
3074 lk.Lock()
3075 defer lk.Unlock()
3076 fn()
3077 }
3078
View as plain text