Commit 222ec3c2 authored by gwenn's avatar gwenn

Add completion methods for database, table and column names

parent 767f75c7
...@@ -4,7 +4,11 @@ ...@@ -4,7 +4,11 @@
package shell package shell
import "github.com/gwenn/gosqlite" import (
"strings"
"github.com/gwenn/gosqlite"
)
type pendingAction struct { type pendingAction struct {
action sqlite.Action action sqlite.Action
...@@ -40,7 +44,7 @@ func CreateCache() (*CompletionCache, error) { ...@@ -40,7 +44,7 @@ func CreateCache() (*CompletionCache, error) {
func (cc *CompletionCache) init() error { func (cc *CompletionCache) init() error {
cmd := `CREATE VIRTUAL TABLE pragma_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args); cmd := `CREATE VIRTUAL TABLE pragma_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE func_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args); CREATE VIRTUAL TABLE func_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE moduleNames USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args); CREATE VIRTUAL TABLE module_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE cmd_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args); CREATE VIRTUAL TABLE cmd_names USING fts4(name, args, tokenize=porter, matchinfo=fts3, notindexed=args);
CREATE VIRTUAL TABLE col_names USING fts4(db_name, tbl_name, type, col_name, tokenize=porter, matchinfo=fts3, notindexed=type); CREATE VIRTUAL TABLE col_names USING fts4(db_name, tbl_name, type, col_name, tokenize=porter, matchinfo=fts3, notindexed=type);
` `
...@@ -208,7 +212,7 @@ func (cc *CompletionCache) init() error { ...@@ -208,7 +212,7 @@ func (cc *CompletionCache) init() error {
} }
// Only built-in modules are supported. // Only built-in modules are supported.
// TODO make possible to register extended/user-defined modules // TODO make possible to register extended/user-defined modules
s, err = cc.memDb.Prepare("INSERT INTO moduleNames (name, args) VALUES (?, ?)") s, err = cc.memDb.Prepare("INSERT INTO module_names (name, args) VALUES (?, ?)")
if err != nil { if err != nil {
return err return err
} }
...@@ -303,7 +307,9 @@ func (cc *CompletionCache) Cache(db *sqlite.Conn) error { ...@@ -303,7 +307,9 @@ func (cc *CompletionCache) Cache(db *sqlite.Conn) error {
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1}) cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1})
case sqlite.Attach: case sqlite.Attach:
// database name is not available, only the path... // database name is not available, only the path...
if arg1 != "" && arg1 != ":memory:" { // temporary db: "" and memory db: ":memory:" are empty when attached.
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1}) cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: arg1})
}
case sqlite.DropTable, sqlite.DropTempTable, sqlite.DropView, sqlite.DropTempView, sqlite.DropVTable: case sqlite.DropTable, sqlite.DropTempTable, sqlite.DropView, sqlite.DropTempView, sqlite.DropVTable:
cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName, tblName: arg1}) cc.pendingActions = append(cc.pendingActions, pendingAction{action: action, dbName: dbName, tblName: arg1})
case sqlite.CreateTable, sqlite.CreateTempTable, sqlite.CreateVTable: case sqlite.CreateTable, sqlite.CreateTempTable, sqlite.CreateVTable:
...@@ -427,7 +433,97 @@ func (cc *CompletionCache) CompleteCmd(prefix string) ([]string, error) { ...@@ -427,7 +433,97 @@ func (cc *CompletionCache) CompleteCmd(prefix string) ([]string, error) {
} }
func (cc *CompletionCache) complete(tbl, prefix string) ([]string, error) { func (cc *CompletionCache) complete(tbl, prefix string) ([]string, error) {
s, err := cc.memDb.Prepare("SELECT name FROM " + tbl + " WHERE name MATCH ?||'*' ORDER BY 1") s, err := cc.memDb.Prepare("SELECT name FROM "+tbl+" WHERE name MATCH ?||'*' ORDER BY 1", prefix)
if err != nil {
return nil, err
}
defer s.Finalize()
var names []string
if err = s.Select(func(s *sqlite.Stmt) error {
name, _ := s.ScanText(0)
names = append(names, name)
return nil
}); err != nil {
return nil, err
}
return names, nil
}
func (cc *CompletionCache) CompleteDbName(prefix string) ([]string, error) {
s, err := cc.memDb.Prepare("SELECT DISTINCT db_name FROM col_names WHERE db_name MATCH ?||'*' ORDER BY 1", prefix)
if err != nil {
return nil, err
}
defer s.Finalize()
var names []string
if err = s.Select(func(s *sqlite.Stmt) error {
name, _ := s.ScanText(0)
names = append(names, name)
return nil
}); err != nil {
return nil, err
}
return names, nil
}
func (cc *CompletionCache) CompleteTableName(dbName, prefix, typ string) ([]string, error) {
args := make([]interface{}, 0, 3)
if dbName != "" {
args = append(args, dbName)
}
args = append(args, prefix)
if typ != "" {
args = append(args, typ)
}
var sql string
if dbName == "" {
if typ == "" {
sql = "SELECT DISTINCT tbl_name FROM col_names WHERE tbl_name MATCH ?||'*' ORDER BY 1"
} else {
sql = "SELECT DISTINCT tbl_name FROM col_names WHERE tbl_name MATCH ?||'*' AND type = ? ORDER BY 1"
}
} else {
if typ == "" {
sql = "SELECT DISTINCT tbl_name FROM col_names WHERE db_name = ? AND tbl_name MATCH ?||'*' ORDER BY 1"
} else {
sql = "SELECT DISTINCT tbl_name FROM col_names WHERE db_name = ? AND tbl_name MATCH ?||'*' AND type = ? ORDER BY 1"
}
}
s, err := cc.memDb.Prepare(sql, args...)
if err != nil {
return nil, err
}
defer s.Finalize()
var names []string
if err = s.Select(func(s *sqlite.Stmt) error {
name, _ := s.ScanText(0)
names = append(names, name)
return nil
}); err != nil {
return nil, err
}
return names, nil
}
// tbl_names is mandatory
func (cc *CompletionCache) CompleteColName(dbName string, tbl_names []string, prefix string) ([]string, error) {
args := make([]interface{}, 0, 10)
if dbName != "" {
args = append(args, dbName)
}
phs := make([]string, 0, 10)
for _, tbl_name := range tbl_names {
args = append(args, tbl_name)
phs = append(phs, "?")
}
args = append(args, prefix)
var sql string
if dbName == "" {
sql = "SELECT DISTINCT col_name FROM col_names WHERE tbl_name IN (" + strings.Join(phs, ",") + ") AND col_name MATCH ?||'*' ORDER BY 1"
} else {
sql = "SELECT DISTINCT col_name FROM col_names WHERE db_name = ? AND tbl_name IN (" + strings.Join(phs, ",") + ") AND col_name MATCH ?||'*' ORDER BY 1"
}
s, err := cc.memDb.Prepare(sql, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
...@@ -437,7 +533,7 @@ func (cc *CompletionCache) complete(tbl, prefix string) ([]string, error) { ...@@ -437,7 +533,7 @@ func (cc *CompletionCache) complete(tbl, prefix string) ([]string, error) {
name, _ := s.ScanText(0) name, _ := s.ScanText(0)
names = append(names, name) names = append(names, name)
return nil return nil
}, prefix); err != nil { }); err != nil {
return nil, err return nil, err
} }
return names, nil return names, nil
......
...@@ -26,6 +26,7 @@ func TestPragmaNames(t *testing.T) { ...@@ -26,6 +26,7 @@ func TestPragmaNames(t *testing.T) {
assert.Equalf(t, 3, len(pragmas), "got %d pragmas; expected %d", len(pragmas), 3) assert.Equalf(t, 3, len(pragmas), "got %d pragmas; expected %d", len(pragmas), 3)
assert.Equal(t, []string{"foreign_key_check", "foreign_key_list(", "foreign_keys"}, pragmas, "unexpected pragmas") assert.Equal(t, []string{"foreign_key_check", "foreign_key_list(", "foreign_keys"}, pragmas, "unexpected pragmas")
} }
func TestFuncNames(t *testing.T) { func TestFuncNames(t *testing.T) {
cc := createCache(t) cc := createCache(t)
defer cc.Close() defer cc.Close()
...@@ -34,6 +35,7 @@ func TestFuncNames(t *testing.T) { ...@@ -34,6 +35,7 @@ func TestFuncNames(t *testing.T) {
assert.Equal(t, 2, len(funcs), "got %d functions; expected %d", len(funcs), 2) assert.Equal(t, 2, len(funcs), "got %d functions; expected %d", len(funcs), 2)
assert.Equal(t, []string{"substr(", "sum("}, funcs, "unexpected functions") assert.Equal(t, []string{"substr(", "sum("}, funcs, "unexpected functions")
} }
func TestCmdNames(t *testing.T) { func TestCmdNames(t *testing.T) {
cc := createCache(t) cc := createCache(t)
defer cc.Close() defer cc.Close()
...@@ -42,6 +44,7 @@ func TestCmdNames(t *testing.T) { ...@@ -42,6 +44,7 @@ func TestCmdNames(t *testing.T) {
assert.Equal(t, 2, len(cmds), "got %d commands; expected %d", len(cmds), 2) assert.Equal(t, 2, len(cmds), "got %d commands; expected %d", len(cmds), 2)
assert.Equal(t, []string{".headers", ".help"}, cmds, "unexpected commands") assert.Equal(t, []string{".headers", ".help"}, cmds, "unexpected commands")
} }
func TestCache(t *testing.T) { func TestCache(t *testing.T) {
db, err := sqlite.Open(":memory:") db, err := sqlite.Open(":memory:")
assert.Tf(t, err == nil, "%v", err) assert.Tf(t, err == nil, "%v", err)
...@@ -53,3 +56,36 @@ func TestCache(t *testing.T) { ...@@ -53,3 +56,36 @@ func TestCache(t *testing.T) {
err = cc.Flush(db) err = cc.Flush(db)
assert.Tf(t, err == nil, "%v", err) assert.Tf(t, err == nil, "%v", err)
} }
func TestCacheUpdate(t *testing.T) {
db, err := sqlite.Open(":memory:")
assert.Tf(t, err == nil, "%v", err)
defer db.Close()
cc := createCache(t)
defer cc.Close()
err = cc.Cache(db)
assert.Tf(t, err == nil, "%v", err)
err = db.FastExec("CREATE TABLE test (id INTEGER PRIMARY KEY NOT NULL, name TEXT UNIQUE NOT NULL)")
assert.Tf(t, err == nil, "%v", err)
cc.Update(db)
db_names, err := cc.CompleteDbName("m")
assert.Tf(t, err == nil, "%v", err)
assert.Equal(t, 1, len(db_names), "got %d database names; expected %d", len(db_names), 1)
assert.Equal(t, []string{"main"}, db_names, "unexpected database names")
tbl_names, err := cc.CompleteTableName("", "te", "")
assert.Tf(t, err == nil, "%v", err)
assert.Equal(t, 1, len(tbl_names), "got %d table names; expected %d", len(tbl_names), 1)
assert.Equal(t, []string{"test"}, tbl_names, "unexpected table names")
tbl_names, err = cc.CompleteTableName("main", "te", "table")
assert.Tf(t, err == nil, "%v", err)
assert.Equal(t, 1, len(tbl_names), "got %d table names; expected %d", len(tbl_names), 1)
assert.Equal(t, []string{"test"}, tbl_names, "unexpected table names")
col_names, err := cc.CompleteColName("", []string{"test"}, "n")
assert.Tf(t, err == nil, "%v", err)
assert.Equal(t, 1, len(col_names), "got %d column names; expected %d", len(col_names), 1)
assert.Equal(t, []string{"name"}, col_names, "unexpected column names")
}
...@@ -705,6 +705,6 @@ func Parse(line string) { ...@@ -705,6 +705,6 @@ func Parse(line string) {
if item.typ == itemEOF { if item.typ == itemEOF {
break break
} }
fmt.Printf("%s\n", item) //fmt.Printf("%s\n", item)
} }
} }
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment