Commit e77099da authored by Brad Fitzpatrick's avatar Brad Fitzpatrick

sql: add Tx.Stmt to use an existing prepared stmt in a transaction

R=rsc
CC=golang-dev
https://golang.org/cl/5433059
parent 23227f3d
...@@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error { ...@@ -344,25 +344,26 @@ func (tx *Tx) Rollback() error {
return tx.txi.Rollback() return tx.txi.Rollback()
} }
// Prepare creates a prepared statement. // Prepare creates a prepared statement for use within a transaction.
// //
// The statement is only valid within the scope of this transaction. // The returned statement operates within the transaction and can no longer
// be used once the transaction has been committed or rolled back.
//
// To use an existing prepared statement on this transaction, see Tx.Stmt.
func (tx *Tx) Prepare(query string) (*Stmt, error) { func (tx *Tx) Prepare(query string) (*Stmt, error) {
// TODO(bradfitz): the restriction that the returned statement // TODO(bradfitz): We could be more efficient here and either
// is only valid for this Transaction is lame and negates a // provide a method to take an existing Stmt (created on
// lot of the benefit of prepared statements. We could be // perhaps a different Conn), and re-create it on this Conn if
// more efficient here and either provide a method to take an // necessary. Or, better: keep a map in DB of query string to
// existing Stmt (created on perhaps a different Conn), and // Stmts, and have Stmt.Execute do the right thing and
// re-create it on this Conn if necessary. Or, better: keep a // re-prepare if the Conn in use doesn't have that prepared
// map in DB of query string to Stmts, and have Stmt.Execute // statement. But we'll want to avoid caching the statement
// do the right thing and re-prepare if the Conn in use // in the case where we only call conn.Prepare implicitly
// doesn't have that prepared statement. But we'll want to // (such as in db.Exec or tx.Exec), but the caller package
// avoid caching the statement in the case where we only call // can't be holding a reference to the returned statement.
// conn.Prepare implicitly (such as in db.Exec or tx.Exec), // Perhaps just looking at the reference count (by noting
// but the caller package can't be holding a reference to the // Stmt.Close) would be enough. We might also want a finalizer
// returned statement. Perhaps just looking at the reference // on Stmt to drop the reference count.
// count (by noting Stmt.Close) would be enough. We might also
// want a finalizer on Stmt to drop the reference count.
ci, err := tx.grabConn() ci, err := tx.grabConn()
if err != nil { if err != nil {
return nil, err return nil, err
...@@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) { ...@@ -383,6 +384,39 @@ func (tx *Tx) Prepare(query string) (*Stmt, error) {
return stmt, nil return stmt, nil
} }
// Stmt returns a transaction-specific prepared statement from
// an existing statement.
//
// Example:
// updateMoney, err := db.Prepare("UPDATE balance SET money=money+? WHERE id=?")
// ...
// tx, err := db.Begin()
// ...
// res, err := tx.Stmt(updateMoney).Exec(123.45, 98293203)
func (tx *Tx) Stmt(stmt *Stmt) *Stmt {
// TODO(bradfitz): optimize this. Currently this re-prepares
// each time. This is fine for now to illustrate the API but
// we should really cache already-prepared statements
// per-Conn. See also the big comment in Tx.Prepare.
if tx.db != stmt.db {
return &Stmt{stickyErr: errors.New("sql: Tx.Stmt: statement from different database used")}
}
ci, err := tx.grabConn()
if err != nil {
return &Stmt{stickyErr: err}
}
defer tx.releaseConn()
si, err := ci.Prepare(stmt.query)
return &Stmt{
db: tx.db,
tx: tx,
txsi: si,
query: stmt.query,
stickyErr: err,
}
}
// Exec executes a query that doesn't return rows. // Exec executes a query that doesn't return rows.
// For example: an INSERT and UPDATE. // For example: an INSERT and UPDATE.
func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) { func (tx *Tx) Exec(query string, args ...interface{}) (Result, error) {
...@@ -448,8 +482,9 @@ type connStmt struct { ...@@ -448,8 +482,9 @@ type connStmt struct {
// Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines. // Stmt is a prepared statement. Stmt is safe for concurrent use by multiple goroutines.
type Stmt struct { type Stmt struct {
// Immutable: // Immutable:
db *DB // where we came from db *DB // where we came from
query string // that created the Sttm query string // that created the Stmt
stickyErr error // if non-nil, this error is returned for all operations
// If in a transaction, else both nil: // If in a transaction, else both nil:
tx *Tx tx *Tx
...@@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) { ...@@ -513,6 +548,9 @@ func (s *Stmt) Exec(args ...interface{}) (Result, error) {
// statement, a function to call to release the connection, and a // statement, a function to call to release the connection, and a
// statement bound to that connection. // statement bound to that connection.
func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) { func (s *Stmt) connStmt() (ci driver.Conn, releaseConn func(), si driver.Stmt, err error) {
if s.stickyErr != nil {
return nil, nil, nil, s.stickyErr
}
s.mu.Lock() s.mu.Lock()
if s.closed { if s.closed {
s.mu.Unlock() s.mu.Unlock()
...@@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row { ...@@ -621,6 +659,9 @@ func (s *Stmt) QueryRow(args ...interface{}) *Row {
// Close closes the statement. // Close closes the statement.
func (s *Stmt) Close() error { func (s *Stmt) Close() error {
if s.stickyErr != nil {
return s.stickyErr
}
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
if s.closed { if s.closed {
......
...@@ -166,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) { ...@@ -166,7 +166,7 @@ func TestBogusPreboundParameters(t *testing.T) {
} }
} }
func TestDb(t *testing.T) { func TestExec(t *testing.T) {
db := newTestDB(t, "foo") db := newTestDB(t, "foo")
defer closeDB(t, db) defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool") exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
...@@ -206,3 +206,25 @@ func TestDb(t *testing.T) { ...@@ -206,3 +206,25 @@ func TestDb(t *testing.T) {
} }
} }
} }
func TestTxStmt(t *testing.T) {
db := newTestDB(t, "")
defer closeDB(t, db)
exec(t, db, "CREATE|t1|name=string,age=int32,dead=bool")
stmt, err := db.Prepare("INSERT|t1|name=?,age=?")
if err != nil {
t.Fatalf("Stmt, err = %v, %v", stmt, err)
}
tx, err := db.Begin()
if err != nil {
t.Fatalf("Begin = %v", err)
}
_, err = tx.Stmt(stmt).Exec("Bobby", 7)
if err != nil {
t.Fatalf("Exec = %v", err)
}
err = tx.Commit()
if err != nil {
t.Fatalf("Commit = %v", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment