From d2b9618da03c86d16b9169c823d29c7baa07b245 Mon Sep 17 00:00:00 2001 From: mappu Date: Sat, 29 Jun 2024 11:56:45 +1200 Subject: [PATCH] sqliteclidriver: initial commit --- go.mod | 4 + go.sum | 3 + sqliteclidriver/sqliteclidriver.go | 305 ++++++++++++++++++++++++ sqliteclidriver/sqliteclidriver_test.go | 52 ++++ 4 files changed, 364 insertions(+) create mode 100644 sqliteclidriver/sqliteclidriver.go create mode 100644 sqliteclidriver/sqliteclidriver_test.go diff --git a/go.mod b/go.mod index 0df8405..c79891c 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.22 github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c github.com/redis/go-redis/v9 v9.5.3 + github.com/stretchr/testify v1.9.0 github.com/ying32/govcl v2.2.3+incompatible go.etcd.io/bbolt v1.4.0-alpha.1 modernc.org/sqlite v1.24.0 @@ -20,6 +21,7 @@ require ( github.com/cockroachdb/errors v1.11.3 // indirect github.com/cockroachdb/logtags v0.0.0-20230118201751-21c54148d20b // indirect github.com/cockroachdb/redact v1.1.5 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgraph-io/ristretto v0.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dustin/go-humanize v1.0.1 // indirect @@ -38,6 +40,7 @@ require ( github.com/mattn/go-isatty v0.0.17 // indirect github.com/matttproud/golang_protobuf_extensions v1.0.2-0.20181231171920-c182affec369 // indirect github.com/pkg/errors v0.9.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/prometheus/client_golang v1.12.0 // indirect github.com/prometheus/client_model v0.2.1-0.20210607210712-147c58e9608a // indirect github.com/prometheus/common v0.32.1 // indirect @@ -52,6 +55,7 @@ require ( golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.13.0 // indirect google.golang.org/protobuf v1.33.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect lukechampine.com/uint128 v1.2.0 // indirect modernc.org/cc/v3 v3.40.0 // indirect modernc.org/ccgo/v3 v3.16.13 // indirect diff --git a/go.sum b/go.sum index 8bdf476..2afee30 100644 --- a/go.sum +++ b/go.sum @@ -256,6 +256,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= +github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/ying32/govcl v2.2.3+incompatible h1:Iyfcl26yNE1USm+3uG+btQyhkoFIV18+VITrUdHu8Lw= github.com/ying32/govcl v2.2.3+incompatible/go.mod h1:yZVtbJ9Md1nAVxtHKIriKZn4K6TQYqI1en3sN/m9FJ8= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= @@ -543,6 +544,7 @@ gopkg.in/alecthomas/kingpin.v2 v2.2.6/go.mod h1:FMv+mEhP44yOT+4EoQTLFTRgOQ1FBLks gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.1/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= @@ -551,6 +553,7 @@ gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/sqliteclidriver/sqliteclidriver.go b/sqliteclidriver/sqliteclidriver.go new file mode 100644 index 0000000..cfb32f0 --- /dev/null +++ b/sqliteclidriver/sqliteclidriver.go @@ -0,0 +1,305 @@ +// sqliteclidriver is a database/sql driver for SQLite implemented on top of +// the sqlite3 command-line tool. This allows it to be used over remote SSH +// connections. +// Functionality is limited. +package sqliteclidriver + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "fmt" + "io" + "os/exec" + "yvbolt/lexer" +) + +var ErrNotSupported = errors.New("Not supported") + +// + +type SCDriver struct{} + +func (d *SCDriver) Open(connectionString string) (driver.Conn, error) { + // TODO support custom binpath from our connection string + + cmd := exec.Command(`/usr/bin/sqlite3`, `-noheader`, `-json`, connectionString) // n.b. doesn't support `--` + + pw, err := cmd.StdinPipe() + if err != nil { + return nil, err + } + + pr, err := cmd.StdoutPipe() + if err != nil { + return nil, err + } + + pe, err := cmd.StderrPipe() + if err != nil { + return nil, err + } + + err = cmd.Start() + if err != nil { + return nil, err + } + + return &SCConn{ + stdout: pr, + stderr: pe, + w: pw, + }, nil +} + +var _ driver.Driver = &SCDriver{} // interface assertion + +func init() { + sql.Register("sqliteclidriver", &SCDriver{}) +} + +// + +type SCConnector struct { + connectionString string + driver *SCDriver +} + +func (c *SCConnector) Connect(context.Context) (driver.Conn, error) { + return c.driver.Open(c.connectionString) +} + +func (c *SCConnector) Driver() driver.Driver { + return c.driver +} + +var _ driver.Connector = &SCConnector{} // interface assertion + +// + +type SCConn struct { + stdout io.Reader + stderr io.Reader + w io.WriteCloser +} + +func (c *SCConn) Prepare(query string) (driver.Stmt, error) { + f, err := lexer.Fields(query) // Rely on the query's escaping still being present + if err != nil { + return nil, err + } + + if len(f) == 0 { + return nil, errors.New("Empty query") + } + + return &SCStmt{ + conn: c, + query: f, + }, nil +} + +func (c *SCConn) Close() error { + return c.w.Close() +} + +func (c *SCConn) Begin() (driver.Tx, error) { + return nil, ErrNotSupported +} + +var _ driver.Conn = &SCConn{} // interface assertion + +// + +type SCStmt struct { + conn *SCConn + query []string +} + +func (s *SCStmt) Close() error { + return nil +} + +func (s *SCStmt) NumInput() int { + var ct int = 0 + for _, token := range s.query { + if token == `?` { + ct++ + } + } + + return ct +} + +func (s *SCStmt) buildQuery(args []driver.Value) ([]byte, error) { + + // Embed query params + // WARNING: Not secure against injection? That's not a security boundary + // for the purposes of yvbolt, but maybe for other package users + + var querybuilder bytes.Buffer + + for _, token := range s.query { + if token == `?` { + if len(args) == 0 { + return nil, errors.New("Not enough arguments") + } + + arg := args[0] + args = args[1:] + + // Format the argument + if szarg, ok := arg.(string); ok { + querybuilder.WriteString(lexer.Quote(szarg)) + + } else if intarg, ok := arg.(int64); ok { + querybuilder.WriteString(fmt.Sprintf("%d", intarg)) + + } else if flarg, ok := arg.(float64); ok { + querybuilder.WriteString(fmt.Sprintf("%f", flarg)) + + } else { + return nil, fmt.Errorf("Parameter %v has unsupported type", arg) + } + + } else { + // Normal token + querybuilder.WriteString(token) + } + + // Whitespace + querybuilder.WriteByte(' ') + } + + if len(args) != 0 { + return nil, errors.New("Too many arguments") + } + + // Swap final whitespace for a \n + submit := querybuilder.Bytes() + submit[len(submit)-1] = '\n' + + return submit, nil +} + +func (s *SCStmt) Exec(args []driver.Value) (driver.Result, error) { + + submit, err := s.buildQuery(args) + if err != nil { + return nil, err + } + + _, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit))) + if err != nil { + return nil, fmt.Errorf("Write: %w", err) + } + + // Drain stdout + // TODO + + // Drain stderr + // TODO + + return &SCResult{}, nil +} + +func (s *SCStmt) Query(args []driver.Value) (driver.Rows, error) { + + submit, err := s.buildQuery(args) + if err != nil { + return nil, err + } + + _, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit))) + if err != nil { + return nil, fmt.Errorf("Write: %w", err) + } + + // We expect some kind of thing on stdout + ret := []map[string]any{} + err = json.NewDecoder(s.conn.stdout).Decode(&ret) + if err != nil { + return nil, err + } + + // Drain stderr + // TODO + + // Come up with a canonical ordering for the columns after json.NewDecoder + // wiped it out for us + // FIXME use a different json parsing library + + var columnNames []string + if len(ret) > 0 { + for k, _ := range ret[0] { + columnNames = append(columnNames, k) + } + } + + // + + return &SCRows{ + idx: 0, + columns: columnNames, + data: ret, + }, nil +} + +var _ driver.Stmt = &SCStmt{} // interface assertion + +// + +type SCResult struct{} + +func (r *SCResult) LastInsertId() (int64, error) { + return 0, ErrNotSupported +} + +func (r *SCResult) RowsAffected() (int64, error) { + return 0, ErrNotSupported +} + +var _ driver.Result = &SCResult{} // interface assertion + +// + +type SCRows struct { + idx int + columns []string + data []map[string]any +} + +func (r *SCRows) Columns() []string { + return r.columns +} + +func (r *SCRows) Close() error { + r.idx = 0 + r.data = nil + return nil +} + +func (r *SCRows) Next(dest []driver.Value) error { + if r.idx >= len(r.data) { + return io.EOF + } + + if len(dest) != len(r.data[r.idx]) { + return errors.New("Wrong number of arguments to Next()") + } + + for i := 0; i < len(dest); i++ { + cell, ok := r.data[r.idx][r.columns[i]] + if !ok { + return fmt.Errorf("Result row %d is missing column #%d %q, unexpected", r.idx, i, r.columns[i]) + } + + dest[i] = cell + } + + r.idx++ + return nil +} diff --git a/sqliteclidriver/sqliteclidriver_test.go b/sqliteclidriver/sqliteclidriver_test.go new file mode 100644 index 0000000..b531100 --- /dev/null +++ b/sqliteclidriver/sqliteclidriver_test.go @@ -0,0 +1,52 @@ +package sqliteclidriver + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSqliteCliDriver(t *testing.T) { + db, err := sql.Open("sqliteclidriver", ":memory:") + require.NoError(t, err) + + _, err = db.Exec(`CREATE TABLE my_test_table ( id INTEGER PRIMARY KEY, extra TEXT NOT NULL );`) + require.NoError(t, err) + + _, err = db.Exec(`INSERT INTO my_test_table (id, extra) VALUES (1337, "abcdef"), (9001, "whoop");`) + require.NoError(t, err) + + res, err := db.Query(`SELECT * FROM my_test_table ORDER BY id ASC;`) + require.NoError(t, err) + + cols, err := res.Columns() + require.NoError(t, err) + require.EqualValues(t, cols, []string{"id", "extra"}) + + var rowCount int = 0 + + for res.Next() { + rowCount++ + + var idVal int + var extraVal string + err = res.Scan(&idVal, &extraVal) + if err != nil { + t.Fatal(err) + } + + switch rowCount { + case 1: + require.EqualValues(t, 1337, idVal) + require.EqualValues(t, "abcdef", extraVal) + + case 2: + require.EqualValues(t, 9001, idVal) + require.EqualValues(t, "whoop", extraVal) + + } + } + + require.Equal(t, rowCount, 2) +}