Commit a9bf3b2e authored by Daniel Theophanes's avatar Daniel Theophanes

database/sql: allow drivers to support custom arg types

Previously all arguments were passed through driver.IsValid.
This checked arguments against a few fundamental go types and
prevented others from being passed in as arguments.

The new interface driver.NamedValueChecker may be implemented
by both driver.Stmt and driver.Conn. This allows
this new interface to completely supersede the
driver.ColumnConverter interface as it can be used for
checking arguments known to a prepared statement and
arbitrary query arguments. The NamedValueChecker may be
skipped with driver.ErrSkip after all special cases are
exhausted to use the default argument converter.

In addition if driver.ErrRemoveArgument is returned
the argument will not be passed to the query at all,
useful for passing in driver specific per-query options.

Add a canonical Out argument wrapper to be passed
to OUTPUT parameters. This will unify checks that need to
be written in the NameValueChecker.

The statement number check is also moved to the argument
converter so the NamedValueChecker may remove arguments
passed to the query.

Fixes #13567
Fixes #18079
Updates #18417
Updates #17834
Updates #16235
Updates #13067
Updates #19797

Change-Id: I89088bd9cca4596a48bba37bfd20d987453ef237
Reviewed-on: https://go-review.googlesource.com/38533Reviewed-by: default avatarBrad Fitzpatrick <bradfitz@golang.org>
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Gobot Gobot <gobot@golang.org>
parent 9044cb04
...@@ -12,6 +12,7 @@ import ( ...@@ -12,6 +12,7 @@ import (
"fmt" "fmt"
"reflect" "reflect"
"strconv" "strconv"
"sync"
"time" "time"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
...@@ -37,86 +38,180 @@ func validateNamedValueName(name string) error { ...@@ -37,86 +38,180 @@ func validateNamedValueName(name string) error {
return fmt.Errorf("name %q does not begin with a letter", name) return fmt.Errorf("name %q does not begin with a letter", name)
} }
func driverNumInput(ds *driverStmt) int {
ds.Lock()
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
// ccChecker wraps the driver.ColumnConverter and allows it to be used
// as if it were a NamedValueChecker. If the driver ColumnConverter
// is not present then the NamedValueChecker will return driver.ErrSkip.
type ccChecker struct {
sync.Locker
cci driver.ColumnConverter
want int
}
func (c ccChecker) CheckNamedValue(nv *driver.NamedValue) error {
if c.cci == nil {
return driver.ErrSkip
}
// The column converter shouldn't be called on any index
// it isn't expecting. The final error will be thrown
// in the argument converter loop.
index := nv.Ordinal - 1
if c.want <= index {
return nil
}
// First, see if the value itself knows how to convert
// itself to a driver type. For example, a NullString
// struct changing into a string or nil.
if vr, ok := nv.Value.(driver.Valuer); ok {
sv, err := callValuerValue(vr)
if err != nil {
return err
}
if !driver.IsValue(sv) {
return fmt.Errorf("non-subset type %T returned from Value", sv)
}
nv.Value = sv
}
// Second, ask the column to sanity check itself. For
// example, drivers might use this to make sure that
// an int64 values being inserted into a 16-bit
// integer field is in range (before getting
// truncated), or that a nil can't go into a NOT NULL
// column before going across the network to get the
// same error.
var err error
arg := nv.Value
c.Lock()
nv.Value, err = c.cci.ColumnConverter(index).ConvertValue(arg)
c.Unlock()
if err != nil {
return err
}
if !driver.IsValue(nv.Value) {
return fmt.Errorf("driver ColumnConverter error converted %T to unsupported type %T", arg, nv.Value)
}
return nil
}
// defaultCheckNamedValue wraps the default ColumnConverter to have the same
// function signature as the CheckNamedValue in the driver.NamedValueChecker
// interface.
func defaultCheckNamedValue(nv *driver.NamedValue) (err error) {
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(nv.Value)
return err
}
// driverArgs converts arguments from callers of Stmt.Exec and // driverArgs converts arguments from callers of Stmt.Exec and
// Stmt.Query into driver Values. // Stmt.Query into driver Values.
// //
// The statement ds may be nil, if no statement is available. // The statement ds may be nil, if no statement is available.
func driverArgs(ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) { func driverArgs(ci driver.Conn, ds *driverStmt, args []interface{}) ([]driver.NamedValue, error) {
nvargs := make([]driver.NamedValue, len(args)) nvargs := make([]driver.NamedValue, len(args))
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
want := -1
var si driver.Stmt var si driver.Stmt
var cc ccChecker
if ds != nil { if ds != nil {
si = ds.si si = ds.si
want = driverNumInput(ds)
cc.Locker = ds.Locker
cc.want = want
} }
cc, ok := si.(driver.ColumnConverter)
// Normal path, for a driver.Stmt that is not a ColumnConverter. // Check all types of interfaces from the start.
// Drivers may opt to use the NamedValueChecker for special
// argument types, then return driver.ErrSkip to pass it along
// to the column converter.
nvc, ok := si.(driver.NamedValueChecker)
if !ok { if !ok {
for n, arg := range args { nvc, ok = ci.(driver.NamedValueChecker)
var err error }
nv := &nvargs[n] cci, ok := si.(driver.ColumnConverter)
nv.Ordinal = n + 1 if ok {
if np, ok := arg.(NamedArg); ok { cc.cci = cci
if err := validateNamedValueName(np.Name); err != nil {
return nil, err
}
arg = np.Value
nvargs[n].Name = np.Name
}
nv.Value, err = driver.DefaultParameterConverter.ConvertValue(arg)
if err != nil {
return nil, fmt.Errorf("sql: converting Exec argument %s type: %v", describeNamedValue(nv), err)
}
}
return nvargs, nil
} }
// Let the Stmt convert its own arguments. // Loop through all the arguments, checking each one.
for n, arg := range args { // If no error is returned simply increment the index
// and continue. However if driver.ErrRemoveArgument
// is returned the argument is not included in the query
// argument list.
var err error
var n int
for _, arg := range args {
nv := &nvargs[n] nv := &nvargs[n]
nv.Ordinal = n + 1
if np, ok := arg.(NamedArg); ok { if np, ok := arg.(NamedArg); ok {
if err := validateNamedValueName(np.Name); err != nil { if err = validateNamedValueName(np.Name); err != nil {
return nil, err return nil, err
} }
arg = np.Value arg = np.Value
nv.Name = np.Name nv.Name = np.Name
} }
// First, see if the value itself knows how to convert nv.Ordinal = n + 1
// itself to a driver type. For example, a NullString nv.Value = arg
// struct changing into a string or nil.
if vr, ok := arg.(driver.Valuer); ok { // Checking sequence has four routes:
sv, err := callValuerValue(vr) // A: 1. Default
if err != nil { // B: 1. NamedValueChecker 2. Column Converter 3. Default
return nil, fmt.Errorf("sql: argument %s from Value: %v", describeNamedValue(nv), err) // C: 1. NamedValueChecker 3. Default
} // D: 1. Column Converter 2. Default
if !driver.IsValue(sv) { //
return nil, fmt.Errorf("sql: argument %s: non-subset type %T returned from Value", describeNamedValue(nv), sv) // The only time a Column Converter is called is first
} // or after NamedValueConverter. If first it is handled before
arg = sv // the nextCheck label. Thus for repeats tries only when the
// NamedValueConverter is selected should the Column Converter
// be used in the retry.
checker := defaultCheckNamedValue
nextCC := false
switch {
case nvc != nil:
nextCC = cci != nil
checker = nvc.CheckNamedValue
case cci != nil:
checker = cc.CheckNamedValue
} }
// Second, ask the column to sanity check itself. For nextCheck:
// example, drivers might use this to make sure that err = checker(nv)
// an int64 values being inserted into a 16-bit switch err {
// integer field is in range (before getting case nil:
// truncated), or that a nil can't go into a NOT NULL n++
// column before going across the network to get the continue
// same error. case driver.ErrRemoveArgument:
var err error nvargs = nvargs[:len(nvargs)-1]
ds.Lock() continue
nv.Value, err = cc.ColumnConverter(n).ConvertValue(arg) case driver.ErrSkip:
ds.Unlock() if nextCC {
if err != nil { nextCC = false
checker = cc.CheckNamedValue
} else {
checker = defaultCheckNamedValue
}
goto nextCheck
default:
return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err) return nil, fmt.Errorf("sql: converting argument %s type: %v", describeNamedValue(nv), err)
} }
if !driver.IsValue(nv.Value) { }
return nil, fmt.Errorf("sql: for argument %s, driver ColumnConverter error converted %T to unsupported type %T",
describeNamedValue(nv), arg, nv.Value) // Check the length of arguments after convertion to allow for omitted
} // arguments.
if want != -1 && len(nvargs) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(nvargs))
} }
return nvargs, nil return nvargs, nil
} }
// convertAssign copies to dest the value in src, converting it if possible. // convertAssign copies to dest the value in src, converting it if possible.
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"reflect" "reflect"
"runtime" "runtime"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
) )
...@@ -468,8 +469,8 @@ func TestDriverArgs(t *testing.T) { ...@@ -468,8 +469,8 @@ func TestDriverArgs(t *testing.T) {
}, },
} }
for i, tt := range tests { for i, tt := range tests {
ds := new(driverStmt) ds := &driverStmt{Locker: &sync.Mutex{}, si: stubDriverStmt{nil}}
got, err := driverArgs(ds, tt.args) got, err := driverArgs(nil, ds, tt.args)
if err != nil { if err != nil {
t.Errorf("test[%d]: %v", i, err) t.Errorf("test[%d]: %v", i, err)
continue continue
......
...@@ -262,9 +262,39 @@ type StmtQueryContext interface { ...@@ -262,9 +262,39 @@ type StmtQueryContext interface {
QueryContext(ctx context.Context, args []NamedValue) (Rows, error) QueryContext(ctx context.Context, args []NamedValue) (Rows, error)
} }
// ErrRemoveArgument may be returned from NamedValueChecker to instruct the
// sql package to not pass the argument to the driver query interface.
// Return when accepting query specific options or structures that aren't
// SQL query arguments.
var ErrRemoveArgument = errors.New("driver: remove argument from query")
// NamedValueChecker may be optionally implemented by Conn or Stmt. It provides
// the driver more control to handle Go and database types beyond the default
// Values types allowed.
//
// The sql package checks for value checkers in the following order,
// stopping at the first found match: Stmt.NamedValueChecker, Conn.NamedValueChecker,
// Stmt.ColumnConverter, DefaultParameterConverter.
//
// If CheckNamedValue returns ErrRemoveArgument, the NamedValue will not be included in
// the final query arguments. This may be used to pass special options to
// the query itself.
//
// If ErrSkip is returned the column converter error checking
// path is used for the argument. Drivers may wish to return ErrSkip after
// they have exhausted their own special cases.
type NamedValueChecker interface {
// CheckNamedValue is called before passing arguments to the driver
// and is called in place of any ColumnConverter. CheckNamedValue must do type
// validation and conversion as appropriate for the driver.
CheckNamedValue(*NamedValue) error
}
// ColumnConverter may be optionally implemented by Stmt if the // ColumnConverter may be optionally implemented by Stmt if the
// statement is aware of its own columns' types and can convert from // statement is aware of its own columns' types and can convert from
// any type to a driver Value. // any type to a driver Value.
//
// Deprecated: Drivers should implement NamedValueChecker.
type ColumnConverter interface { type ColumnConverter interface {
// ColumnConverter returns a ValueConverter for the provided // ColumnConverter returns a ValueConverter for the provided
// column index. If the type of a specific column isn't known // column index. If the type of a specific column isn't known
......
...@@ -58,9 +58,10 @@ type fakeDriver struct { ...@@ -58,9 +58,10 @@ type fakeDriver struct {
type fakeDB struct { type fakeDB struct {
name string name string
mu sync.Mutex mu sync.Mutex
tables map[string]*table tables map[string]*table
badConn bool badConn bool
allowAny bool
} }
type table struct { type table struct {
...@@ -352,12 +353,14 @@ func (c *fakeConn) Close() (err error) { ...@@ -352,12 +353,14 @@ func (c *fakeConn) Close() (err error) {
return nil return nil
} }
func checkSubsetTypes(args []driver.NamedValue) error { func checkSubsetTypes(allowAny bool, args []driver.NamedValue) error {
for _, arg := range args { for _, arg := range args {
switch arg.Value.(type) { switch arg.Value.(type) {
case int64, float64, bool, nil, []byte, string, time.Time: case int64, float64, bool, nil, []byte, string, time.Time:
default: default:
return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value) if !allowAny {
return fmt.Errorf("fakedb_test: invalid argument ordinal %[1]d: %[2]v, type %[2]T", arg.Ordinal, arg.Value)
}
} }
} }
return nil return nil
...@@ -373,7 +376,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver. ...@@ -373,7 +376,7 @@ func (c *fakeConn) ExecContext(ctx context.Context, query string, args []driver.
// just to check that all the args are of the proper types. // just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't // ErrSkip is returned so the caller acts as if we didn't
// implement this at all. // implement this at all.
err := checkSubsetTypes(args) err := checkSubsetTypes(c.db.allowAny, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -390,7 +393,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver ...@@ -390,7 +393,7 @@ func (c *fakeConn) QueryContext(ctx context.Context, query string, args []driver
// just to check that all the args are of the proper types. // just to check that all the args are of the proper types.
// ErrSkip is returned so the caller acts as if we didn't // ErrSkip is returned so the caller acts as if we didn't
// implement this at all. // implement this at all.
err := checkSubsetTypes(args) err := checkSubsetTypes(c.db.allowAny, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -642,7 +645,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d ...@@ -642,7 +645,7 @@ func (s *fakeStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (d
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
err := checkSubsetTypes(args) err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -753,7 +756,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) ( ...@@ -753,7 +756,7 @@ func (s *fakeStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
err := checkSubsetTypes(args) err := checkSubsetTypes(s.c.db.allowAny, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1004,6 +1007,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) { ...@@ -1004,6 +1007,12 @@ func (fakeDriverString) ConvertValue(v interface{}) (driver.Value, error) {
return fmt.Sprintf("%v", v), nil return fmt.Sprintf("%v", v), nil
} }
type anyTypeConverter struct{}
func (anyTypeConverter) ConvertValue(v interface{}) (driver.Value, error) {
return v, nil
}
func converterForType(typ string) driver.ValueConverter { func converterForType(typ string) driver.ValueConverter {
switch typ { switch typ {
case "bool": case "bool":
...@@ -1030,6 +1039,8 @@ func converterForType(typ string) driver.ValueConverter { ...@@ -1030,6 +1039,8 @@ func converterForType(typ string) driver.ValueConverter {
return driver.Null{Converter: driver.DefaultParameterConverter} return driver.Null{Converter: driver.DefaultParameterConverter}
case "datetime": case "datetime":
return driver.DefaultParameterConverter return driver.DefaultParameterConverter
case "any":
return anyTypeConverter{}
} }
panic("invalid fakedb column type of " + typ) panic("invalid fakedb column type of " + typ)
} }
...@@ -1056,6 +1067,8 @@ func colTypeToReflectType(typ string) reflect.Type { ...@@ -1056,6 +1067,8 @@ func colTypeToReflectType(typ string) reflect.Type {
return reflect.TypeOf(NullFloat64{}) return reflect.TypeOf(NullFloat64{})
case "datetime": case "datetime":
return reflect.TypeOf(time.Time{}) return reflect.TypeOf(time.Time{})
case "any":
return reflect.TypeOf(new(interface{})).Elem()
} }
panic("invalid fakedb column type of " + typ) panic("invalid fakedb column type of " + typ)
} }
...@@ -278,6 +278,27 @@ type Scanner interface { ...@@ -278,6 +278,27 @@ type Scanner interface {
Scan(src interface{}) error Scan(src interface{}) error
} }
// Out may be used to retrieve OUTPUT value parameters from stored procedures.
//
// Not all drivers and databases support OUTPUT value parameters.
//
// Example usage:
//
// var outArg string
// _, err := db.ExecContext(ctx, "ProcName", sql.Named("Arg1", Out{Dest: &outArg}))
type Out struct {
_Named_Fields_Required struct{}
// Dest is a pointer to the value that will be set to the result of the
// stored procedure's OUTPUT parameter.
Dest interface{}
// In is whether the parameter is an INOUT parameter. If so, the input value to the stored
// procedure is the dereferenced value of Dest's pointer, which is then replaced with
// the output value.
In bool
}
// ErrNoRows is returned by Scan when QueryRow doesn't return a // ErrNoRows is returned by Scan when QueryRow doesn't return a
// row. In such a case, QueryRow returns a placeholder *Row value that // row. In such a case, QueryRow returns a placeholder *Row value that
// defers this error until a Scan. // defers this error until a Scan.
...@@ -1206,7 +1227,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q ...@@ -1206,7 +1227,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
}() }()
if execer, ok := dc.ci.(driver.Execer); ok { if execer, ok := dc.ci.(driver.Execer); ok {
var dargs []driver.NamedValue var dargs []driver.NamedValue
dargs, err = driverArgs(nil, args) dargs, err = driverArgs(dc.ci, nil, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -1231,7 +1252,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q ...@@ -1231,7 +1252,7 @@ func (db *DB) execDC(ctx context.Context, dc *driverConn, release func(error), q
} }
ds := &driverStmt{Locker: dc, si: si} ds := &driverStmt{Locker: dc, si: si}
defer ds.Close() defer ds.Close()
return resultFromStatement(ctx, ds, args...) return resultFromStatement(ctx, dc.ci, ds, args...)
} }
// QueryContext executes a query that returns rows, typically a SELECT. // QueryContext executes a query that returns rows, typically a SELECT.
...@@ -1270,7 +1291,7 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat ...@@ -1270,7 +1291,7 @@ func (db *DB) query(ctx context.Context, query string, args []interface{}, strat
// The connection gets released by the releaseConn function. // The connection gets released by the releaseConn function.
func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) { func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(error), query string, args []interface{}) (*Rows, error) {
if queryer, ok := dc.ci.(driver.Queryer); ok { if queryer, ok := dc.ci.(driver.Queryer); ok {
dargs, err := driverArgs(nil, args) dargs, err := driverArgs(dc.ci, nil, args)
if err != nil { if err != nil {
releaseConn(err) releaseConn(err)
return nil, err return nil, err
...@@ -1307,7 +1328,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro ...@@ -1307,7 +1328,7 @@ func (db *DB) queryDC(ctx context.Context, dc *driverConn, releaseConn func(erro
} }
ds := &driverStmt{Locker: dc, si: si} ds := &driverStmt{Locker: dc, si: si}
rowsi, err := rowsiFromStatement(ctx, ds, args...) rowsi, err := rowsiFromStatement(ctx, dc.ci, ds, args...)
if err != nil { if err != nil {
ds.Close() ds.Close()
releaseConn(err) releaseConn(err)
...@@ -2009,7 +2030,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er ...@@ -2009,7 +2030,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
var res Result var res Result
for i := 0; i < maxBadConnRetries; i++ { for i := 0; i < maxBadConnRetries; i++ {
_, releaseConn, ds, err := s.connStmt(ctx) dc, releaseConn, ds, err := s.connStmt(ctx)
if err != nil { if err != nil {
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
continue continue
...@@ -2017,7 +2038,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er ...@@ -2017,7 +2038,7 @@ func (s *Stmt) ExecContext(ctx context.Context, args ...interface{}) (Result, er
return nil, err return nil, err
} }
res, err = resultFromStatement(ctx, ds, args...) res, err = resultFromStatement(ctx, dc.ci, ds, args...)
releaseConn(err) releaseConn(err)
if err != driver.ErrBadConn { if err != driver.ErrBadConn {
return res, err return res, err
...@@ -2032,23 +2053,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -2032,23 +2053,8 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
return s.ExecContext(context.Background(), args...) return s.ExecContext(context.Background(), args...)
} }
func driverNumInput(ds *driverStmt) int { func resultFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (Result, error) {
ds.Lock() dargs, err := driverArgs(ci, ds, args)
defer ds.Unlock() // in case NumInput panics
return ds.si.NumInput()
}
func resultFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (Result, error) {
want := driverNumInput(ds)
// -1 means the driver doesn't know how to count the number of
// placeholders, so we won't sanity check input here and instead let the
// driver deal with errors.
if want != -1 && len(args) != want {
return nil, fmt.Errorf("sql: expected %d arguments, got %d", want, len(args))
}
dargs, err := driverArgs(ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -2174,7 +2180,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er ...@@ -2174,7 +2180,7 @@ func (s *Stmt) QueryContext(ctx context.Context, args ...interface{}) (*Rows, er
return nil, err return nil, err
} }
rowsi, err = rowsiFromStatement(ctx, ds, args...) rowsi, err = rowsiFromStatement(ctx, dc.ci, ds, args...)
if err == nil { if err == nil {
// Note: ownership of ci passes to the *Rows, to be freed // Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn. // with releaseConn.
...@@ -2211,7 +2217,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -2211,7 +2217,7 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
return s.QueryContext(context.Background(), args...) return s.QueryContext(context.Background(), args...)
} }
func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}) (driver.Rows, error) { func rowsiFromStatement(ctx context.Context, ci driver.Conn, ds *driverStmt, args ...interface{}) (driver.Rows, error) {
var want int var want int
withLock(ds, func() { withLock(ds, func() {
want = ds.si.NumInput() want = ds.si.NumInput()
...@@ -2224,7 +2230,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{} ...@@ -2224,7 +2230,7 @@ func rowsiFromStatement(ctx context.Context, ds *driverStmt, args ...interface{}
return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args)) return nil, fmt.Errorf("sql: statement expects %d inputs; got %d", want, len(args))
} }
dargs, err := driverArgs(ds, args) dargs, err := driverArgs(ci, ds, args)
if err != nil { if err != nil {
return nil, err return nil, err
} }
......
...@@ -3191,6 +3191,131 @@ func TestConnectionLeak(t *testing.T) { ...@@ -3191,6 +3191,131 @@ func TestConnectionLeak(t *testing.T) {
wg.Wait() wg.Wait()
} }
type nvcDriver struct {
fakeDriver
skipNamedValueCheck bool
}
func (d *nvcDriver) Open(dsn string) (driver.Conn, error) {
c, err := d.fakeDriver.Open(dsn)
fc := c.(*fakeConn)
fc.db.allowAny = true
return &nvcConn{fc, d.skipNamedValueCheck}, err
}
type nvcConn struct {
*fakeConn
skipNamedValueCheck bool
}
type decimal struct {
value int
}
type doNotInclude struct{}
var _ driver.NamedValueChecker = &nvcConn{}
func (c *nvcConn) CheckNamedValue(nv *driver.NamedValue) error {
if c.skipNamedValueCheck {
return driver.ErrSkip
}
switch v := nv.Value.(type) {
default:
return driver.ErrSkip
case Out:
switch ov := v.Dest.(type) {
default:
return errors.New("unkown NameValueCheck OUTPUT type")
case *string:
*ov = "from-server"
nv.Value = "OUT:*string"
}
return nil
case decimal, []int64:
return nil
case doNotInclude:
return driver.ErrRemoveArgument
}
}
func TestNamedValueChecker(t *testing.T) {
Register("NamedValueCheck", &nvcDriver{})
db, err := Open("NamedValueCheck", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = db.ExecContext(ctx, "WIPE")
if err != nil {
t.Fatal("exec wipe", err)
}
_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any,str1=string,out1=string,array1=any")
if err != nil {
t.Fatal("exec create", err)
}
o1 := ""
_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A,str1=?,out1=?O1,array1=?", Named("A", decimal{123}), "hello", Named("O1", Out{Dest: &o1}), []int64{42, 128, 707}, doNotInclude{})
if err != nil {
t.Fatal("exec insert", err)
}
var (
str1 string
dec1 decimal
arr1 []int64
)
err = db.QueryRowContext(ctx, "SELECT|keys|dec1,str1,array1|").Scan(&dec1, &str1, &arr1)
if err != nil {
t.Fatal("select", err)
}
list := []struct{ got, want interface{} }{
{o1, "from-server"},
{dec1, decimal{123}},
{str1, "hello"},
{arr1, []int64{42, 128, 707}},
}
for index, item := range list {
if !reflect.DeepEqual(item.got, item.want) {
t.Errorf("got %#v wanted %#v for index %d", item.got, item.want, index)
}
}
}
func TestNamedValueCheckerSkip(t *testing.T) {
Register("NamedValueCheckSkip", &nvcDriver{skipNamedValueCheck: true})
db, err := Open("NamedValueCheckSkip", "")
if err != nil {
t.Fatal(err)
}
defer db.Close()
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
_, err = db.ExecContext(ctx, "WIPE")
if err != nil {
t.Fatal("exec wipe", err)
}
_, err = db.ExecContext(ctx, "CREATE|keys|dec1=any")
if err != nil {
t.Fatal("exec create", err)
}
_, err = db.ExecContext(ctx, "INSERT|keys|dec1=?A", Named("A", decimal{123}))
if err == nil {
t.Fatalf("expected error with bad argument, got %v", err)
}
}
// badConn implements a bad driver.Conn, for TestBadDriver. // badConn implements a bad driver.Conn, for TestBadDriver.
// The Exec method panics. // The Exec method panics.
type badConn struct{} type badConn struct{}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment