Commit 1a72a524 authored by gwenn's avatar gwenn

Prefer double quote for identifier

parent 6bc819a9
......@@ -45,7 +45,7 @@ func (c *Conn) Tables(dbName string, temp bool) ([]string, error) {
if len(dbName) == 0 {
sql = "SELECT name FROM sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%' ORDER BY 1"
} else {
sql = Mprintf("SELECT name FROM %Q.sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%%' ORDER BY 1", dbName)
sql = fmt.Sprintf("SELECT name FROM %s.sqlite_master WHERE type = 'table' AND name NOT LIKE 'sqlite_%%' ORDER BY 1", doubleQuote(dbName))
}
if temp {
sql = strings.Replace(sql, "sqlite_master", "sqlite_temp_master", 1)
......@@ -73,7 +73,7 @@ func (c *Conn) Views(dbName string, temp bool) ([]string, error) {
if len(dbName) == 0 {
sql = "SELECT name FROM sqlite_master WHERE type = 'view' ORDER BY 1"
} else {
sql = Mprintf("SELECT name FROM %Q.sqlite_master WHERE type = 'view' ORDER BY 1", dbName)
sql = fmt.Sprintf("SELECT name FROM %s.sqlite_master WHERE type = 'view' ORDER BY 1", doubleQuote(dbName))
}
if temp {
sql = strings.Replace(sql, "sqlite_master", "sqlite_temp_master", 1)
......@@ -102,7 +102,7 @@ func (c *Conn) Indexes(dbName string, temp bool) (map[string]string, error) {
if len(dbName) == 0 {
sql = "SELECT name, tbl_name FROM sqlite_master WHERE type = 'index'"
} else {
sql = Mprintf("SELECT name, tbl_name FROM %Q.sqlite_master WHERE type = 'index'", dbName)
sql = fmt.Sprintf("SELECT name, tbl_name FROM %s.sqlite_master WHERE type = 'index'", doubleQuote(dbName))
}
if temp {
sql = strings.Replace(sql, "sqlite_master", "sqlite_temp_master", 1)
......@@ -145,9 +145,9 @@ type Column struct {
func (c *Conn) Columns(dbName, table string) ([]Column, error) {
var pragma string
if len(dbName) == 0 {
pragma = Mprintf("PRAGMA table_info(%Q)", table)
pragma = fmt.Sprintf(`PRAGMA table_info("%s")`, escapeQuote(table))
} else {
pragma = Mprintf2("PRAGMA %Q.table_info(%Q)", dbName, table)
pragma = fmt.Sprintf(`PRAGMA %s.table_info("%s")`, doubleQuote(dbName), escapeQuote(table))
}
s, err := c.prepare(pragma)
if err != nil {
......@@ -286,9 +286,9 @@ type ForeignKey struct {
func (c *Conn) ForeignKeys(dbName, table string) (map[int]*ForeignKey, error) {
var pragma string
if len(dbName) == 0 {
pragma = Mprintf("PRAGMA foreign_key_list(%Q)", table)
pragma = fmt.Sprintf(`PRAGMA foreign_key_list("%s")`, escapeQuote(table))
} else {
pragma = Mprintf2("PRAGMA %Q.foreign_key_list(%Q)", dbName, table)
pragma = fmt.Sprintf(`PRAGMA %s.foreign_key_list("%s")`, doubleQuote(dbName), escapeQuote(table))
}
s, err := c.prepare(pragma)
if err != nil {
......@@ -331,9 +331,9 @@ type Index struct {
func (c *Conn) TableIndexes(dbName, table string) ([]Index, error) {
var pragma string
if len(dbName) == 0 {
pragma = Mprintf("PRAGMA index_list(%Q)", table)
pragma = fmt.Sprintf(`PRAGMA index_list("%s")`, escapeQuote(table))
} else {
pragma = Mprintf2("PRAGMA %Q.index_list(%Q)", dbName, table)
pragma = fmt.Sprintf(`PRAGMA %s.index_list("%s")`, doubleQuote(dbName), escapeQuote(table))
}
s, err := c.prepare(pragma)
if err != nil {
......@@ -362,9 +362,9 @@ func (c *Conn) TableIndexes(dbName, table string) ([]Index, error) {
func (c *Conn) IndexColumns(dbName, index string) ([]Column, error) {
var pragma string
if len(dbName) == 0 {
pragma = Mprintf("PRAGMA index_info(%Q)", index)
pragma = fmt.Sprintf(`PRAGMA index_info("%s")`, escapeQuote(index))
} else {
pragma = Mprintf2("PRAGMA %Q.index_info(%Q)", dbName, index)
pragma = fmt.Sprintf(`PRAGMA %s.index_info("%s")`, doubleQuote(dbName), escapeQuote(index))
}
s, err := c.prepare(pragma)
if err != nil {
......
......@@ -154,13 +154,13 @@ func (c *Conn) ForeignKeyCheck(dbName, table string) ([]FkViolation, error) {
if len(table) == 0 {
pragma = "PRAGMA foreign_key_check"
} else {
pragma = Mprintf("PRAGMA foreign_key_check(%Q)", table)
pragma = fmt.Sprintf(`PRAGMA foreign_key_check("%s")`, escapeQuote(table))
}
} else {
if len(table) == 0 {
pragma = Mprintf("PRAGMA %Q.foreign_key_check", dbName)
pragma = fmt.Sprintf("PRAGMA %s.foreign_key_check", doubleQuote(dbName))
} else {
pragma = Mprintf2("PRAGMA %Q.foreign_key_check(%Q)", dbName, table)
pragma = fmt.Sprintf(`PRAGMA %s.foreign_key_check("%s")`, doubleQuote(dbName), escapeQuote(table))
}
}
s, err := c.prepare(pragma)
......@@ -226,7 +226,10 @@ func pragma(dbName, pragmaName string) string {
if len(dbName) == 0 {
return "PRAGMA " + pragmaName
}
return Mprintf("PRAGMA %Q."+pragmaName, dbName)
if dbName == "main" || dbName == "temp" {
return fmt.Sprintf("PRAGMA %s.%s", dbName, pragmaName)
}
return fmt.Sprintf("PRAGMA %s.%s", doubleQuote(dbName), pragmaName)
}
func (c *Conn) oneValue(query string, value interface{}) error { // no cache
......
......@@ -398,7 +398,7 @@ func (s *Stmt) ExplainQueryPlan(w io.Writer) error {
if len(sql) == 0 {
return s.specificError("empty statement")
}
explain := Mprintf("EXPLAIN QUERY PLAN %s", s.SQL())
explain := "EXPLAIN QUERY PLAN " + s.SQL()
sExplain, err := s.Conn().prepare(explain)
if err != nil {
......
......@@ -12,14 +12,13 @@ package sqlite
static inline char *my_mprintf(char *zFormat, char *arg) {
return sqlite3_mprintf(zFormat, arg);
}
static inline char *my_mprintf2(char *zFormat, char *arg1, char *arg2) {
return sqlite3_mprintf(zFormat, arg1, arg2);
}
*/
import "C"
import (
"fmt"
"reflect"
"strings"
"unsafe"
)
......@@ -39,21 +38,6 @@ func mPrintf(format, arg string) *C.char { // TODO may return nil when no memory
return C.my_mprintf(cf, ca)
}
// Mprintf2 is like fmt.Printf but implements some additional formatting options
// that are useful for constructing SQL statements.
// (See http://sqlite.org/c3ref/mprintf.html)
func Mprintf2(format string, arg1, arg2 string) string {
cf := C.CString(format)
defer C.free(unsafe.Pointer(cf))
ca1 := C.CString(arg1)
defer C.free(unsafe.Pointer(ca1))
ca2 := C.CString(arg2)
defer C.free(unsafe.Pointer(ca2))
zSQL := C.my_mprintf2(cf, ca1, ca2) // TODO may return nil when no memory...
defer C.sqlite3_free(unsafe.Pointer(zSQL))
return C.GoString(zSQL)
}
func btocint(b bool) C.int {
if b {
return 1
......@@ -73,3 +57,16 @@ func gostring(cs *C.char) string {
return *(*string)(unsafe.Pointer(&x))
}
*/
func escapeQuote(identifier string) string {
if strings.ContainsRune(identifier, '"') { // escape quote by doubling them
identifier = strings.Replace(identifier, `"`, `""`, -1)
}
return identifier
}
func doubleQuote(dbName string) string {
if dbName == "main" || dbName == "temp" {
return dbName
}
return fmt.Sprintf(`"%s"`, escapeQuote(dbName)) // surround identifier with quote
}
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment