Commit 9f7bc846 authored by gwenn's avatar gwenn

Go 1.8: fix ExecContext

parent 6dfbed85
......@@ -174,12 +174,48 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
if c.c.IsClosed() {
return nil, driver.ErrBadConn
}
c.c.ProgressHandler(progressHandler, 100, ctx)
defer c.c.ProgressHandler(nil, 0, nil)
if len(args) == 0 {
if query == "unwrap" {
return nil, ConnError{c: c.c}
}
if err := c.c.FastExec(query); err != nil {
return nil, err
}
return c.c.result(), nil
}
for len(query) > 0 {
s, err := c.c.Prepare(query)
if err != nil {
return nil, err
} else if s.stmt == nil {
// this happens for a comment or white-space
query = s.tail
continue
}
var subargs []driver.NamedValue
count := s.BindParameterCount()
if len(s.tail) > 0 && len(args) >= count {
subargs = args[:count]
args = args[count:]
} else {
subargs = args
}
st := stmt{s: s}
return st.ExecContext(ctx, args)
if err = s.bindNamedValue(subargs); err != nil {
return nil, err
}
err = s.exec()
if err != nil {
s.finalize()
return nil, err
}
if err = s.finalize(); err != nil {
return nil, err
}
query = s.tail
}
return c.c.result(), nil
}
func (c *conn) Close() error {
......@@ -262,7 +298,7 @@ func (s *stmt) Query(args []driver.Value) (driver.Rows, error) {
}
func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) {
if err := s.bindNamedValue(args); err != nil {
if err := s.s.bindNamedValue(args); err != nil {
return nil, err
}
s.s.c.ProgressHandler(progressHandler, 100, ctx)
......@@ -274,45 +310,13 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
}
func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) {
if err := s.bindNamedValue(args); err != nil {
if err := s.s.bindNamedValue(args); err != nil {
return nil, err
}
s.rowsRef = true
return &rowsImpl{s, nil, ctx}, nil
}
func (s *stmt) bindNamedValue(args []driver.NamedValue) error {
for _, v := range args {
if len(v.Name) == 0 {
if err := s.s.BindByIndex(v.Ordinal, v.Value); err != nil {
return err
}
} else {
index, err := s.s.BindParameterIndex(v.Name)
if err != nil {
return err
}
if err = s.s.BindByIndex(index, v.Value); err != nil {
return err
}
}
}
return nil
}
func progressHandler(p interface{}) bool {
if ctx, ok := p.(context.Context); ok {
select {
case <-ctx.Done():
// Cancelled
return true
default:
return false
}
}
return false
}
func (s *stmt) bind(args []driver.Value) error {
for i, v := range args {
if err := s.s.BindByIndex(i+1, v); err != nil {
......@@ -388,3 +392,35 @@ func (c *Conn) result() driver.Result {
rows := int64(c.Changes())
return &result{id, rows} // FIXME RowAffected/noRows
}
func (s *Stmt) bindNamedValue(args []driver.NamedValue) error {
for _, v := range args {
if len(v.Name) == 0 {
if err := s.BindByIndex(v.Ordinal, v.Value); err != nil {
return err
}
} else {
index, err := s.BindParameterIndex(v.Name)
if err != nil {
return err
}
if err = s.BindByIndex(index, v.Value); err != nil {
return err
}
}
}
return nil
}
func progressHandler(p interface{}) bool {
if ctx, ok := p.(context.Context); ok {
select {
case <-ctx.Done():
// Cancelled
return true
default:
return false
}
}
return false
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment