Commit d0334ea5 authored by gwenn's avatar gwenn

Use FTS module to autocomplete database, table, view, column names.

parent 96a298b9
......@@ -11,4 +11,4 @@ before_script:
- go get code.google.com/p/go.tools/cmd/cover
- go get -tags all github.com/gwenn/gosqlite
script:
- go test -tags all -cover ./...
\ No newline at end of file
- go test -tags all -cover
\ No newline at end of file
......@@ -4,25 +4,20 @@
package shell
import (
"strings"
import "github.com/gwenn/gosqlite"
"github.com/gwenn/gosqlite"
)
type CompletionCache struct {
memDb *sqlite.Conn
dbNames []string // "main", "temp", ...
dbCaches map[string]*databaseCache
type pendingAction struct {
action sqlite.Action
dbName string
tblName string
typ string
}
type databaseCache struct {
schemaVersion int //
tableNames map[string]string // lowercase name => original name
viewNames map[string]string
columnNames map[string][]string // lowercase table name => column name
// idxNames []string // indexed by dbName (seems useful only in DROP INDEX statement)
// trigNames []string // trigger by dbName (seems useful only in DROP TRIGGER statement)
type CompletionCache struct {
memDb *sqlite.Conn
insert *sqlite.Stmt
schemaVersions map[string]int
pendingActions []pendingAction
}
func CreateCache() (*CompletionCache, error) {
......@@ -30,19 +25,25 @@ func CreateCache() (*CompletionCache, error) {
if err != nil {
return nil, err
}
cc := &CompletionCache{memDb: db, dbNames: make([]string, 0, 2), dbCaches: make(map[string]*databaseCache)}
cc := &CompletionCache{memDb: db, schemaVersions: make(map[string]int), pendingActions: make([]pendingAction, 0, 5)}
if err = cc.init(); err != nil {
db.Close()
return nil, err
}
s, err := cc.memDb.Prepare("INSERT INTO col_names (db_name, tbl_name, type, col_name) VALUES (?, ?, ?, ?)")
if err != nil {
return nil, err
}
cc.insert = s
return cc, nil
}
func (cc *CompletionCache) init() error {
cmd := `CREATE VIRTUAL TABLE pragmaNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed="args");
CREATE VIRTUAL TABLE funcNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed="args");
CREATE VIRTUAL TABLE moduleNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed="args");
CREATE VIRTUAL TABLE cmdNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed="args");
cmd := `CREATE VIRTUAL TABLE pragma_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE func_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE moduleNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE cmd_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE col_names USING fts4(db_name, tbl_name, type, col_name, tokenize=porter, matchinfo=fts3, notindexed=type);
`
var err error
if err = cc.memDb.FastExec(cmd); err != nil {
......@@ -58,7 +59,7 @@ func (cc *CompletionCache) init() error {
err = cc.memDb.Commit()
}
}()
s, err := cc.memDb.Prepare("INSERT INTO pragmaNames (name, args) VALUES (?, ?)")
s, err := cc.memDb.Prepare("INSERT INTO pragma_names (name, args) VALUES (?, ?)")
if err != nil {
return err
}
......@@ -137,7 +138,7 @@ func (cc *CompletionCache) init() error {
}
// Only built-in functions are supported.
// TODO make possible to register extended/user-defined functions
s, err = cc.memDb.Prepare("INSERT INTO funcNames (name, args) VALUES (?, ?)")
s, err = cc.memDb.Prepare("INSERT INTO func_names (name, args) VALUES (?, ?)")
if err != nil {
return err
}
......@@ -228,7 +229,7 @@ func (cc *CompletionCache) init() error {
if err = s.Finalize(); err != nil {
return err
}
s, err = cc.memDb.Prepare("INSERT INTO cmdNames (name, args) VALUES (?, ?)")
s, err = cc.memDb.Prepare("INSERT INTO cmd_names (name, args) VALUES (?, ?)")
if err != nil {
return err
}
......@@ -284,25 +285,38 @@ func (cc *CompletionCache) init() error {
}
func (cc *CompletionCache) Close() error {
if err := cc.insert.Finalize(); err != nil {
return err
}
return cc.memDb.Close()
}
func (cc *CompletionCache) Update(db *sqlite.Conn) error {
// update database list (TODO only on ATTACH ...)
cc.dbNames = cc.dbNames[:0]
func (cc *CompletionCache) Cache(db *sqlite.Conn) error {
db.SetAuthorizer(func(udp interface{}, action sqlite.Action, arg1, arg2, dbName, triggerName string) sqlite.Auth {
switch action {
case sqlite.Detach:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1})
case sqlite.Attach:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName})
case sqlite.DropTable, sqlite.DropTempTable, sqlite.DropView, sqlite.DropTempView, sqlite.DropVTable:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName, tblName: arg1})
case sqlite.CreateTable, sqlite.CreateTempTable, sqlite.CreateVTable:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName, tblName: arg1, typ: "table"})
case sqlite.CreateView, sqlite.CreateTempView:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName, tblName: arg1, typ: "view"})
case sqlite.AlterTable:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1, tblName: arg2, typ: "table"})
// TODO trigger, index
}
return sqlite.AuthOk
}, nil)
dbNames, err := db.Databases()
if err != nil {
return err
}
// update databases cache
// cache databases schema
for dbName := range dbNames {
cc.dbNames = append(cc.dbNames, dbName)
dbc := cc.dbCaches[dbName]
if dbc == nil {
dbc = &databaseCache{schemaVersion: -1, tableNames: make(map[string]string), viewNames: make(map[string]string), columnNames: make(map[string][]string)}
cc.dbCaches[dbName] = dbc
}
err = dbc.update(db, dbName)
err = cc.cache(db, dbName)
if err != nil {
return err
}
......@@ -310,11 +324,11 @@ func (cc *CompletionCache) Update(db *sqlite.Conn) error {
return nil
}
func (dc *databaseCache) update(db *sqlite.Conn, dbName string) error {
func (cc *CompletionCache) cache(db *sqlite.Conn, dbName string) error {
var sv int
if sv, err := db.SchemaVersion(dbName); err != nil {
return err
} else if sv == dc.schemaVersion { // up to date
} else if osv, ok := cc.schemaVersions[dbName]; ok && osv == sv { // up to date
return nil
}
......@@ -327,73 +341,86 @@ func (dc *databaseCache) update(db *sqlite.Conn, dbName string) error {
} else {
ts = append(ts, "sqlite_master")
}
// clear
for table := range dc.tableNames {
delete(dc.tableNames, table)
}
for _, table := range ts {
dc.tableNames[strings.ToLower(table)] = table // TODO unicode
if err = cc.cacheTable(db, dbName, table, "table"); err != nil {
return err
}
}
vs, err := db.Views(dbName)
if err != nil {
return err
}
// clear
for view := range dc.viewNames {
delete(dc.viewNames, view)
}
for _, view := range vs {
dc.viewNames[strings.ToLower(view)] = view // TODO unicode
}
// drop
for table := range dc.columnNames {
if _, ok := dc.tableNames[table]; ok {
continue
} else if _, ok := dc.viewNames[table]; ok {
continue
if err = cc.cacheTable(db, dbName, view, "view"); err != nil {
return err
}
delete(dc.columnNames, table)
}
for table := range dc.tableNames {
cs, err := db.Columns(dbName, table)
if err != nil {
return err
}
columnNames := dc.columnNames[table]
columnNames = columnNames[:0]
for _, c := range cs {
columnNames = append(columnNames, c.Name)
}
dc.columnNames[table] = columnNames
cc.schemaVersions[dbName] = sv
return nil
}
func (cc *CompletionCache) cacheTable(db *sqlite.Conn, dbName, tblName, typ string) error {
cs, err := db.Columns(dbName, tblName)
if err != nil {
return err
}
for view := range dc.viewNames {
cs, err := db.Columns(dbName, view)
if err != nil {
for _, c := range cs {
if err = cc.insert.Exec(dbName, tblName, typ, c.Name); err != nil {
return err
}
columnNames := dc.columnNames[view]
columnNames = columnNames[:0]
for _, c := range cs {
columnNames = append(columnNames, c.Name)
}
dc.columnNames[view] = columnNames
}
return nil
}
dc.schemaVersion = sv
func (cc *CompletionCache) Update(db *sqlite.Conn) error {
for _, pa := range cc.pendingActions {
switch pa.action {
case sqlite.Attach:
if err := cc.cache(db, pa.dbName); err != nil {
return err
}
case sqlite.Detach:
if err := cc.memDb.Exec("DELETE FROM col_names WHERE db_name = ?", pa.dbName); err != nil {
return err
}
case sqlite.AlterTable:
if err := cc.memDb.Exec("DELETE FROM col_names WHERE db_name = ? AND tbl_name = ?", pa.dbName, pa.tblName); err != nil {
return err
}
fallthrough
case sqlite.CreateTable, sqlite.CreateTempTable, sqlite.CreateView, sqlite.CreateTempView, sqlite.CreateVTable:
if err := cc.cacheTable(db, pa.dbName, pa.tblName, pa.typ); err != nil {
return err
}
case sqlite.DropTable, sqlite.DropTempTable, sqlite.DropView, sqlite.DropTempView, sqlite.DropVTable:
if err := cc.memDb.Exec("DELETE FROM col_names WHERE db_name = ? AND tbl_name = ?", pa.dbName, pa.tblName); err != nil {
return err
}
}
}
cc.pendingActions = cc.pendingActions[:0]
return nil
}
func (cc *CompletionCache) Flush(db *sqlite.Conn) error {
for dbName := range cc.schemaVersions {
delete(cc.schemaVersions, dbName)
}
cc.pendingActions = cc.pendingActions[:0]
return cc.memDb.FastExec("DELETE FROM col_names")
}
func (cc *CompletionCache) CompletePragma(prefix string) ([]string, error) {
return cc.complete("pragmaNames", prefix)
return cc.complete("pragma_names", prefix)
}
func (cc *CompletionCache) CompleteFunc(prefix string) ([]string, error) {
return cc.complete("funcNames", prefix)
return cc.complete("func_names", prefix)
}
func (cc *CompletionCache) CompleteCmd(prefix string) ([]string, error) {
return cc.complete("cmdNames", prefix)
return cc.complete("cmd_names", prefix)
}
func (cc *CompletionCache) complete(tbl, prefix string) ([]string, error) {
......
......@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// +build ignore
package shell_test
import (
......@@ -50,6 +48,8 @@ func TestCache(t *testing.T) {
defer db.Close()
cc := createCache(t)
defer cc.Close()
err = cc.Update(db)
err = cc.Cache(db)
assert.Tf(t, err == nil, "%v", err)
err = cc.Flush(db)
assert.Tf(t, err == nil, "%v", err)
}
......@@ -178,6 +178,9 @@ func main() {
catchInterrupt()
err = completionCache.Cache(db)
check(err)
// TODO .mode MODE ?TABLE? Set output mode where MODE is one of:
// TODO .separator STRING Change separator used by output mode and .import
tw := tabwriter.NewWriter(os.Stdout, 0, 8, 0, '\t', 0)
......@@ -252,6 +255,7 @@ func main() {
cmd = s.Tail()
} // exec
b.Reset()
completionCache.Update(db)
}
}
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment