Commit 9c0804e7 authored by gwenn's avatar gwenn

Improve User function support.

parent 2469a2f1
...@@ -71,7 +71,8 @@ type Context struct { ...@@ -71,7 +71,8 @@ type Context struct {
} }
type ScalarContext struct { type ScalarContext struct {
Context Context
ad map[int]interface{} // Function Auxiliary Data ad map[int]interface{} // Function Auxiliary Data
udf *sqliteFunction
} }
type AggregateContext struct { type AggregateContext struct {
Context Context
...@@ -308,25 +309,22 @@ type FinalFunction func(ctx *AggregateContext) ...@@ -308,25 +309,22 @@ type FinalFunction func(ctx *AggregateContext)
type DestroyFunctionData func(pApp interface{}) type DestroyFunctionData func(pApp interface{})
type sqliteFunction struct { type sqliteFunction struct {
scalar ScalarFunction scalar ScalarFunction
step StepFunction step StepFunction
final FinalFunction final FinalFunction
d DestroyFunctionData d DestroyFunctionData
pApp interface{} pApp interface{}
contexts map[*C.sqlite3_context]*AggregateContext scalarCtxs map[*ScalarContext]bool
aggrCtxs map[*AggregateContext]bool
} }
// To prevent Context from being gced
// TODO Retry to put this in the sqliteFunction
var contexts map[*C.sqlite3_context]*ScalarContext = make(map[*C.sqlite3_context]*ScalarContext)
//export goXAuxDataDestroy //export goXAuxDataDestroy
func goXAuxDataDestroy(ad unsafe.Pointer) { func goXAuxDataDestroy(ad unsafe.Pointer) {
c := (*ScalarContext)(ad) c := (*ScalarContext)(ad)
if c != nil { if c != nil {
delete(contexts, c.sc) delete(c.udf.scalarCtxs, c)
} }
// fmt.Printf("Contexts: %v\n", contexts) // fmt.Printf("Contexts: %v\n", c.udf.scalarCtxs)
} }
//export goXFunc //export goXFunc
...@@ -337,9 +335,10 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) { ...@@ -337,9 +335,10 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
if c == nil { if c == nil {
c = new(ScalarContext) c = new(ScalarContext)
c.sc = (*C.sqlite3_context)(scp) c.sc = (*C.sqlite3_context)(scp)
c.udf = udf
C.goSqlite3SetAuxdata(c.sc, 0, unsafe.Pointer(c)) C.goSqlite3SetAuxdata(c.sc, 0, unsafe.Pointer(c))
// To make sure it is not cged // To make sure it is not cged
contexts[c.sc] = c udf.scalarCtxs[c] = true
} }
c.argv = (**C.sqlite3_value)(argv) c.argv = (**C.sqlite3_value)(argv)
udf.scalar(c, argc) udf.scalar(c, argc)
...@@ -359,7 +358,7 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) { ...@@ -359,7 +358,7 @@ func goXStep(scp, udfp unsafe.Pointer, argc int, argv unsafe.Pointer) {
c.sc = (*C.sqlite3_context)(scp) c.sc = (*C.sqlite3_context)(scp)
*(*unsafe.Pointer)(cp) = unsafe.Pointer(c) *(*unsafe.Pointer)(cp) = unsafe.Pointer(c)
// To make sure it is not cged // To make sure it is not cged
udf.contexts[c.sc] = c udf.aggrCtxs[c] = true
} else { } else {
c = (*AggregateContext)(p) c = (*AggregateContext)(p)
} }
...@@ -378,12 +377,12 @@ func goXFinal(scp, udfp unsafe.Pointer) { ...@@ -378,12 +377,12 @@ func goXFinal(scp, udfp unsafe.Pointer) {
p := *(*unsafe.Pointer)(cp) p := *(*unsafe.Pointer)(cp)
if p != nil { if p != nil {
c := (*AggregateContext)(p) c := (*AggregateContext)(p)
delete(udf.contexts, c.sc) delete(udf.aggrCtxs, c)
c.sc = (*C.sqlite3_context)(scp) c.sc = (*C.sqlite3_context)(scp)
udf.final(c) udf.final(c)
} }
} }
// fmt.Printf("Contexts: %v\n", udf.contexts) // fmt.Printf("Contexts: %v\n", udf.aggrCtxts)
} }
//export goXDestroy //export goXDestroy
...@@ -407,7 +406,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac ...@@ -407,7 +406,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil)) return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
} }
// To make sure it is not gced, keep a reference in the connection. // To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, nil, nil, d, pApp, nil} udf := &sqliteFunction{f, nil, nil, d, pApp, make(map[*ScalarContext]bool), nil}
if len(c.udfs) == 0 { if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction) c.udfs = make(map[string]*sqliteFunction)
} }
...@@ -429,7 +428,7 @@ func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp inter ...@@ -429,7 +428,7 @@ func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp inter
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil)) return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
} }
// To make sure it is not gced, keep a reference in the connection. // To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{nil, step, final, d, pApp, make(map[*C.sqlite3_context]*AggregateContext)} udf := &sqliteFunction{nil, step, final, d, pApp, nil, make(map[*AggregateContext]bool)}
if len(c.udfs) == 0 { if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction) c.udfs = make(map[string]*sqliteFunction)
} }
......
...@@ -32,10 +32,13 @@ func TestScalarFunction(t *testing.T) { ...@@ -32,10 +32,13 @@ func TestScalarFunction(t *testing.T) {
checkNoError(t, err, "couldn't destroy function: %s") checkNoError(t, err, "couldn't destroy function: %s")
} }
var reused bool
func re(ctx *ScalarContext, nArg int) { func re(ctx *ScalarContext, nArg int) {
ad := ctx.GetAuxData(0) ad := ctx.GetAuxData(0)
var re *regexp.Regexp var re *regexp.Regexp
if ad == nil { if ad == nil {
reused = false
//println("Compile") //println("Compile")
var err error var err error
re, err = regexp.Compile(ctx.Text(0)) re, err = regexp.Compile(ctx.Text(0))
...@@ -45,6 +48,7 @@ func re(ctx *ScalarContext, nArg int) { ...@@ -45,6 +48,7 @@ func re(ctx *ScalarContext, nArg int) {
} }
ctx.SetAuxData(0, re) ctx.SetAuxData(0, re)
} else { } else {
reused = true
//println("Reuse") //println("Reuse")
var ok bool var ok bool
if re, ok = ad.(*regexp.Regexp); !ok { if re, ok = ad.(*regexp.Regexp); !ok {
...@@ -71,6 +75,7 @@ func TestRegexpFunction(t *testing.T) { ...@@ -71,6 +75,7 @@ func TestRegexpFunction(t *testing.T) {
s, err := db.Prepare("select regexp('l.s[aeiouy]', name) from (select 'lisa' as name union all select 'bart')") s, err := db.Prepare("select regexp('l.s[aeiouy]', name) from (select 'lisa' as name union all select 'bart')")
checkNoError(t, err, "couldn't prepare statement: %s") checkNoError(t, err, "couldn't prepare statement: %s")
defer s.Finalize() defer s.Finalize()
if b := Must(s.Next()); !b { if b := Must(s.Next()); !b {
t.Fatalf("No result") t.Fatalf("No result")
} }
...@@ -79,6 +84,10 @@ func TestRegexpFunction(t *testing.T) { ...@@ -79,6 +84,10 @@ func TestRegexpFunction(t *testing.T) {
if i != 1 { if i != 1 {
t.Errorf("Expected %d but got %d", 1, i) t.Errorf("Expected %d but got %d", 1, i)
} }
if reused {
t.Errorf("unexpected reused state")
}
if b := Must(s.Next()); !b { if b := Must(s.Next()); !b {
t.Fatalf("No result") t.Fatalf("No result")
} }
...@@ -87,6 +96,9 @@ func TestRegexpFunction(t *testing.T) { ...@@ -87,6 +96,9 @@ func TestRegexpFunction(t *testing.T) {
if i != 0 { if i != 0 {
t.Errorf("Expected %d but got %d", 0, i) t.Errorf("Expected %d but got %d", 0, i)
} }
if !reused {
t.Errorf("unexpected reused state")
}
} }
func sumStep(ctx *AggregateContext, nArg int) { func sumStep(ctx *AggregateContext, nArg int) {
...@@ -158,8 +170,8 @@ func BenchmarkHalf(b *testing.B) { ...@@ -158,8 +170,8 @@ func BenchmarkHalf(b *testing.B) {
db, _ := Open("") db, _ := Open("")
defer db.Close() defer db.Close()
randomFill(db, 1) randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
db.CreateScalarFunction("half", 1, nil, half, nil) db.CreateScalarFunction("half", 1, nil, half, nil)
cs, _ := db.Prepare("SELECT count(1) FROM test where half(rank) > 20")
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
...@@ -173,8 +185,8 @@ func BenchmarkRegexp(b *testing.B) { ...@@ -173,8 +185,8 @@ func BenchmarkRegexp(b *testing.B) {
db, _ := Open("") db, _ := Open("")
defer db.Close() defer db.Close()
randomFill(db, 1) randomFill(db, 1)
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
db.CreateScalarFunction("regexp", 2, nil, re, reDestroy) db.CreateScalarFunction("regexp", 2, nil, re, reDestroy)
cs, _ := db.Prepare("SELECT count(1) FROM test where name regexp '(?i)\\blisa\\b'")
b.StartTimer() b.StartTimer()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment