Source file src/database/sql/fakedb_test.go

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package sql
     6  
     7  import (
     8  	"context"
     9  	"database/sql/driver"
    10  	"errors"
    11  	"fmt"
    12  	"io"
    13  	"reflect"
    14  	"sort"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  // fakeDriver is a fake database that implements Go's driver.Driver
    24  // interface, just for testing.
    25  //
    26  // It speaks a query language that's semantically similar to but
    27  // syntactically different and simpler than SQL.  The syntax is as
    28  // follows:
    29  //
    30  //	WIPE
    31  //	CREATE|<tablename>|<col>=<type>,<col>=<type>,...
    32  //	  where types are: "string", [u]int{8,16,32,64}, "bool"
    33  //	INSERT|<tablename>|col=val,col2=val2,col3=?
    34  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?,filtercol2=?
    35  //	SELECT|<tablename>|projectcol1,projectcol2|filtercol=?param1,filtercol2=?param2
    36  //
    37  // Any of these can be preceded by PANIC|<method>|, to cause the
    38  // named method on fakeStmt to panic.
    39  //
    40  // Any of these can be proceeded by WAIT|<duration>|, to cause the
    41  // named method on fakeStmt to sleep for the specified duration.
    42  //
    43  // Multiple of these can be combined when separated with a semicolon.
    44  //
    45  // When opening a fakeDriver's database, it starts empty with no
    46  // tables. All tables and data are stored in memory only.
    47  type fakeDriver struct {
    48  	mu         sync.Mutex // guards 3 following fields
    49  	openCount  int        // conn opens
    50  	closeCount int        // conn closes
    51  	waitCh     chan struct{}
    52  	waitingCh  chan struct{}
    53  	dbs        map[string]*fakeDB
    54  }
    55  
    56  type fakeConnector struct {
    57  	name string
    58  
    59  	waiter func(context.Context)
    60  	closed bool
    61  }
    62  
    63  func (c *fakeConnector) Connect(context.Context) (driver.Conn, error) {
    64  	conn, err := fdriver.Open(c.name)
    65  	conn.(*fakeConn).waiter = c.waiter
    66  	return conn, err
    67  }
    68  
    69  func (c *fakeConnector) Driver() driver.Driver {
    70  	return fdriver
    71  }
    72  
    73  func (c *fakeConnector) Close() error {
    74  	if c.closed {
    75  		return errors.New("fakedb: connector is closed")
    76  	}
    77  	c.closed = true
    78  	return nil
    79  }
    80  
    81  type fakeDriverCtx struct {
    82  	fakeDriver
    83  }
    84  
    85  var _ driver.DriverContext = &fakeDriverCtx{}
    86  
    87  func (cc *fakeDriverCtx) OpenConnector(name string) (driver.Connector, error) {
    88  	return &fakeConnector{name: name}, nil
    89  }
    90  
    91  type fakeDB struct {
    92  	name string
    93  
    94  	useRawBytes atomic.Bool
    95  
    96  	mu       sync.Mutex
    97  	tables   map[string]*table
    98  	badConn  bool
    99  	allowAny bool
   100  }
   101  
   102  type fakeError struct {
   103  	Message string
   104  	Wrapped error
   105  }
   106  
   107  func (err fakeError) Error() string {
   108  	return err.Message
   109  }
   110  
   111  func (err fakeError) Unwrap() error {
   112  	return err.Wrapped
   113  }
   114  
   115  type table struct {
   116  	mu      sync.Mutex
   117  	colname []string
   118  	coltype []string
   119  	rows    []*row
   120  }
   121  
   122  func (t *table) columnIndex(name string) int {
   123  	for n, nname := range t.colname {
   124  		if name == nname {
   125  			return n
   126  		}
   127  	}
   128  	return -1
   129  }
   130  
   131  type row struct {
   132  	cols []any // must be same size as its table colname + coltype
   133  }
   134  
   135  type memToucher interface {
   136  	// touchMem reads & writes some memory, to help find data races.
   137  	touchMem()
   138  }
   139  
   140  type fakeConn struct {
   141  	db *fakeDB // where to return ourselves to
   142  
   143  	currTx *fakeTx
   144  
   145  	// Every operation writes to line to enable the race detector
   146  	// check for data races.
   147  	line int64
   148  
   149  	// Stats for tests:
   150  	mu          sync.Mutex
   151  	stmtsMade   int
   152  	stmtsClosed int
   153  	numPrepare  int
   154  
   155  	// bad connection tests; see isBad()
   156  	bad       bool
   157  	stickyBad bool
   158  
   159  	skipDirtySession bool // tests that use Conn should set this to true.
   160  
   161  	// dirtySession tests ResetSession, true if a query has executed
   162  	// until ResetSession is called.
   163  	dirtySession bool
   164  
   165  	// The waiter is called before each query. May be used in place of the "WAIT"
   166  	// directive.
   167  	waiter func(context.Context)
   168  }
   169  
   170  func (c *fakeConn) touchMem() {
   171  	c.line++
   172  }
   173  
   174  func (c *fakeConn) incrStat(v *int) {
   175  	c.mu.Lock()
   176  	*v++
   177  	c.mu.Unlock()
   178  }
   179  
   180  type fakeTx struct {
   181  	c *fakeConn
   182  }
   183  
   184  type boundCol struct {
   185  	Column      string
   186  	Placeholder string
   187  	Ordinal     int
   188  }
   189  
   190  type fakeStmt struct {
   191  	memToucher
   192  	c *fakeConn
   193  	q string // just for debugging
   194  
   195  	cmd   string
   196  	table string
   197  	panic string
   198  	wait  time.Duration
   199  
   200  	next *fakeStmt // used for returning multiple results.
   201  
   202  	closed bool
   203  
   204  	colName      []string // used by CREATE, INSERT, SELECT (selected columns)
   205  	colType      []string // used by CREATE
   206  	colValue     []any    // used by INSERT (mix of strings and "?" for bound params)
   207  	placeholders int      // used by INSERT/SELECT: number of ? params
   208  
   209  	whereCol []boundCol // used by SELECT (all placeholders)
   210  
   211  	placeholderConverter []driver.ValueConverter // used by INSERT
   212  }
   213  
   214  var fdriver driver.Driver = &fakeDriver{}
   215  
   216  func init() {
   217  	Register("test", fdriver)
   218  }
   219  
   220  func contains(list []string, y string) bool {
   221  	for _, x := range list {
   222  		if x == y {
   223  			return true
   224  		}
   225  	}
   226  	return false
   227  }
   228  
   229  type Dummy struct {
   230  	driver.Driver
   231  }
   232  
   233  func TestDrivers(t *testing.T) {
   234  	unregisterAllDrivers()
   235  	Register("test", fdriver)
   236  	Register("invalid", Dummy{})
   237  	all := Drivers()
   238  	if len(all) < 2 || !sort.StringsAreSorted(all) || !contains(all, "test") || !contains(all, "invalid") {
   239  		t.Fatalf("Drivers = %v, want sorted list with at least [invalid, test]", all)
   240  	}
   241  }
   242  
   243  // hook to simulate connection failures
   244  var hookOpenErr struct {
   245  	sync.Mutex
   246  	fn func() error
   247  }
   248  
   249  func setHookOpenErr(fn func() error) {
   250  	hookOpenErr.Lock()
   251  	defer hookOpenErr.Unlock()
   252  	hookOpenErr.fn = fn
   253  }
   254  
   255  // Supports dsn forms:
   256  //
   257  //	<dbname>
   258  //	<dbname>;<opts>  (only currently supported option is `badConn`,
   259  //	                  which causes driver.ErrBadConn to be returned on
   260  //	                  every other conn.Begin())
   261  func (d *fakeDriver) Open(dsn string) (driver.Conn, error) {
   262  	hookOpenErr.Lock()
   263  	fn := hookOpenErr.fn
   264  	hookOpenErr.Unlock()
   265  	if fn != nil {
   266  		if err := fn(); err != nil {
   267  			return nil, err
   268  		}
   269  	}
   270  	parts := strings.Split(dsn, ";")
   271  	if len(parts) < 1 {
   272  		return nil, errors.New("fakedb: no database name")
   273  	}
   274  	name := parts[0]
   275  
   276  	db := d.getDB(name)
   277  
   278  	d.mu.Lock()
   279  	d.openCount++
   280  	d.mu.Unlock()
   281  	conn := &fakeConn{db: db}
   282  
   283  	if len(parts) >= 2 && parts[1] == "badConn" {
   284  		conn.bad = true
   285  	}
   286  	if d.waitCh != nil {
   287  		d.waitingCh <- struct{}{}
   288  		<-d.waitCh
   289  		d.waitCh = nil
   290  		d.waitingCh = nil
   291  	}
   292  	return conn, nil
   293  }
   294  
   295  func (d *fakeDriver) getDB(name string) *fakeDB {
   296  	d.mu.Lock()
   297  	defer d.mu.Unlock()
   298  	if d.dbs == nil {
   299  		d.dbs = make(map[string]*fakeDB)
   300  	}
   301  	db, ok := d.dbs[name]
   302  	if !ok {
   303  		db = &fakeDB{name: name}
   304  		d.dbs[name] = db
   305  	}
   306  	return db
   307  }
   308  
   309  func (db *fakeDB) wipe() {
   310  	db.mu.Lock()
   311  	defer db.mu.Unlock()
   312  	db.tables = nil
   313  }
   314  
   315  func (db *fakeDB) createTable(name string, columnNames, columnTypes []string) error {
   316  	db.mu.Lock()
   317  	defer db.mu.Unlock()
   318  	if db.tables == nil {
   319  		db.tables = make(map[string]*table)
   320  	}
   321  	if _, exist := db.tables[name]; exist {
   322  		return fmt.Errorf("fakedb: table %q already exists", name)
   323  	}
   324  	if len(columnNames) != len(columnTypes) {
   325  		return fmt.Errorf("fakedb: create table of %q len(names) != len(types): %d vs %d",
   326  			name, len(columnNames), len(columnTypes))
   327  	}
   328  	db.tables[name] = &table{colname: columnNames, coltype: columnTypes}
   329  	return nil
   330  }
   331  
   332  // must be called with db.mu lock held
   333  func (db *fakeDB) table(table string) (*table, bool) {
   334  	if db.tables == nil {
   335  		return nil, false
   336  	}
   337  	t, ok := db.tables[table]
   338  	return t, ok
   339  }
   340  
   341  func (db *fakeDB) columnType(table, column string) (typ string, ok bool) {
   342  	db.mu.Lock()
   343  	defer db.mu.Unlock()
   344  	t, ok := db.table(table)
   345  	if !ok {
   346  		return
   347  	}
   348  	for n, cname := range t.colname {
   349  		if cname == column {
   350  			return t.coltype[n], true
   351  		}
   352  	}
   353  	return "", false
   354  }
   355  
   356  func (c *fakeConn) isBad() bool {
   357  	if c.stickyBad {
   358  		return true
   359  	} else if c.bad {
   360  		if c.db == nil {
   361  			return false
   362  		}
   363  		// alternate between bad conn and not bad conn
   364  		c.db.badConn = !c.db.badConn
   365  		return c.db.badConn
   366  	} else {
   367  		return false
   368  	}
   369  }
   370  
   371  func (c *fakeConn) isDirtyAndMark() bool {
   372  	if c.skipDirtySession {
   373  		return false
   374  	}
   375  	if c.currTx != nil {
   376  		c.dirtySession = true
   377  		return false
   378  	}
   379  	if c.dirtySession {
   380  		return true
   381  	}
   382  	c.dirtySession = true
   383  	return false
   384  }
   385  
   386  func (c *fakeConn) Begin() (driver.Tx, error) {
   387  	if c.isBad() {
   388  		return nil, fakeError{Wrapped: driver.ErrBadConn}
   389  	}
   390  	if c.currTx != nil {
   391  		return nil, errors.New("fakedb: already in a transaction")
   392  	}
   393  	c.touchMem()
   394  	c.currTx = &fakeTx{c: c}
   395  	return c.currTx, nil
   396  }
   397  
   398  var hookPostCloseConn struct {
   399  	sync.Mutex
   400  	fn func(*fakeConn, error)
   401  }
   402  
   403  func setHookpostCloseConn(fn func(*fakeConn, error)) {
   404  	hookPostCloseConn.Lock()
   405  	defer hookPostCloseConn.Unlock()
   406  	hookPostCloseConn.fn = fn
   407  }
   408  
   409  var testStrictClose *testing.T
   410  
   411  // setStrictFakeConnClose sets the t to Errorf on when fakeConn.Close
   412  // fails to close. If nil, the check is disabled.
   413  func setStrictFakeConnClose(t *testing.T) {
   414  	testStrictClose = t
   415  }
   416  
   417  func (c *fakeConn) ResetSession(ctx context.Context) error {
   418  	c.dirtySession = false
   419  	c.currTx = nil
   420  	if c.isBad() {
   421  		return fakeError{Message: "Reset Session: bad conn", Wrapped: driver.ErrBadConn}
   422  	}
   423  	return nil
   424  }
   425  
   426  var _ driver.Validator = (*fakeConn)(nil)
   427  
   428  func (c *fakeConn) IsValid() bool {
   429  	return !c.isBad()
   430  }
   431  
   432  func (c *fakeConn) Close() (err error) {
   433  	drv := fdriver.(*fakeDriver)
   434  	defer func() {
   435  		if err != nil && testStrictClose != nil {
   436  			testStrictClose.Errorf("failed to close a test fakeConn: %v", err)
   437  		}
   438  		hookPostCloseConn.Lock()
   439  		fn := hookPostCloseConn.fn
   440  		hookPostCloseConn.Unlock()
   441  		if fn != nil {
   442  			fn(c, err)
   443  		}
   444  		if err == nil {
   445  			drv.mu.Lock()
   446  			drv.closeCount++
   447  			drv.mu.Unlock()
   448  		}
   449  	}()
   450  	c.touchMem()
   451  	if c.currTx != nil {
   452  		return errors.New("fakedb: can't close fakeConn; in a Transaction")
   453  	}
   454  	if c.db == nil {
   455  		return errors.New("fakedb: can't close fakeConn; already closed")
   456  	}
   457  	if c.stmtsMade > c.stmtsClosed {
   458  		return errors.New("fakedb: can't close; dangling statement(s)")
   459  	}
   460  	c.db = nil
   461  	return nil
   462  }
   463  
   464  func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
   465  	for _, arg := range args {
   466  		switch arg.Value.(type) {
   467  		case int64, float64, bool, nil, []byte, string, time.Time:
   468  		default:
   469  			if !allowAny {
   470  				return fmt.Errorf("fakedb: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
   471  			}
   472  		}
   473  	}
   474  	return nil
   475  }
   476  
   477  func (c *fakeConn) Exec(query string, args []driver.Value) (driver.Result, error) {
   478  	// Ensure that ExecContext is called if available.
   479  	panic("ExecContext was not called.")
   480  }
   481  
   482  func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) {
   483  	// This is an optional interface, but it's implemented here
   484  	// just to check that all the args are of the proper types.
   485  	// ErrSkip is returned so the caller acts as if we didn't
   486  	// implement this at all.
   487  	err := checkSubsetTypes(c.db.allowAny, args)
   488  	if err != nil {
   489  		return nil, err
   490  	}
   491  	return nil, driver.ErrSkip
   492  }
   493  
   494  func (c *fakeConn) Query(query string, args []driver.Value) (driver.Rows, error) {
   495  	// Ensure that ExecContext is called if available.
   496  	panic("QueryContext was not called.")
   497  }
   498  
   499  func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) {
   500  	// This is an optional interface, but it's implemented here
   501  	// just to check that all the args are of the proper types.
   502  	// ErrSkip is returned so the caller acts as if we didn't
   503  	// implement this at all.
   504  	err := checkSubsetTypes(c.db.allowAny, args)
   505  	if err != nil {
   506  		return nil, err
   507  	}
   508  	return nil, driver.ErrSkip
   509  }
   510  
   511  func errf(msg string, args ...any) error {
   512  	return errors.New("fakedb: " + fmt.Sprintf(msg, args...))
   513  }
   514  
   515  // parts are table|selectCol1,selectCol2|whereCol=?,whereCol2=?
   516  // (note that where columns must always contain ? marks,
   517  // just a limitation for fakedb)
   518  func (c *fakeConn) prepareSelect(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   519  	if len(parts) != 3 {
   520  		stmt.Close()
   521  		return nil, errf("invalid SELECT syntax with %d parts; want 3", len(parts))
   522  	}
   523  	stmt.table = parts[0]
   524  
   525  	stmt.colName = strings.Split(parts[1], ",")
   526  	for n, colspec := range strings.Split(parts[2], ",") {
   527  		if colspec == "" {
   528  			continue
   529  		}
   530  		nameVal := strings.Split(colspec, "=")
   531  		if len(nameVal) != 2 {
   532  			stmt.Close()
   533  			return nil, errf("SELECT on table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   534  		}
   535  		column, value := nameVal[0], nameVal[1]
   536  		_, ok := c.db.columnType(stmt.table, column)
   537  		if !ok {
   538  			stmt.Close()
   539  			return nil, errf("SELECT on table %q references non-existent column %q", stmt.table, column)
   540  		}
   541  		if !strings.HasPrefix(value, "?") {
   542  			stmt.Close()
   543  			return nil, errf("SELECT on table %q has pre-bound value for where column %q; need a question mark",
   544  				stmt.table, column)
   545  		}
   546  		stmt.placeholders++
   547  		stmt.whereCol = append(stmt.whereCol, boundCol{Column: column, Placeholder: value, Ordinal: stmt.placeholders})
   548  	}
   549  	return stmt, nil
   550  }
   551  
   552  // parts are table|col=type,col2=type2
   553  func (c *fakeConn) prepareCreate(stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   554  	if len(parts) != 2 {
   555  		stmt.Close()
   556  		return nil, errf("invalid CREATE syntax with %d parts; want 2", len(parts))
   557  	}
   558  	stmt.table = parts[0]
   559  	for n, colspec := range strings.Split(parts[1], ",") {
   560  		nameType := strings.Split(colspec, "=")
   561  		if len(nameType) != 2 {
   562  			stmt.Close()
   563  			return nil, errf("CREATE table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   564  		}
   565  		stmt.colName = append(stmt.colName, nameType[0])
   566  		stmt.colType = append(stmt.colType, nameType[1])
   567  	}
   568  	return stmt, nil
   569  }
   570  
   571  // parts are table|col=?,col2=val
   572  func (c *fakeConn) prepareInsert(ctx context.Context, stmt *fakeStmt, parts []string) (*fakeStmt, error) {
   573  	if len(parts) != 2 {
   574  		stmt.Close()
   575  		return nil, errf("invalid INSERT syntax with %d parts; want 2", len(parts))
   576  	}
   577  	stmt.table = parts[0]
   578  	for n, colspec := range strings.Split(parts[1], ",") {
   579  		nameVal := strings.Split(colspec, "=")
   580  		if len(nameVal) != 2 {
   581  			stmt.Close()
   582  			return nil, errf("INSERT table %q has invalid column spec of %q (index %d)", stmt.table, colspec, n)
   583  		}
   584  		column, value := nameVal[0], nameVal[1]
   585  		ctype, ok := c.db.columnType(stmt.table, column)
   586  		if !ok {
   587  			stmt.Close()
   588  			return nil, errf("INSERT table %q references non-existent column %q", stmt.table, column)
   589  		}
   590  		stmt.colName = append(stmt.colName, column)
   591  
   592  		if !strings.HasPrefix(value, "?") {
   593  			var subsetVal any
   594  			// Convert to driver subset type
   595  			switch ctype {
   596  			case "string":
   597  				subsetVal = []byte(value)
   598  			case "blob":
   599  				subsetVal = []byte(value)
   600  			case "int32":
   601  				i, err := strconv.Atoi(value)
   602  				if err != nil {
   603  					stmt.Close()
   604  					return nil, errf("invalid conversion to int32 from %q", value)
   605  				}
   606  				subsetVal = int64(i) // int64 is a subset type, but not int32
   607  			case "table": // For testing cursor reads.
   608  				c.skipDirtySession = true
   609  				vparts := strings.Split(value, "!")
   610  
   611  				substmt, err := c.PrepareContext(ctx, fmt.Sprintf("SELECT|%s|%s|", vparts[0], strings.Join(vparts[1:], ",")))
   612  				if err != nil {
   613  					return nil, err
   614  				}
   615  				cursor, err := (substmt.(driver.StmtQueryContext)).QueryContext(ctx, []driver.NamedValue{})
   616  				substmt.Close()
   617  				if err != nil {
   618  					return nil, err
   619  				}
   620  				subsetVal = cursor
   621  			default:
   622  				stmt.Close()
   623  				return nil, errf("unsupported conversion for pre-bound parameter %q to type %q", value, ctype)
   624  			}
   625  			stmt.colValue = append(stmt.colValue, subsetVal)
   626  		} else {
   627  			stmt.placeholders++
   628  			stmt.placeholderConverter = append(stmt.placeholderConverter, converterForType(ctype))
   629  			stmt.colValue = append(stmt.colValue, value)
   630  		}
   631  	}
   632  	return stmt, nil
   633  }
   634  
   635  // hook to simulate broken connections
   636  var hookPrepareBadConn func() bool
   637  
   638  func (c *fakeConn) Prepare(query string) (driver.Stmt, error) {
   639  	panic("use PrepareContext")
   640  }
   641  
   642  func (c *fakeConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) {
   643  	c.numPrepare++
   644  	if c.db == nil {
   645  		panic("nil c.db; conn = " + fmt.Sprintf("%#v", c))
   646  	}
   647  
   648  	if c.stickyBad || (hookPrepareBadConn != nil && hookPrepareBadConn()) {
   649  		return nil, fakeError{Message: "Prepare: Sticky Bad", Wrapped: driver.ErrBadConn}
   650  	}
   651  
   652  	c.touchMem()
   653  	var firstStmt, prev *fakeStmt
   654  	for _, query := range strings.Split(query, ";") {
   655  		parts := strings.Split(query, "|")
   656  		if len(parts) < 1 {
   657  			return nil, errf("empty query")
   658  		}
   659  		stmt := &fakeStmt{q: query, c: c, memToucher: c}
   660  		if firstStmt == nil {
   661  			firstStmt = stmt
   662  		}
   663  		if len(parts) >= 3 {
   664  			switch parts[0] {
   665  			case "PANIC":
   666  				stmt.panic = parts[1]
   667  				parts = parts[2:]
   668  			case "WAIT":
   669  				wait, err := time.ParseDuration(parts[1])
   670  				if err != nil {
   671  					return nil, errf("expected section after WAIT to be a duration, got %q %v", parts[1], err)
   672  				}
   673  				parts = parts[2:]
   674  				stmt.wait = wait
   675  			}
   676  		}
   677  		cmd := parts[0]
   678  		stmt.cmd = cmd
   679  		parts = parts[1:]
   680  
   681  		if c.waiter != nil {
   682  			c.waiter(ctx)
   683  			if err := ctx.Err(); err != nil {
   684  				return nil, err
   685  			}
   686  		}
   687  
   688  		if stmt.wait > 0 {
   689  			wait := time.NewTimer(stmt.wait)
   690  			select {
   691  			case <-wait.C:
   692  			case <-ctx.Done():
   693  				wait.Stop()
   694  				return nil, ctx.Err()
   695  			}
   696  		}
   697  
   698  		c.incrStat(&c.stmtsMade)
   699  		var err error
   700  		switch cmd {
   701  		case "WIPE":
   702  			// Nothing
   703  		case "USE_RAWBYTES":
   704  			c.db.useRawBytes.Store(true)
   705  		case "SELECT":
   706  			stmt, err = c.prepareSelect(stmt, parts)
   707  		case "CREATE":
   708  			stmt, err = c.prepareCreate(stmt, parts)
   709  		case "INSERT":
   710  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   711  		case "NOSERT":
   712  			// Do all the prep-work like for an INSERT but don't actually insert the row.
   713  			// Used for some of the concurrent tests.
   714  			stmt, err = c.prepareInsert(ctx, stmt, parts)
   715  		default:
   716  			stmt.Close()
   717  			return nil, errf("unsupported command type %q", cmd)
   718  		}
   719  		if err != nil {
   720  			return nil, err
   721  		}
   722  		if prev != nil {
   723  			prev.next = stmt
   724  		}
   725  		prev = stmt
   726  	}
   727  	return firstStmt, nil
   728  }
   729  
   730  func (s *fakeStmt) ColumnConverter(idx int) driver.ValueConverter {
   731  	if s.panic == "ColumnConverter" {
   732  		panic(s.panic)
   733  	}
   734  	if len(s.placeholderConverter) == 0 {
   735  		return driver.DefaultParameterConverter
   736  	}
   737  	return s.placeholderConverter[idx]
   738  }
   739  
   740  func (s *fakeStmt) Close() error {
   741  	if s.panic == "Close" {
   742  		panic(s.panic)
   743  	}
   744  	if s.c == nil {
   745  		panic("nil conn in fakeStmt.Close")
   746  	}
   747  	if s.c.db == nil {
   748  		panic("in fakeStmt.Close, conn's db is nil (already closed)")
   749  	}
   750  	s.touchMem()
   751  	if !s.closed {
   752  		s.c.incrStat(&s.c.stmtsClosed)
   753  		s.closed = true
   754  	}
   755  	if s.next != nil {
   756  		s.next.Close()
   757  	}
   758  	return nil
   759  }
   760  
   761  var errClosed = errors.New("fakedb: statement has been closed")
   762  
   763  // hook to simulate broken connections
   764  var hookExecBadConn func() bool
   765  
   766  func (s *fakeStmt) Exec(args []driver.Value) (driver.Result, error) {
   767  	panic("Using ExecContext")
   768  }
   769  
   770  var errFakeConnSessionDirty = errors.New("fakedb: session is dirty")
   771  
   772  func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
   773  	if s.panic == "Exec" {
   774  		panic(s.panic)
   775  	}
   776  	if s.closed {
   777  		return nil, errClosed
   778  	}
   779  
   780  	if s.c.stickyBad || (hookExecBadConn != nil && hookExecBadConn()) {
   781  		return nil, fakeError{Message: "Exec: Sticky Bad", Wrapped: driver.ErrBadConn}
   782  	}
   783  	if s.c.isDirtyAndMark() {
   784  		return nil, errFakeConnSessionDirty
   785  	}
   786  
   787  	err := checkSubsetTypes(s.c.db.allowAny, args)
   788  	if err != nil {
   789  		return nil, err
   790  	}
   791  	s.touchMem()
   792  
   793  	if s.wait > 0 {
   794  		time.Sleep(s.wait)
   795  	}
   796  
   797  	select {
   798  	default:
   799  	case <-ctx.Done():
   800  		return nil, ctx.Err()
   801  	}
   802  
   803  	db := s.c.db
   804  	switch s.cmd {
   805  	case "WIPE":
   806  		db.wipe()
   807  		return driver.ResultNoRows, nil
   808  	case "USE_RAWBYTES":
   809  		s.c.db.useRawBytes.Store(true)
   810  		return driver.ResultNoRows, nil
   811  	case "CREATE":
   812  		if err := db.createTable(s.table, s.colName, s.colType); err != nil {
   813  			return nil, err
   814  		}
   815  		return driver.ResultNoRows, nil
   816  	case "INSERT":
   817  		return s.execInsert(args, true)
   818  	case "NOSERT":
   819  		// Do all the prep-work like for an INSERT but don't actually insert the row.
   820  		// Used for some of the concurrent tests.
   821  		return s.execInsert(args, false)
   822  	}
   823  	return nil, fmt.Errorf("fakedb: unimplemented statement Exec command type of %q", s.cmd)
   824  }
   825  
   826  // When doInsert is true, add the row to the table.
   827  // When doInsert is false do prep-work and error checking, but don't
   828  // actually add the row to the table.
   829  func (s *fakeStmt) execInsert(args []driver.NamedValue, doInsert bool) (driver.Result, error) {
   830  	db := s.c.db
   831  	if len(args) != s.placeholders {
   832  		panic("error in pkg db; should only get here if size is correct")
   833  	}
   834  	db.mu.Lock()
   835  	t, ok := db.table(s.table)
   836  	db.mu.Unlock()
   837  	if !ok {
   838  		return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   839  	}
   840  
   841  	t.mu.Lock()
   842  	defer t.mu.Unlock()
   843  
   844  	var cols []any
   845  	if doInsert {
   846  		cols = make([]any, len(t.colname))
   847  	}
   848  	argPos := 0
   849  	for n, colname := range s.colName {
   850  		colidx := t.columnIndex(colname)
   851  		if colidx == -1 {
   852  			return nil, fmt.Errorf("fakedb: column %q doesn't exist or dropped since prepared statement was created", colname)
   853  		}
   854  		var val any
   855  		if strvalue, ok := s.colValue[n].(string); ok && strings.HasPrefix(strvalue, "?") {
   856  			if strvalue == "?" {
   857  				val = args[argPos].Value
   858  			} else {
   859  				// Assign value from argument placeholder name.
   860  				for _, a := range args {
   861  					if a.Name == strvalue[1:] {
   862  						val = a.Value
   863  						break
   864  					}
   865  				}
   866  			}
   867  			argPos++
   868  		} else {
   869  			val = s.colValue[n]
   870  		}
   871  		if doInsert {
   872  			cols[colidx] = val
   873  		}
   874  	}
   875  
   876  	if doInsert {
   877  		t.rows = append(t.rows, &row{cols: cols})
   878  	}
   879  	return driver.RowsAffected(1), nil
   880  }
   881  
   882  // hook to simulate broken connections
   883  var hookQueryBadConn func() bool
   884  
   885  func (s *fakeStmt) Query(args []driver.Value) (driver.Rows, error) {
   886  	panic("Use QueryContext")
   887  }
   888  
   889  func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
   890  	if s.panic == "Query" {
   891  		panic(s.panic)
   892  	}
   893  	if s.closed {
   894  		return nil, errClosed
   895  	}
   896  
   897  	if s.c.stickyBad || (hookQueryBadConn != nil && hookQueryBadConn()) {
   898  		return nil, fakeError{Message: "Query: Sticky Bad", Wrapped: driver.ErrBadConn}
   899  	}
   900  	if s.c.isDirtyAndMark() {
   901  		return nil, errFakeConnSessionDirty
   902  	}
   903  
   904  	err := checkSubsetTypes(s.c.db.allowAny, args)
   905  	if err != nil {
   906  		return nil, err
   907  	}
   908  
   909  	s.touchMem()
   910  	db := s.c.db
   911  	if len(args) != s.placeholders {
   912  		panic("error in pkg db; should only get here if size is correct")
   913  	}
   914  
   915  	setMRows := make([][]*row, 0, 1)
   916  	setColumns := make([][]string, 0, 1)
   917  	setColType := make([][]string, 0, 1)
   918  
   919  	for {
   920  		db.mu.Lock()
   921  		t, ok := db.table(s.table)
   922  		db.mu.Unlock()
   923  		if !ok {
   924  			return nil, fmt.Errorf("fakedb: table %q doesn't exist", s.table)
   925  		}
   926  
   927  		if s.table == "magicquery" {
   928  			if len(s.whereCol) == 2 && s.whereCol[0].Column == "op" && s.whereCol[1].Column == "millis" {
   929  				if args[0].Value == "sleep" {
   930  					time.Sleep(time.Duration(args[1].Value.(int64)) * time.Millisecond)
   931  				}
   932  			}
   933  		}
   934  		if s.table == "tx_status" && s.colName[0] == "tx_status" {
   935  			txStatus := "autocommit"
   936  			if s.c.currTx != nil {
   937  				txStatus = "transaction"
   938  			}
   939  			cursor := &rowsCursor{
   940  				db:        s.c.db,
   941  				parentMem: s.c,
   942  				posRow:    -1,
   943  				rows: [][]*row{
   944  					{
   945  						{
   946  							cols: []any{
   947  								txStatus,
   948  							},
   949  						},
   950  					},
   951  				},
   952  				cols: [][]string{
   953  					{
   954  						"tx_status",
   955  					},
   956  				},
   957  				colType: [][]string{
   958  					{
   959  						"string",
   960  					},
   961  				},
   962  				errPos: -1,
   963  			}
   964  			return cursor, nil
   965  		}
   966  
   967  		t.mu.Lock()
   968  
   969  		colIdx := make(map[string]int) // select column name -> column index in table
   970  		for _, name := range s.colName {
   971  			idx := t.columnIndex(name)
   972  			if idx == -1 {
   973  				t.mu.Unlock()
   974  				return nil, fmt.Errorf("fakedb: unknown column name %q", name)
   975  			}
   976  			colIdx[name] = idx
   977  		}
   978  
   979  		mrows := []*row{}
   980  	rows:
   981  		for _, trow := range t.rows {
   982  			// Process the where clause, skipping non-match rows. This is lazy
   983  			// and just uses fmt.Sprintf("%v") to test equality. Good enough
   984  			// for test code.
   985  			for _, wcol := range s.whereCol {
   986  				idx := t.columnIndex(wcol.Column)
   987  				if idx == -1 {
   988  					t.mu.Unlock()
   989  					return nil, fmt.Errorf("fakedb: invalid where clause column %q", wcol)
   990  				}
   991  				tcol := trow.cols[idx]
   992  				if bs, ok := tcol.([]byte); ok {
   993  					// lazy hack to avoid sprintf %v on a []byte
   994  					tcol = string(bs)
   995  				}
   996  				var argValue any
   997  				if wcol.Placeholder == "?" {
   998  					argValue = args[wcol.Ordinal-1].Value
   999  				} else {
  1000  					// Assign arg value from placeholder name.
  1001  					for _, a := range args {
  1002  						if a.Name == wcol.Placeholder[1:] {
  1003  							argValue = a.Value
  1004  							break
  1005  						}
  1006  					}
  1007  				}
  1008  				if fmt.Sprintf("%v", tcol) != fmt.Sprintf("%v", argValue) {
  1009  					continue rows
  1010  				}
  1011  			}
  1012  			mrow := &row{cols: make([]any, len(s.colName))}
  1013  			for seli, name := range s.colName {
  1014  				mrow.cols[seli] = trow.cols[colIdx[name]]
  1015  			}
  1016  			mrows = append(mrows, mrow)
  1017  		}
  1018  
  1019  		var colType []string
  1020  		for _, column := range s.colName {
  1021  			colType = append(colType, t.coltype[t.columnIndex(column)])
  1022  		}
  1023  
  1024  		t.mu.Unlock()
  1025  
  1026  		setMRows = append(setMRows, mrows)
  1027  		setColumns = append(setColumns, s.colName)
  1028  		setColType = append(setColType, colType)
  1029  
  1030  		if s.next == nil {
  1031  			break
  1032  		}
  1033  		s = s.next
  1034  	}
  1035  
  1036  	cursor := &rowsCursor{
  1037  		db:        s.c.db,
  1038  		parentMem: s.c,
  1039  		posRow:    -1,
  1040  		rows:      setMRows,
  1041  		cols:      setColumns,
  1042  		colType:   setColType,
  1043  		errPos:    -1,
  1044  	}
  1045  	return cursor, nil
  1046  }
  1047  
  1048  func (s *fakeStmt) NumInput() int {
  1049  	if s.panic == "NumInput" {
  1050  		panic(s.panic)
  1051  	}
  1052  	return s.placeholders
  1053  }
  1054  
  1055  // hook to simulate broken connections
  1056  var hookCommitBadConn func() bool
  1057  
  1058  func (tx *fakeTx) Commit() error {
  1059  	tx.c.currTx = nil
  1060  	if hookCommitBadConn != nil && hookCommitBadConn() {
  1061  		return fakeError{Message: "Commit: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1062  	}
  1063  	tx.c.touchMem()
  1064  	return nil
  1065  }
  1066  
  1067  // hook to simulate broken connections
  1068  var hookRollbackBadConn func() bool
  1069  
  1070  func (tx *fakeTx) Rollback() error {
  1071  	tx.c.currTx = nil
  1072  	if hookRollbackBadConn != nil && hookRollbackBadConn() {
  1073  		return fakeError{Message: "Rollback: Hook Bad Conn", Wrapped: driver.ErrBadConn}
  1074  	}
  1075  	tx.c.touchMem()
  1076  	return nil
  1077  }
  1078  
  1079  type rowsCursor struct {
  1080  	db        *fakeDB
  1081  	parentMem memToucher
  1082  	cols      [][]string
  1083  	colType   [][]string
  1084  	posSet    int
  1085  	posRow    int
  1086  	rows      [][]*row
  1087  	closed    bool
  1088  
  1089  	// errPos and err are for making Next return early with error.
  1090  	errPos int
  1091  	err    error
  1092  
  1093  	// a clone of slices to give out to clients, indexed by the
  1094  	// original slice's first byte address.  we clone them
  1095  	// just so we're able to corrupt them on close.
  1096  	bytesClone map[*byte][]byte
  1097  
  1098  	// Every operation writes to line to enable the race detector
  1099  	// check for data races.
  1100  	// This is separate from the fakeConn.line to allow for drivers that
  1101  	// can start multiple queries on the same transaction at the same time.
  1102  	line int64
  1103  
  1104  	// closeErr is returned when rowsCursor.Close
  1105  	closeErr error
  1106  }
  1107  
  1108  func (rc *rowsCursor) touchMem() {
  1109  	rc.parentMem.touchMem()
  1110  	rc.line++
  1111  }
  1112  
  1113  func (rc *rowsCursor) Close() error {
  1114  	rc.touchMem()
  1115  	rc.parentMem.touchMem()
  1116  	rc.closed = true
  1117  	return rc.closeErr
  1118  }
  1119  
  1120  func (rc *rowsCursor) Columns() []string {
  1121  	return rc.cols[rc.posSet]
  1122  }
  1123  
  1124  func (rc *rowsCursor) ColumnTypeScanType(index int) reflect.Type {
  1125  	return colTypeToReflectType(rc.colType[rc.posSet][index])
  1126  }
  1127  
  1128  var rowsCursorNextHook func(dest []driver.Value) error
  1129  
  1130  func (rc *rowsCursor) Next(dest []driver.Value) error {
  1131  	if rowsCursorNextHook != nil {
  1132  		return rowsCursorNextHook(dest)
  1133  	}
  1134  
  1135  	if rc.closed {
  1136  		return errors.New("fakedb: cursor is closed")
  1137  	}
  1138  	rc.touchMem()
  1139  	rc.posRow++
  1140  	if rc.posRow == rc.errPos {
  1141  		return rc.err
  1142  	}
  1143  	if rc.posRow >= len(rc.rows[rc.posSet]) {
  1144  		return io.EOF // per interface spec
  1145  	}
  1146  	for i, v := range rc.rows[rc.posSet][rc.posRow].cols {
  1147  		// TODO(bradfitz): convert to subset types? naah, I
  1148  		// think the subset types should only be input to
  1149  		// driver, but the sql package should be able to handle
  1150  		// a wider range of types coming out of drivers. all
  1151  		// for ease of drivers, and to prevent drivers from
  1152  		// messing up conversions or doing them differently.
  1153  		dest[i] = v
  1154  
  1155  		if bs, ok := v.([]byte); ok && !rc.db.useRawBytes.Load() {
  1156  			if rc.bytesClone == nil {
  1157  				rc.bytesClone = make(map[*byte][]byte)
  1158  			}
  1159  			clone, ok := rc.bytesClone[&bs[0]]
  1160  			if !ok {
  1161  				clone = make([]byte, len(bs))
  1162  				copy(clone, bs)
  1163  				rc.bytesClone[&bs[0]] = clone
  1164  			}
  1165  			dest[i] = clone
  1166  		}
  1167  	}
  1168  	return nil
  1169  }
  1170  
  1171  func (rc *rowsCursor) HasNextResultSet() bool {
  1172  	rc.touchMem()
  1173  	return rc.posSet < len(rc.rows)-1
  1174  }
  1175  
  1176  func (rc *rowsCursor) NextResultSet() error {
  1177  	rc.touchMem()
  1178  	if rc.HasNextResultSet() {
  1179  		rc.posSet++
  1180  		rc.posRow = -1
  1181  		return nil
  1182  	}
  1183  	return io.EOF // Per interface spec.
  1184  }
  1185  
  1186  // fakeDriverString is like driver.String, but indirects pointers like
  1187  // DefaultValueConverter.
  1188  //
  1189  // This could be surprising behavior to retroactively apply to
  1190  // driver.String now that Go1 is out, but this is convenient for
  1191  // our TestPointerParamsAndScans.
  1192  type fakeDriverString struct{}
  1193  
  1194  func (fakeDriverString) ConvertValue(v any) (driver.Value, error) {
  1195  	switch c := v.(type) {
  1196  	case string, []byte:
  1197  		return v, nil
  1198  	case *string:
  1199  		if c == nil {
  1200  			return nil, nil
  1201  		}
  1202  		return *c, nil
  1203  	}
  1204  	return fmt.Sprintf("%v", v), nil
  1205  }
  1206  
  1207  type anyTypeConverter struct{}
  1208  
  1209  func (anyTypeConverter) ConvertValue(v any) (driver.Value, error) {
  1210  	return v, nil
  1211  }
  1212  
  1213  func converterForType(typ string) driver.ValueConverter {
  1214  	switch typ {
  1215  	case "bool":
  1216  		return driver.Bool
  1217  	case "nullbool":
  1218  		return driver.Null{Converter: driver.Bool}
  1219  	case "byte", "int16":
  1220  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1221  	case "int32":
  1222  		return driver.Int32
  1223  	case "nullbyte", "nullint32", "nullint16":
  1224  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1225  	case "string":
  1226  		return driver.NotNull{Converter: fakeDriverString{}}
  1227  	case "nullstring":
  1228  		return driver.Null{Converter: fakeDriverString{}}
  1229  	case "int64":
  1230  		// TODO(coopernurse): add type-specific converter
  1231  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1232  	case "nullint64":
  1233  		// TODO(coopernurse): add type-specific converter
  1234  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1235  	case "float64":
  1236  		// TODO(coopernurse): add type-specific converter
  1237  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1238  	case "nullfloat64":
  1239  		// TODO(coopernurse): add type-specific converter
  1240  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1241  	case "datetime":
  1242  		return driver.NotNull{Converter: driver.DefaultParameterConverter}
  1243  	case "nulldatetime":
  1244  		return driver.Null{Converter: driver.DefaultParameterConverter}
  1245  	case "any":
  1246  		return anyTypeConverter{}
  1247  	}
  1248  	panic("invalid fakedb column type of " + typ)
  1249  }
  1250  
  1251  func colTypeToReflectType(typ string) reflect.Type {
  1252  	switch typ {
  1253  	case "bool":
  1254  		return reflect.TypeFor[bool]()
  1255  	case "nullbool":
  1256  		return reflect.TypeFor[NullBool]()
  1257  	case "int16":
  1258  		return reflect.TypeFor[int16]()
  1259  	case "nullint16":
  1260  		return reflect.TypeFor[NullInt16]()
  1261  	case "int32":
  1262  		return reflect.TypeFor[int32]()
  1263  	case "nullint32":
  1264  		return reflect.TypeFor[NullInt32]()
  1265  	case "string":
  1266  		return reflect.TypeFor[string]()
  1267  	case "nullstring":
  1268  		return reflect.TypeFor[NullString]()
  1269  	case "int64":
  1270  		return reflect.TypeFor[int64]()
  1271  	case "nullint64":
  1272  		return reflect.TypeFor[NullInt64]()
  1273  	case "float64":
  1274  		return reflect.TypeFor[float64]()
  1275  	case "nullfloat64":
  1276  		return reflect.TypeFor[NullFloat64]()
  1277  	case "datetime":
  1278  		return reflect.TypeFor[time.Time]()
  1279  	case "any":
  1280  		return reflect.TypeFor[any]()
  1281  	}
  1282  	panic("invalid fakedb column type of " + typ)
  1283  }
  1284  

View as plain text