Commit 82e1732f authored by Sarah Adams's avatar Sarah Adams Committed by Daniel Theophanes

database/sql: proper prepared statement support in transactions

This change was originally written by Marko Tiikkaja <marko@joh.to>.
https://go-review.googlesource.com/#/c/2035/

Previously *Tx.Stmt always prepared a new statement, even if an
existing one was available on the connection the transaction was on.
Now we first see if the statement is already available on the
connection and only prepare if it isn't. Additionally, when we do
need to prepare one, we store it in the parent *Stmt to allow it to be
later reused by other calls to *Tx.Stmt on that statement or just
straight up by *Stmt.Exec et al.

To make sure that the statement doesn't disappear unexpectedly, we
record a dependency from the statement returned by *Tx.Stmt to the
*Stmt it came from and set a new field, parentStmt, to point to the
originating *Stmt. When the transaction's *Stmt is closed, we remove
the dependency. This way the "parent" *Stmt can be closed by the user
without her having to know whether any transactions are still using it
or not.

Fixes #15606

Change-Id: I41b5056847e117ac61130328b0239d1e000a4a08
Reviewed-on: https://go-review.googlesource.com/35476
Run-TryBot: Daniel Theophanes <kardianos@gmail.com>
TryBot-Result: Gobot Gobot <gobot@golang.org>
Reviewed-by: default avatarDaniel Theophanes <kardianos@gmail.com>
parent 3b988eb6
...@@ -1554,19 +1554,6 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) { ...@@ -1554,19 +1554,6 @@ func (tx *Tx) PrepareContext(ctx context.Context, query string) (*Stmt, error) {
tx.closemu.RLock() tx.closemu.RLock()
defer tx.closemu.RUnlock() defer tx.closemu.RUnlock()
// TODO(bradfitz): We could be more efficient here and either
// provide a method to take an existing Stmt (created on
// perhaps a different Conn), and re-create it on this Conn if
// necessary. Or, better: keep a map in DB of query string to
// Stmts, and have Stmt.Execute do the right thing and
// re-prepare if the Conn in use doesn't have that prepared
// statement. But we'll want to avoid caching the statement
// in the case where we only call conn.Prepare implicitly
// (such as in db.Exec or tx.Exec), but the caller package
// can't be holding a reference to the returned statement.
// Perhaps just looking at the reference count (by noting
// Stmt.Close) would be enough. We might also want a finalizer
// on Stmt to drop the reference count.
dc, err := tx.grabConn(ctx) dc, err := tx.grabConn(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -1621,11 +1608,6 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { ...@@ -1621,11 +1608,6 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
tx.closemu.RLock() tx.closemu.RLock()
defer tx.closemu.RUnlock() defer tx.closemu.RUnlock()
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
// per-Conn. See also the big comment in Tx.Prepare.
if tx.db != stmt.db { if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")} return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
} }
...@@ -1634,9 +1616,45 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { ...@@ -1634,9 +1616,45 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
return &Stmt{stickyErr: err} return &Stmt{stickyErr: err}
} }
var si driver.Stmt var si driver.Stmt
withLock(dc, func() { var parentStmt *Stmt
si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query) stmt.mu.Lock()
}) if stmt.closed || stmt.tx != nil {
// If the statement has been closed or already belongs to a
// transaction, we can't reuse it in this connection.
// Since tx.StmtContext should never need to be called with a
// Stmt already belonging to tx, we ignore this edge case and
// re-prepare the statement in this case. No need to add
// code-complexity for this.
stmt.mu.Unlock()
withLock(dc, func() {
si, err = ctxDriverPrepare(ctx, dc.ci, stmt.query)
})
if err != nil {
return &Stmt{stickyErr: err}
}
} else {
stmt.removeClosedStmtLocked()
// See if the statement has already been prepared on this connection,
// and reuse it if possible.
for _, v := range stmt.css {
if v.dc == dc {
si = v.ds.si
break
}
}
stmt.mu.Unlock()
if si == nil {
cs, err := stmt.prepareOnConnLocked(ctx, dc)
if err != nil {
return &Stmt{stickyErr: err}
}
si = cs.si
}
parentStmt = stmt
}
txs := &Stmt{ txs := &Stmt{
db: tx.db, db: tx.db,
tx: tx, tx: tx,
...@@ -1644,8 +1662,11 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt { ...@@ -1644,8 +1662,11 @@ func (tx *Tx) StmtContext(ctx context.Context, stmt *Stmt) *Stmt {
Locker: dc, Locker: dc,
si: si, si: si,
}, },
query: stmt.query, parentStmt: parentStmt,
stickyErr: err, query: stmt.query,
}
if parentStmt != nil {
tx.db.addDep(parentStmt, txs)
} }
tx.stmts.Lock() tx.stmts.Lock()
tx.stmts.v = append(tx.stmts.v, txs) tx.stmts.v = append(tx.stmts.v, txs)
...@@ -1769,13 +1790,21 @@ type Stmt struct { ...@@ -1769,13 +1790,21 @@ type Stmt struct {
tx *Tx tx *Tx
txds *driverStmt txds *driverStmt
// parentStmt is set when a transaction-specific statement
// is requested from an identical statement prepared on the same
// conn. parentStmt is used to track the dependency of this statement
// on its originating ("parent") statement so that parentStmt may
// be closed by the user without them having to know whether or not
// any transactions are still using it.
parentStmt *Stmt
mu sync.Mutex // protects the rest of the fields mu sync.Mutex // protects the rest of the fields
closed bool closed bool
// css is a list of underlying driver statement interfaces // css is a list of underlying driver statement interfaces
// that are valid on particular connections. This is only // that are valid on particular connections. This is only
// used if tx == nil and one is found that has idle // used if tx == nil and one is found that has idle
// connections. If tx != nil, txsi is always used. // connections. If tx != nil, txds is always used.
css []connStmt css []connStmt
// lastNumClosed is copied from db.numClosed when Stmt is created // lastNumClosed is copied from db.numClosed when Stmt is created
...@@ -1916,18 +1945,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e ...@@ -1916,18 +1945,28 @@ func (s *Stmt) connStmt(ctx context.Context) (ci *driverConn, releaseConn func(e
// No luck; we need to prepare the statement on this connection // No luck; we need to prepare the statement on this connection
withLock(dc, func() { withLock(dc, func() {
ds, err = dc.prepareLocked(ctx, s.query) ds, err = s.prepareOnConnLocked(ctx, dc)
}) })
if err != nil { if err != nil {
s.db.putConn(dc, err) s.db.putConn(dc, err)
return nil, nil, nil, err return nil, nil, nil, err
} }
return dc, dc.releaseConn, ds, nil
}
// prepareOnConnLocked prepares the query in Stmt s on dc and adds it to the list of
// open connStmt on the statement. It assumes the caller is holding the lock on dc.
func (s *Stmt) prepareOnConnLocked(ctx context.Context, dc *driverConn) (*driverStmt, error) {
si, err := dc.prepareLocked(ctx, s.query)
if err != nil {
return nil, err
}
cs := connStmt{dc, si}
s.mu.Lock() s.mu.Lock()
cs := connStmt{dc, ds}
s.css = append(s.css, cs) s.css = append(s.css, cs)
s.mu.Unlock() s.mu.Unlock()
return cs.ds, nil
return dc, dc.releaseConn, ds, nil
} }
// QueryContext executes a prepared query statement with the given arguments // QueryContext executes a prepared query statement with the given arguments
...@@ -2056,11 +2095,16 @@ func (s *Stmt) Close() error { ...@@ -2056,11 +2095,16 @@ func (s *Stmt) Close() error {
s.closed = true s.closed = true
s.mu.Unlock() s.mu.Unlock()
if s.tx != nil { if s.tx == nil {
return s.txds.Close() return s.db.removeDep(s, s)
} }
return s.db.removeDep(s, s) if s.parentStmt != nil {
// If parentStmt is set, we must not close s.txds since it's stored
// in the css array of the parentStmt.
return s.db.removeDep(s.parentStmt, s)
}
return s.txds.Close()
} }
func (s *Stmt) finalClose() error { func (s *Stmt) finalClose() error {
......
...@@ -1024,6 +1024,196 @@ func TestTxStmt(t *testing.T) { ...@@ -1024,6 +1024,196 @@ func TestTxStmt(t *testing.T) {
} }
} }
func TestTxStmtPreparedOnce(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32")
prepares0 := numPrepares(t, db)
// db.Prepare increments numPrepares.
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
defer stmt.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
txs1 := tx.Stmt(stmt)
txs2 := tx.Stmt(stmt)
_, err = txs1.Exec("Go", 7)
if err != nil {
t.Fatalf("Exec = %v", err)
}
txs1.Close()
_, err = txs2.Exec("Gopher", 8)
if err != nil {
t.Fatalf("Exec = %v", err)
}
txs2.Close()
err = tx.Commit()
if err != nil {
t.Fatalf("Commit = %v", err)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
func TestTxStmtClosedRePrepares(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32")
prepares0 := numPrepares(t, db)
// db.Prepare increments numPrepares.
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
err = stmt.Close()
if err != nil {
t.Fatalf("stmt.Close() = %v", err)
}
// tx.Stmt increments numPrepares because stmt is closed.
txs := tx.Stmt(stmt)
if txs.stickyErr != nil {
t.Fatal(txs.stickyErr)
}
if txs.parentStmt != nil {
t.Fatal("expected nil parentStmt")
}
_, err = txs.Exec(`Eric`, 82)
if err != nil {
t.Fatalf("txs.Exec = %v", err)
}
err = txs.Close()
if err != nil {
t.Fatalf("txs.Close = %v", err)
}
tx.Rollback()
if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
t.Errorf("executed %d Prepare statements; want 2", prepares)
}
}
func TestParentStmtOutlivesTxStmt(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32")
// Make sure everything happens on the same connection.
db.SetMaxOpenConns(1)
prepares0 := numPrepares(t, db)
// db.Prepare increments numPrepares.
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
defer stmt.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
txs := tx.Stmt(stmt)
if len(stmt.css) != 1 {
t.Fatalf("len(stmt.css) = %v; want 1", len(stmt.css))
}
err = txs.Close()
if err != nil {
t.Fatalf("txs.Close() = %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("tx.Rollback() = %v", err)
}
// txs must not be valid.
_, err = txs.Exec("Suzan", 30)
if err == nil {
t.Fatalf("txs.Exec(), expected err")
}
// Stmt must still be valid.
_, err = stmt.Exec("Janina", 25)
if err != nil {
t.Fatalf("stmt.Exec() = %v", err)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 1 {
t.Errorf("executed %d Prepare statements; want 1", prepares)
}
}
// Test that tx.Stmt called with a statment already
// associated with tx as argument re-prepares the same
// statement again.
func TestTxStmtFromTxStmtRePrepares(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32")
prepares0 := numPrepares(t, db)
// db.Prepare increments numPrepares.
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
defer stmt.Close()
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
txs1 := tx.Stmt(stmt)
// tx.Stmt(txs1) increments numPrepares because txs1 already
// belongs to a transaction (albeit the same transaction).
txs2 := tx.Stmt(txs1)
if txs2.stickyErr != nil {
t.Fatal(txs2.stickyErr)
}
if txs2.parentStmt != nil {
t.Fatal("expected nil parentStmt")
}
_, err = txs2.Exec(`Eric`, 82)
if err != nil {
t.Fatal(err)
}
err = txs1.Close()
if err != nil {
t.Fatalf("txs1.Close = %v", err)
}
err = txs2.Close()
if err != nil {
t.Fatalf("txs1.Close = %v", err)
}
err = tx.Rollback()
if err != nil {
t.Fatalf("tx.Rollback = %v", err)
}
if prepares := numPrepares(t, db) - prepares0; prepares != 2 {
t.Errorf("executed %d Prepare statements; want 2", prepares)
}
}
// Issue: https://golang.org/issue/2784 // Issue: https://golang.org/issue/2784
// This test didn't fail before because we got lucky with the fakedb driver. // This test didn't fail before because we got lucky with the fakedb driver.
// It was failing, and now not, in github.com/bradfitz/go-sql-test // It was failing, and now not, in github.com/bradfitz/go-sql-test
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment