Commit 3f258c0f authored by gwenn's avatar gwenn

First draft of aggregation function.

parent ab027b97
...@@ -52,17 +52,27 @@ static void goSqlite3SetAuxdata(sqlite3_context *ctx, int N, void *ad) { ...@@ -52,17 +52,27 @@ static void goSqlite3SetAuxdata(sqlite3_context *ctx, int N, void *ad) {
sqlite3_set_auxdata(ctx, N, ad, goXAuxDataDestroy); sqlite3_set_auxdata(ctx, N, ad, goXAuxDataDestroy);
} }
extern void goXFunc(sqlite3_context *ctx, void *udf, void *goctx, int argc, sqlite3_value **argv); extern void goXFuncOrStep(sqlite3_context *ctx, void *udf, void *goctx, int argc, sqlite3_value **argv);
extern void goXFinal(void *udf, void *goctx);
extern void goXDestroy(void *pApp); extern void goXDestroy(void *pApp);
static void cXFunc(sqlite3_context *ctx, int argc, sqlite3_value **argv) { static void cXFuncOrStep(sqlite3_context *ctx, int argc, sqlite3_value **argv) {
void *udf = sqlite3_user_data(ctx); void *udf = sqlite3_user_data(ctx);
void *goctx = sqlite3_get_auxdata(ctx, 0); void *goctx = sqlite3_get_auxdata(ctx, 0);
goXFunc(ctx, udf, goctx, argc, argv); goXFuncOrStep(ctx, udf, goctx, argc, argv);
}
static void cXFinal(sqlite3_context *ctx) {
void *udf = sqlite3_user_data(ctx);
void *goctx = sqlite3_get_auxdata(ctx, 0);
goXFinal(udf, goctx);
} }
static int goSqlite3CreateScalarFunction(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp) { static int goSqlite3CreateScalarFunction(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp) {
return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, cXFunc, NULL, NULL, goXDestroy); return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, cXFuncOrStep, NULL, NULL, goXDestroy);
}
static int goSqlite3CreateAggregateFunction(sqlite3 *db, const char *zFunctionName, int nArg, int eTextRep, void *pApp) {
return sqlite3_create_function_v2(db, zFunctionName, nArg, eTextRep, pApp, NULL, cXFuncOrStep, cXFinal, goXDestroy);
} }
*/ */
import "C" import "C"
...@@ -81,9 +91,10 @@ sqlite3 *sqlite3_context_db_handle(sqlite3_context*); ...@@ -81,9 +91,10 @@ sqlite3 *sqlite3_context_db_handle(sqlite3_context*);
*/ */
type Context struct { type Context struct {
sc *C.sqlite3_context sc *C.sqlite3_context
argv **C.sqlite3_value argv **C.sqlite3_value
ad map[int]interface{} // Function Auxiliary Data ad map[int]interface{} // Function Auxiliary Data
AggregateContext interface{} // Aggregate Function Context
} }
func (c *Context) Result(r interface{}) { func (c *Context) Result(r interface{}) {
...@@ -315,14 +326,14 @@ type FinalFunction func(ctx *Context) ...@@ -315,14 +326,14 @@ type FinalFunction func(ctx *Context)
type DestroyFunctionData func(pApp interface{}) type DestroyFunctionData func(pApp interface{})
/* /*
void (*xFunc)(sqlite3_context*,int,sqlite3_value**),
void (*xStep)(sqlite3_context*,int,sqlite3_value**), void (*xStep)(sqlite3_context*,int,sqlite3_value**),
*/ */
type sqliteFunction struct { type sqliteFunction struct {
f ScalarFunction funcOrStep ScalarFunction
d DestroyFunctionData final FinalFunction
pApp interface{} d DestroyFunctionData
pApp interface{}
} }
// To prevent Context from being gced // To prevent Context from being gced
...@@ -338,8 +349,8 @@ func goXAuxDataDestroy(ad unsafe.Pointer) { ...@@ -338,8 +349,8 @@ func goXAuxDataDestroy(ad unsafe.Pointer) {
//fmt.Printf("%v\n", contexts) //fmt.Printf("%v\n", contexts)
} }
//export goXFunc //export goXFuncOrStep
func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) { func goXFuncOrStep(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
udf := (*sqliteFunction)(udfp) udf := (*sqliteFunction)(udfp)
// To avoid the creation of a Context at each call, just put it in auxdata // To avoid the creation of a Context at each call, just put it in auxdata
c := (*Context)(ctxp) c := (*Context)(ctxp)
...@@ -351,7 +362,15 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) { ...@@ -351,7 +362,15 @@ func goXFunc(scp, udfp, ctxp unsafe.Pointer, argc int, argv unsafe.Pointer) {
contexts[c.sc] = c contexts[c.sc] = c
} }
c.argv = (**C.sqlite3_value)(argv) c.argv = (**C.sqlite3_value)(argv)
udf.f(c, argc) udf.funcOrStep(c, argc)
c.argv = nil
}
//export goXFinal
func goXFinal(udfp, ctxp unsafe.Pointer) {
udf := (*sqliteFunction)(udfp)
c := (*Context)(ctxp)
udf.final(c)
} }
//export goXDestroy //export goXDestroy
...@@ -375,7 +394,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac ...@@ -375,7 +394,7 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil)) return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
} }
// To make sure it is not gced, keep a reference in the connection. // To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, d, pApp} udf := &sqliteFunction{f, nil, d, pApp}
if len(c.udfs) == 0 { if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction) c.udfs = make(map[string]*sqliteFunction)
} }
...@@ -383,29 +402,31 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac ...@@ -383,29 +402,31 @@ func (c *Conn) CreateScalarFunction(functionName string, nArg int, pApp interfac
return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf))) return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)))
} }
// Obtain Aggregate Function Context /*
// Calls http://sqlite.org/c3ref/aggregate_context.html // Calls http://sqlite.org/c3ref/aggregate_context.html
func (c *Context) AggregateContext(nBytes int) interface{} { func (c *Context) AggregateContext(nBytes int) interface{} {
return C.sqlite3_aggregate_context(c.sc, C.int(nBytes)) return C.sqlite3_aggregate_context(c.sc, C.int(nBytes))
} }
*/
// Create or redefine SQL functions // Create or redefine SQL functions
// TODO Make possible to specify the preferred encoding // TODO Make possible to specify the preferred encoding
// Calls http://sqlite.org/c3ref/create_function.html // Calls http://sqlite.org/c3ref/create_function.html
func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{}, f ScalarFunction, d DestroyFunctionData) error { func (c *Conn) CreateAggregateFunction(functionName string, nArg int, pApp interface{},
step ScalarFunction, final FinalFunction, d DestroyFunctionData) error {
fname := C.CString(functionName) fname := C.CString(functionName)
defer C.free(unsafe.Pointer(fname)) defer C.free(unsafe.Pointer(fname))
if f == nil { if step == nil {
if len(c.udfs) > 0 { if len(c.udfs) > 0 {
delete(c.udfs, functionName) delete(c.udfs, functionName)
} }
return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil)) return c.error(C.sqlite3_create_function_v2(c.db, fname, C.int(nArg), C.SQLITE_UTF8, nil, nil, nil, nil, nil))
} }
// To make sure it is not gced, keep a reference in the connection. // To make sure it is not gced, keep a reference in the connection.
udf := &sqliteFunction{f, d, pApp} udf := &sqliteFunction{step, final, d, pApp}
if len(c.udfs) == 0 { if len(c.udfs) == 0 {
c.udfs = make(map[string]*sqliteFunction) c.udfs = make(map[string]*sqliteFunction)
} }
c.udfs[functionName] = udf // FIXME same function name with different args is not supported c.udfs[functionName] = udf // FIXME same function name with different args is not supported
return c.error(C.goSqlite3CreateScalarFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf))) return c.error(C.goSqlite3CreateAggregateFunction(c.db, fname, C.int(nArg), C.SQLITE_UTF8, unsafe.Pointer(udf)))
} }
...@@ -80,6 +80,7 @@ func TestRegexpFunction(t *testing.T) { ...@@ -80,6 +80,7 @@ func TestRegexpFunction(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("couldn't prepare statement: %s", err) t.Fatalf("couldn't prepare statement: %s", err)
} }
defer s.Finalize()
if b := Must(s.Next()); !b { if b := Must(s.Next()); !b {
t.Fatalf("No result") t.Fatalf("No result")
} }
...@@ -100,15 +101,54 @@ func TestRegexpFunction(t *testing.T) { ...@@ -100,15 +101,54 @@ func TestRegexpFunction(t *testing.T) {
if i != 0 { if i != 0 {
t.Errorf("Expected %d but got %d", 0, i) t.Errorf("Expected %d but got %d", 0, i)
} }
if err = s.Finalize(); err != nil { }
t.Fatalf("couldn't finalize statement: %s", err)
func sumStep(ctx *Context, nArg int) {
nt := ctx.NumericType(0)
if nt == Integer || nt == Float {
var sum float64
var ok bool
if sum, ok = (ctx.AggregateContext).(float64); !ok {
sum = 0
}
sum += ctx.Double(0)
ctx.AggregateContext = sum
}
}
func sumFinal(ctx *Context) {
if sum, ok := (ctx.AggregateContext).(float64); ok {
ctx.ResultDouble(sum)
} else {
ctx.ResultNull()
}
}
/*
func TestSumFunction(t *testing.T) {
db, err := Open("")
if err != nil {
t.Fatalf("couldn't open database file: %s", err)
}
defer db.Close()
if err = db.CreateAggregateFunction("sum", 1, nil, sumStep, sumFinal, nil); err != nil {
t.Fatalf("couldn't create function: %s", err)
}
i, err := db.OneValue("select sum(i) from (select 2 as i union all select 2 as i)")
if err != nil {
t.Fatalf("couldn't execute statement: %s", err)
}
if i != 4 {
t.Errorf("Expected %d but got %d", 4, i)
} }
} }
*/
func randomFill(db *Conn, n int) { func randomFill(db *Conn, n int) {
db.Exec("DROP TABLE IF EXISTS test") db.Exec("DROP TABLE IF EXISTS test")
db.Exec("CREATE TABLE test (name TEXT, rank int)") db.Exec("CREATE TABLE test (name TEXT, rank int)")
s, _ := db.Prepare("INSERT INTO test (name, rank) VALUES (?, ?)") s, _ := db.Prepare("INSERT INTO test (name, rank) VALUES (?, ?)")
defer s.Finalize()
names := []string{"Bart", "Homer", "Lisa", "Maggie", "Marge"} names := []string{"Bart", "Homer", "Lisa", "Maggie", "Marge"}
...@@ -116,7 +156,6 @@ func randomFill(db *Conn, n int) { ...@@ -116,7 +156,6 @@ func randomFill(db *Conn, n int) {
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
s.Exec(names[rand.Intn(len(names))], rand.Intn(100)) s.Exec(names[rand.Intn(len(names))], rand.Intn(100))
} }
s.Finalize()
db.Commit() db.Commit()
} }
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment