Commit f7a77163 authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

database/sql: refcounting and lifetime fixes

Simplifies the contract for Driver.Stmt.Close in
the process of fixing issue 3865.

Fixes #3865
Update #4459 (maybe fixes it; uninvestigated)

R=golang-dev, rsc
CC=golang-dev
https://golang.org/cl/7363043
parent 6c976393
...@@ -115,23 +115,8 @@ type Result interface { ...@@ -115,23 +115,8 @@ type Result interface {
type Stmt interface { type Stmt interface {
// Close closes the statement. // Close closes the statement.
// //
// Closing a statement should not interrupt any outstanding // As of Go 1.1, a Stmt will not be closed if it's in use
// query created from that statement. That is, the following // by any queries.
// order of operations is valid:
//
// * create a driver statement
// * call Query on statement, returning Rows
// * close the statement
// * read from Rows
//
// If closing a statement invalidates currently-running
// queries, the final step above will incorrectly fail.
//
// TODO(bradfitz): possibly remove the restriction above, if
// enough driver authors object and find it complicates their
// code too much. The sql package could be smarter about
// refcounting the statement and closing it at the appropriate
// time.
Close() error Close() error
// NumInput returns the number of placeholder parameters. // NumInput returns the number of placeholder parameters.
......
...@@ -11,6 +11,7 @@ import ( ...@@ -11,6 +11,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"runtime"
"sync" "sync"
) )
...@@ -189,9 +190,66 @@ type DB struct { ...@@ -189,9 +190,66 @@ type DB struct {
driver driver.Driver driver driver.Driver
dsn string dsn string
mu sync.Mutex // protects freeConn and closed mu sync.Mutex // protects following fields
freeConn []driver.Conn outConn map[driver.Conn]bool // whether the conn is in use
closed bool freeConn []driver.Conn
closed bool
dep map[finalCloser]depSet
onConnPut map[driver.Conn][]func() // code (with mu held) run when conn is next returned
lastPut map[driver.Conn]string // stacktrace of last conn's put; debug only
}
// depSet is a finalCloser's outstanding dependencies
type depSet map[interface{}]bool // set of true bools
// The finalCloser interface is used by (*DB).addDep and (*DB).get
type finalCloser interface {
// finalClose is called when the reference count of an object
// goes to zero. (*DB).mu is not held while calling it.
finalClose() error
}
// addDep notes that x now depends on dep, and x's finalClose won't be
// called until all of x's dependencies are removed with removeDep.
func (db *DB) addDep(x finalCloser, dep interface{}) {
//println(fmt.Sprintf("addDep(%T %p, %T %p)", x, x, dep, dep))
db.mu.Lock()
defer db.mu.Unlock()
if db.dep == nil {
db.dep = make(map[finalCloser]depSet)
}
xdep := db.dep[x]
if xdep == nil {
xdep = make(depSet)
db.dep[x] = xdep
}
xdep[dep] = true
}
// removeDep notes that x no longer depends on dep.
// If x still has dependencies, nil is returned.
// If x no longer has any dependencies, its finalClose method will be
// called and its error value will be returned.
func (db *DB) removeDep(x finalCloser, dep interface{}) error {
//println(fmt.Sprintf("removeDep(%T %p, %T %p)", x, x, dep, dep))
done := false
db.mu.Lock()
xdep := db.dep[x]
if xdep != nil {
delete(xdep, dep)
if len(xdep) == 0 {
delete(db.dep, x)
done = true
}
}
db.mu.Unlock()
if !done {
return nil
}
//println(fmt.Sprintf("calling final close on %T %v (%#v)", x, x, x))
return x.finalClose()
} }
// Open opens a database specified by its database driver name and a // Open opens a database specified by its database driver name and a
...@@ -201,11 +259,20 @@ type DB struct { ...@@ -201,11 +259,20 @@ type DB struct {
// Most users will open a database via a driver-specific connection // Most users will open a database via a driver-specific connection
// helper function that returns a *DB. // helper function that returns a *DB.
func Open(driverName, dataSourceName string) (*DB, error) { func Open(driverName, dataSourceName string) (*DB, error) {
driver, ok := drivers[driverName] driveri, ok := drivers[driverName]
if !ok { if !ok {
return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName) return nil, fmt.Errorf("sql: unknown driver %q (forgotten import?)", driverName)
} }
return &DB{driver: driver, dsn: dataSourceName}, nil // TODO: optionally proactively connect to a Conn to check
// the dataSourceName: golang.org/issue/4804
db := &DB{
driver: driveri,
dsn: dataSourceName,
outConn: make(map[driver.Conn]bool),
lastPut: make(map[driver.Conn]string),
onConnPut: make(map[driver.Conn][]func()),
}
return db, nil
} }
// Close closes the database, releasing any open resources. // Close closes the database, releasing any open resources.
...@@ -241,22 +308,38 @@ func (db *DB) conn() (driver.Conn, error) { ...@@ -241,22 +308,38 @@ func (db *DB) conn() (driver.Conn, error) {
if n := len(db.freeConn); n > 0 { if n := len(db.freeConn); n > 0 {
conn := db.freeConn[n-1] conn := db.freeConn[n-1]
db.freeConn = db.freeConn[:n-1] db.freeConn = db.freeConn[:n-1]
db.outConn[conn] = true
db.mu.Unlock() db.mu.Unlock()
return conn, nil return conn, nil
} }
db.mu.Unlock() db.mu.Unlock()
return db.driver.Open(db.dsn) conn, err := db.driver.Open(db.dsn)
if err == nil {
db.mu.Lock()
db.outConn[conn] = true
db.mu.Unlock()
}
return conn, err
} }
// connIfFree returns (wanted, true) if wanted is still a valid conn and
// isn't in use.
//
// If wanted is valid but in use, connIfFree returns (wanted, false).
// If wanted is invalid, connIfFre returns (nil, false).
func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()
if db.outConn[wanted] {
return conn, false
}
for i, conn := range db.freeConn { for i, conn := range db.freeConn {
if conn != wanted { if conn != wanted {
continue continue
} }
db.freeConn[i] = db.freeConn[len(db.freeConn)-1] db.freeConn[i] = db.freeConn[len(db.freeConn)-1]
db.freeConn = db.freeConn[:len(db.freeConn)-1] db.freeConn = db.freeConn[:len(db.freeConn)-1]
db.outConn[wanted] = true
return wanted, true return wanted, true
} }
return nil, false return nil, false
...@@ -265,14 +348,52 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) { ...@@ -265,14 +348,52 @@ func (db *DB) connIfFree(wanted driver.Conn) (conn driver.Conn, ok bool) {
// putConnHook is a hook for testing. // putConnHook is a hook for testing.
var putConnHook func(*DB, driver.Conn) var putConnHook func(*DB, driver.Conn)
// noteUnusedDriverStatement notes that si is no longer used and should
// be closed whenever possible (when c is next not in use), unless c is
// already closed.
func (db *DB) noteUnusedDriverStatement(c driver.Conn, si driver.Stmt) {
db.mu.Lock()
defer db.mu.Unlock()
if db.outConn[c] {
db.onConnPut[c] = append(db.onConnPut[c], func() {
si.Close()
})
} else {
si.Close()
}
}
// debugGetPut determines whether getConn & putConn calls' stack traces
// are returned for more verbose crashes.
const debugGetPut = false
// putConn adds a connection to the db's free pool. // putConn adds a connection to the db's free pool.
// err is optionally the last error that occurred on this connection. // err is optionally the last error that occurred on this connection.
func (db *DB) putConn(c driver.Conn, err error) { func (db *DB) putConn(c driver.Conn, err error) {
db.mu.Lock()
if !db.outConn[c] {
if debugGetPut {
fmt.Printf("putConn(%v) DUPLICATE was: %s\n\nPREVIOUS was: %s", c, stack(), db.lastPut[c])
}
panic("sql: connection returned that was never out")
}
if debugGetPut {
db.lastPut[c] = stack()
}
delete(db.outConn, c)
if fns, ok := db.onConnPut[c]; ok {
for _, fn := range fns {
fn()
}
delete(db.onConnPut, c)
}
if err == driver.ErrBadConn { if err == driver.ErrBadConn {
// Don't reuse bad connections. // Don't reuse bad connections.
db.mu.Unlock()
return return
} }
db.mu.Lock()
if putConnHook != nil { if putConnHook != nil {
putConnHook(db, c) putConnHook(db, c)
} }
...@@ -300,7 +421,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) { ...@@ -300,7 +421,7 @@ func (db *DB) Prepare(query string) (*Stmt, error) {
return stmt, err return stmt, err
} }
func (db *DB) prepare(query string) (stmt *Stmt, err error) { func (db *DB) prepare(query string) (*Stmt, error) {
// TODO: check if db.driver supports an optional // TODO: check if db.driver supports an optional
// driver.Preparer interface and call that instead, if so, // driver.Preparer interface and call that instead, if so,
// otherwise we make a prepared statement that's bound // otherwise we make a prepared statement that's bound
...@@ -311,19 +432,17 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) { ...@@ -311,19 +432,17 @@ func (db *DB) prepare(query string) (stmt *Stmt, err error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
db.putConn(ci, err)
}()
si, err := ci.Prepare(query) si, err := ci.Prepare(query)
if err != nil { if err != nil {
db.putConn(ci, err)
return nil, err return nil, err
} }
stmt = &Stmt{ stmt := &Stmt{
db: db, db: db,
query: query, query: query,
css: []connStmt{{ci, si}}, css: []connStmt{{ci, si}},
} }
db.addDep(stmt, stmt)
return stmt, nil return stmt, nil
} }
...@@ -698,6 +817,8 @@ type Stmt struct { ...@@ -698,6 +817,8 @@ type Stmt struct {
query string // that created the Stmt query string // that created the Stmt
stickyErr error // if non-nil, this error is returned for all operations stickyErr error // if non-nil, this error is returned for all operations
closemu sync.RWMutex // held exclusively during close, for read otherwise.
// If in a transaction, else both nil: // If in a transaction, else both nil:
tx *Tx tx *Tx
txsi driver.Stmt txsi driver.Stmt
...@@ -715,6 +836,8 @@ type Stmt struct { ...@@ -715,6 +836,8 @@ type Stmt struct {
// Exec executes a prepared statement with the given arguments and // Exec executes a prepared statement with the given arguments and
// returns a Result summarizing the effect of the statement. // returns a Result summarizing the effect of the statement.
func (s *Stmt) Exec(args ...interface{}) (Result, error) { func (s *Stmt) Exec(args ...interface{}) (Result, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
_, releaseConn, si, err := s.connStmt() _, releaseConn, si, err := s.connStmt()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -813,6 +936,9 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St ...@@ -813,6 +936,9 @@ func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(error), si driver.St
// Query executes a prepared query statement with the given arguments // Query executes a prepared query statement with the given arguments
// and returns the query results as a *Rows. // and returns the query results as a *Rows.
func (s *Stmt) Query(args ...interface{}) (*Rows, error) { func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
s.closemu.RLock()
defer s.closemu.RUnlock()
ci, releaseConn, si, err := s.connStmt() ci, releaseConn, si, err := s.connStmt()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -827,10 +953,15 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) { ...@@ -827,10 +953,15 @@ func (s *Stmt) Query(args ...interface{}) (*Rows, error) {
// Note: ownership of ci passes to the *Rows, to be freed // Note: ownership of ci passes to the *Rows, to be freed
// with releaseConn. // with releaseConn.
rows := &Rows{ rows := &Rows{
db: s.db, db: s.db,
ci: ci, ci: ci,
releaseConn: releaseConn, rowsi: rowsi,
rowsi: rowsi, // releaseConn set below
}
s.db.addDep(s, rows)
rows.releaseConn = func(err error) {
releaseConn(err)
s.db.removeDep(s, rows)
} }
return rows, nil return rows, nil
} }
...@@ -876,6 +1007,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row { ...@@ -876,6 +1007,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
// Close closes the statement. // Close closes the statement.
func (s *Stmt) Close() error { func (s *Stmt) Close() error {
s.closemu.Lock()
defer s.closemu.Unlock()
if s.stickyErr != nil { if s.stickyErr != nil {
return s.stickyErr return s.stickyErr
} }
...@@ -888,18 +1022,17 @@ func (s *Stmt) Close() error { ...@@ -888,18 +1022,17 @@ func (s *Stmt) Close() error {
if s.tx != nil { if s.tx != nil {
s.txsi.Close() s.txsi.Close()
} else { return nil
for _, v := range s.css { }
if ci, match := s.db.connIfFree(v.ci); match {
v.si.Close() return s.db.removeDep(s, s)
s.db.putConn(ci, nil) }
} else {
// TODO(bradfitz): care that we can't close func (s *Stmt) finalClose() error {
// this statement because the statement's for _, v := range s.css {
// connection is in use? s.db.noteUnusedDriverStatement(v.ci, v.si)
}
}
} }
s.css = nil
return nil return nil
} }
...@@ -918,7 +1051,7 @@ func (s *Stmt) Close() error { ...@@ -918,7 +1051,7 @@ func (s *Stmt) Close() error {
// ... // ...
type Rows struct { type Rows struct {
db *DB db *DB
ci driver.Conn // owned; must call putconn when closed to release ci driver.Conn // owned; must call releaseConn when closed to release
releaseConn func(error) releaseConn func(error)
rowsi driver.Rows rowsi driver.Rows
...@@ -1094,3 +1227,8 @@ type Result interface { ...@@ -1094,3 +1227,8 @@ type Result interface {
type result struct { type result struct {
driver.Result driver.Result
} }
func stack() string {
var buf [1024]byte
return string(buf[:runtime.Stack(buf[:], false)])
}
...@@ -8,7 +8,6 @@ import ( ...@@ -8,7 +8,6 @@ import (
"database/sql/driver" "database/sql/driver"
"fmt" "fmt"
"reflect" "reflect"
"runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
...@@ -63,6 +62,10 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) { ...@@ -63,6 +62,10 @@ func exec(t *testing.T, db *DB, query string, args ...interface{}) {
} }
func closeDB(t *testing.T, db *DB) { func closeDB(t *testing.T, db *DB) {
if e := recover(); e != nil {
fmt.Printf("Panic: %v\n", e)
panic(e)
}
err := db.Close() err := db.Close()
if err != nil { if err != nil {
t.Fatalf("error closing DB: %v", err) t.Fatalf("error closing DB: %v", err)
...@@ -448,10 +451,8 @@ func TestIssue2542Deadlock(t *testing.T) { ...@@ -448,10 +451,8 @@ func TestIssue2542Deadlock(t *testing.T) {
} }
} }
// From golang.org/issue/3865
func TestCloseStmtBeforeRows(t *testing.T) { func TestCloseStmtBeforeRows(t *testing.T) {
t.Skip("known broken test; golang.org/issue/3865")
return
db := newTestDB(t, "people") db := newTestDB(t, "people")
defer closeDB(t, db) defer closeDB(t, db)
...@@ -666,8 +667,3 @@ func nullTestRun(t *testing.T, spec nullTestSpec) { ...@@ -666,8 +667,3 @@ func nullTestRun(t *testing.T, spec nullTestSpec) {
} }
} }
} }
func stack() string {
buf := make([]byte, 1024)
return string(buf[:runtime.Stack(buf, false)])
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment