Commit a23f901d authored by gwenn's avatar gwenn

Test and fix cancel/interrupt

parent 85370baf
......@@ -147,14 +147,14 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
return nil, ConnError{c: c.c}
}
if err := c.c.FastExec(query); err != nil {
return nil, err
return nil, ctxError(ctx, err)
}
return c.c.result(), nil
}
for len(query) > 0 {
s, err := c.c.Prepare(query)
if err != nil {
return nil, err
return nil, ctxError(ctx, err)
} else if s.stmt == nil {
// this happens for a comment or white-space
query = s.tail
......@@ -169,15 +169,15 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
subargs = args
}
if err = s.bindNamedValue(subargs); err != nil {
return nil, err
return nil, ctxError(ctx, err)
}
err = s.exec()
if err != nil {
s.finalize()
return nil, err
return nil, ctxError(ctx, err)
}
if err = s.finalize(); err != nil {
return nil, err
return nil, ctxError(ctx, err)
}
query = s.tail
}
......@@ -257,7 +257,7 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive
s.s.c.ProgressHandler(progressHandler, 100, ctx)
defer s.s.c.ProgressHandler(nil, 0, nil)
if err := s.s.exec(); err != nil {
return nil, err
return nil, ctxError(ctx, err)
}
return s.s.c.result(), nil
}
......@@ -270,6 +270,7 @@ func (s *stmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driv
return nil, err
}
s.rowsRef = true
s.s.c.ProgressHandler(progressHandler, 100, ctx)
return &rowsImpl{s, nil, ctx}, nil
}
......@@ -290,13 +291,9 @@ func (r *rowsImpl) Columns() []string {
}
func (r *rowsImpl) Next(dest []driver.Value) error {
if r.ctx != nil {
r.s.s.c.ProgressHandler(progressHandler, 100, r.ctx)
defer r.s.s.c.ProgressHandler(nil, 0, nil)
}
ok, err := r.s.s.Next()
if err != nil {
return err
return ctxError(r.ctx, err)
}
if !ok {
return io.EOF
......@@ -311,6 +308,7 @@ func (r *rowsImpl) Next(dest []driver.Value) error {
}
func (r *rowsImpl) Close() error {
r.s.s.c.ProgressHandler(nil, 0, nil)
r.s.rowsRef = false
if r.s.pendingClose {
return r.s.Close()
......@@ -380,3 +378,11 @@ func progressHandler(p interface{}) bool {
}
return false
}
func ctxError(ctx context.Context, err error) error {
ctxErr := ctx.Err()
if ctxErr != nil {
return ctxErr
}
return err
}
......@@ -5,6 +5,7 @@
package sqlite_test
import (
"context"
"database/sql"
"math/rand"
"testing"
......@@ -419,3 +420,35 @@ func _TestPreparedStmt(t *testing.T) {
<-ch
}
}
func TestCancel(t *testing.T) {
db := sqlOpen(t)
defer checkSqlDbClose(db, t)
/*conn := sqlite.Unwrap(db)
assert.Tf(t, conn != nil, "got %#v; want *sqlite.Conn", conn)
conn.CreateScalarFunction("sleep", 0, false, nil, func(ctx *sqlite.ScalarContext, nArg int) {
time.Sleep(500 * time.Millisecond)
ctx.ResultText("ok")
}, nil)*/
ctx, cancel := context.WithCancel(context.Background())
go func() {
time.Sleep(200 * time.Millisecond)
cancel()
}()
_, err := db.ExecContext(ctx, `
WITH RECURSIVE
cnt(x) AS (VALUES(1) UNION ALL SELECT x+1 FROM cnt WHERE x<1000000)
SELECT x FROM cnt;`)
if err != context.Canceled {
t.Errorf("ExecContext expected to fail with Cancelled but it returned %v", err)
}
// connection should be usable after timeout
row := db.QueryRow("SELECT 1")
var val int64
err = row.Scan(&val)
if err != nil {
t.Fatal("Scan failed with", err)
}
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment