Commit 9f7bc846 authored by gwenn's avatar gwenn

Go 1.8: fix ExecContext

parent 6dfbed85
...@@ -174,12 +174,48 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name ...@@ -174,12 +174,48 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if c.c.IsClosed() { if c.c.IsClosed() {
return nil, driver.ErrBadConn return nil, driver.ErrBadConn
} }
s, err := c.c.Prepare(query) c.c.ProgressHandler(progressHandler, 100, ctx)
if err != nil { defer c.c.ProgressHandler(nil, 0, nil)
return nil, err if len(args) == 0 {
if query == "unwrap" {
return nil, ConnError{c: c.c}
}
if err := c.c.FastExec(query); err != nil {
return nil, err
}
return c.c.result(), nil
} }
st := stmt{s: s} for len(query) > 0 {
return st.ExecContext(ctx, args) s, err := c.c.Prepare(query)
if err != nil {
return nil, err
} else if s.stmt == nil {
// this happens for a comment or white-space
query = s.tail
continue
}
var subargs []driver.NamedValue
count := s.BindParameterCount()
if len(s.tail) > 0 && len(args) >= count {
subargs = args[:count]
args = args[count:]
} else {
subargs = args
}
if err = s.bindNamedValue(subargs); err != nil {
return nil, err
}
err = s.exec()
if err != nil {
s.finalize()
return nil, err
}
if err = s.finalize(); err != nil {
return nil, err
}
query = s.tail
}
return c.c.result(), nil
} }
func (c *conn) Close() error { func (c *conn) Close() error {
...@@ -262,7 +298,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { ...@@ -262,7 +298,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
} }
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if err := s.bindNamedValue(args); err != nil { if err := s.s.bindNamedValue(args); err != nil {
return nil, err return nil, err
} }
s.s.c.ProgressHandler(progressHandler, 100, ctx) s.s.c.ProgressHandler(progressHandler, 100, ctx)
...@@ -274,45 +310,13 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive ...@@ -274,45 +310,13 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
} }
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if err := s.bindNamedValue(args); err != nil { if err := s.s.bindNamedValue(args); err != nil {
return nil, err return nil, err
} }
s.rowsRef = true s.rowsRef = true
return &rowsImpl{s, nil, ctx}, nil return &rowsImpl{s, nil, ctx}, nil
} }
func (s *stmt) bindNamedValue(args []driver.NamedValue) error {
for _, v := range args {
if len(v.Name) == 0 {
if err := s.s.BindByIndex(v.Ordinal, v.Value); err != nil {
return err
}
} else {
index, err := s.s.BindParameterIndex(v.Name)
if err != nil {
return err
}
if err = s.s.BindByIndex(index, v.Value); err != nil {
return err
}
}
}
return nil
}
func progressHandler(p interface{}) bool {
if ctx, ok := p.(context.Context); ok {
select {
case <-ctx.Done():
// Cancelled
return true
default:
return false
}
}
return false
}
func (s *stmt) bind(args []driver.Value) error { func (s *stmt) bind(args []driver.Value) error {
for i, v := range args { for i, v := range args {
if err := s.s.BindByIndex(i+1, v); err != nil { if err := s.s.BindByIndex(i+1, v); err != nil {
...@@ -388,3 +392,35 @@ func (c *Conn) result() driver.Result { ...@@ -388,3 +392,35 @@ func (c *Conn) result() driver.Result {
rows := int64(c.Changes()) rows := int64(c.Changes())
return &result{id, rows} // FIXME RowAffected/noRows return &result{id, rows} // FIXME RowAffected/noRows
} }
func (s *Stmt) bindNamedValue(args []driver.NamedValue) error {
for _, v := range args {
if len(v.Name) == 0 {
if err := s.BindByIndex(v.Ordinal, v.Value); err != nil {
return err
}
} else {
index, err := s.BindParameterIndex(v.Name)
if err != nil {
return err
}
if err = s.BindByIndex(index, v.Value); err != nil {
return err
}
}
}
return nil
}
func progressHandler(p interface{}) bool {
if ctx, ok := p.(context.Context); ok {
select {
case <-ctx.Done():
// Cancelled
return true
default:
return false
}
}
return false
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment