Commit 5f739d9d authored by Marko Tiikkaja's avatar Marko Tiikkaja Committed by Brad Fitzpatrick

database/sql: Close per-tx prepared statements when the associated tx ends

LGTM=bradfitz
R=golang-codereviews, bradfitz, mattn.jp
CC=golang-codereviews
https://golang.org/cl/131650043
parent 93e5cc22
...@@ -1043,6 +1043,13 @@ type Tx struct { ...@@ -1043,6 +1043,13 @@ type Tx struct {
// or Rollback. once done, all operations fail with // or Rollback. once done, all operations fail with
// ErrTxDone. // ErrTxDone.
done bool done bool
// All Stmts prepared for this transaction. These will be closed after the
// transaction has been committed or rolled back.
stmts struct {
sync.Mutex
v []*Stmt
}
} }
var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back") var ErrTxDone = errors.New("sql: Transaction has already been committed or rolled back")
...@@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) { ...@@ -1064,6 +1071,15 @@ func (tx *Tx) grabConn() (*driverConn, error) {
return tx.dc, nil return tx.dc, nil
} }
// Closes all Stmts prepared for this transaction.
func (tx *Tx) closePrepared() {
tx.stmts.Lock()
for _, stmt := range tx.stmts.v {
stmt.Close()
}
tx.stmts.Unlock()
}
// Commit commits the transaction. // Commit commits the transaction.
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
if tx.done { if tx.done {
...@@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error { ...@@ -1071,8 +1087,12 @@ func (tx *Tx) Commit() error {
} }
defer tx.close() defer tx.close()
tx.dc.Lock() tx.dc.Lock()
defer tx.dc.Unlock() err := tx.txi.Commit()
return tx.txi.Commit() tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
} }
// Rollback aborts the transaction. // Rollback aborts the transaction.
...@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error { ...@@ -1082,8 +1102,12 @@ func (tx *Tx) Rollback() error {
} }
defer tx.close() defer tx.close()
tx.dc.Lock() tx.dc.Lock()
defer tx.dc.Unlock() err := tx.txi.Rollback()
return tx.txi.Rollback() tx.dc.Unlock()
if err != driver.ErrBadConn {
tx.closePrepared()
}
return err
} }
// Prepare creates a prepared statement for use within a transaction. // Prepare creates a prepared statement for use within a transaction.
...@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { ...@@ -1127,6 +1151,9 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
}, },
query: query, query: query,
} }
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, stmt)
tx.stmts.Unlock()
return stmt, nil return stmt, nil
} }
...@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { ...@@ -1155,7 +1182,7 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
dc.Lock() dc.Lock()
si, err := dc.ci.Prepare(stmt.query) si, err := dc.ci.Prepare(stmt.query)
dc.Unlock() dc.Unlock()
return &Stmt{ txs := &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
txsi: &driverStmt{ txsi: &driverStmt{
...@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt { ...@@ -1165,6 +1192,10 @@ func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
query: stmt.query, query: stmt.query,
stickyErr: err, stickyErr: err,
} }
tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs)
tx.stmts.Unlock()
return txs
} }
// Exec executes a query that doesn't return rows. // Exec executes a query that doesn't return rows.
......
...@@ -441,6 +441,33 @@ func TestExec(t *testing.T) { ...@@ -441,6 +441,33 @@ func TestExec(t *testing.T) {
} }
} }
func TestTxPrepare(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
stmt, err := tx.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
defer stmt.Close()
_, err = stmt.Exec("Bobby", 7)
if err != nil {
t.Fatalf("Exec = %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit = %v", err)
}
// Commit() should have closed the statement
if !stmt.closed {
t.Fatal("Stmt not closed after Commit")
}
}
func TestTxStmt(t *testing.T) { func TestTxStmt(t *testing.T) {
db := newTestDB(t, "") db := newTestDB(t, "")
defer closeDB(t, db) defer closeDB(t, db)
...@@ -464,6 +491,10 @@ func TestTxStmt(t *testing.T) { ...@@ -464,6 +491,10 @@ func TestTxStmt(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Commit = %v", err) t.Fatalf("Commit = %v", err)
} }
// Commit() should have closed the statement
if !txs.closed {
t.Fatal("Stmt not closed after Commit")
}
} }
// Issue: http://golang.org/issue/2784 // Issue: http://golang.org/issue/2784
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment