379 lines
7.4 KiB
Go
379 lines
7.4 KiB
Go
// sqliteclidriver is a database/sql driver for SQLite implemented on top of
|
|
// the sqlite3 command-line tool. This allows it to be used over remote SSH
|
|
// connections.
|
|
//
|
|
// Functionality is limited.
|
|
//
|
|
// Known caveats:
|
|
// - Bad error handling
|
|
// - Few supported types
|
|
// - Has to escape parameters for CLI instead of preparing them, so not safe for untrusted usage
|
|
// - No context handling
|
|
// - No way to configure sqlite3 command line path
|
|
package sqliteclidriver
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"os/exec"
|
|
|
|
"yvbolt/lexer"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
var ErrNotSupported = errors.New("Not supported")
|
|
|
|
//
|
|
|
|
type SCDriver struct{}
|
|
|
|
func (d *SCDriver) Open(connectionString string) (driver.Conn, error) {
|
|
// TODO support custom binpath from our connection string
|
|
|
|
cmd := exec.Command(`/usr/bin/sqlite3`, `-noheader`, `-json`, connectionString) // n.b. doesn't support `--`
|
|
|
|
chEvents, pw, err := ExecEvents(cmd)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &SCConn{
|
|
listen: chEvents,
|
|
w: pw,
|
|
}, nil
|
|
}
|
|
|
|
var _ driver.Driver = &SCDriver{} // interface assertion
|
|
|
|
func init() {
|
|
sql.Register("sqliteclidriver", &SCDriver{})
|
|
}
|
|
|
|
//
|
|
|
|
type SCConnector struct {
|
|
connectionString string
|
|
driver *SCDriver
|
|
}
|
|
|
|
func (c *SCConnector) Connect(context.Context) (driver.Conn, error) {
|
|
return c.driver.Open(c.connectionString)
|
|
}
|
|
|
|
func (c *SCConnector) Driver() driver.Driver {
|
|
return c.driver
|
|
}
|
|
|
|
var _ driver.Connector = &SCConnector{} // interface assertion
|
|
|
|
//
|
|
|
|
type SCConn struct {
|
|
listen <-chan processEvent
|
|
w io.WriteCloser
|
|
}
|
|
|
|
func (c *SCConn) Prepare(query string) (driver.Stmt, error) {
|
|
f, err := lexer.Fields(query) // Rely on the query's escaping still being present
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(f) == 0 {
|
|
return nil, errors.New("Empty query")
|
|
}
|
|
|
|
if f[len(f)-1] != ";" {
|
|
f = append(f, ";") // Query must end in semicolon
|
|
}
|
|
|
|
return &SCStmt{
|
|
conn: c,
|
|
query: f,
|
|
}, nil
|
|
}
|
|
|
|
func (c *SCConn) Close() error {
|
|
return c.w.Close()
|
|
}
|
|
|
|
func (c *SCConn) Begin() (driver.Tx, error) {
|
|
return nil, ErrNotSupported
|
|
}
|
|
|
|
var _ driver.Conn = &SCConn{} // interface assertion
|
|
|
|
//
|
|
|
|
type SCStmt struct {
|
|
conn *SCConn
|
|
query []string
|
|
}
|
|
|
|
func (s *SCStmt) Close() error {
|
|
return nil
|
|
}
|
|
|
|
func (s *SCStmt) NumInput() int {
|
|
var ct int = 0
|
|
for _, token := range s.query {
|
|
if token == `?` {
|
|
ct++
|
|
}
|
|
}
|
|
|
|
return ct
|
|
}
|
|
|
|
func (s *SCStmt) buildQuery(args []driver.Value) ([]byte, error) {
|
|
|
|
// Embed query params
|
|
// WARNING: Not secure against injection? That's not a security boundary
|
|
// for the purposes of yvbolt, but maybe for other package users
|
|
|
|
var querybuilder bytes.Buffer
|
|
|
|
for _, token := range s.query {
|
|
if token == `?` {
|
|
if len(args) == 0 {
|
|
return nil, errors.New("Not enough arguments")
|
|
}
|
|
|
|
arg := args[0]
|
|
args = args[1:]
|
|
|
|
// Format the argument
|
|
if szarg, ok := arg.(string); ok {
|
|
querybuilder.WriteString(lexer.Quote(szarg))
|
|
|
|
} else if intarg, ok := arg.(int64); ok {
|
|
querybuilder.WriteString(fmt.Sprintf("%d", intarg))
|
|
|
|
} else if flarg, ok := arg.(float64); ok {
|
|
querybuilder.WriteString(fmt.Sprintf("%f", flarg))
|
|
|
|
} else {
|
|
return nil, fmt.Errorf("Parameter %v has unsupported type", arg)
|
|
}
|
|
|
|
} else {
|
|
// Normal token
|
|
querybuilder.WriteString(token)
|
|
}
|
|
|
|
// Whitespace
|
|
querybuilder.WriteByte(' ')
|
|
}
|
|
|
|
if len(args) != 0 {
|
|
return nil, errors.New("Too many arguments")
|
|
}
|
|
|
|
// Swap final whitespace for a \n
|
|
submit := querybuilder.Bytes()
|
|
submit[len(submit)-1] = '\n'
|
|
|
|
return submit, nil
|
|
}
|
|
|
|
func (s *SCStmt) Exec(args []driver.Value) (driver.Result, error) {
|
|
|
|
submit, err := s.buildQuery(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Write: %w", err)
|
|
}
|
|
|
|
// Drain stdout
|
|
// TODO
|
|
|
|
// Drain stderr
|
|
// TODO
|
|
|
|
return &SCResult{}, nil
|
|
}
|
|
|
|
func (s *SCStmt) Query(args []driver.Value) (driver.Rows, error) {
|
|
ctx := context.Background()
|
|
|
|
submit, err := s.buildQuery(args)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// If there are no results to the query, the sqlite3 -json mode does not
|
|
// print anything on stdout at all and we would hang forever
|
|
// Add a followup sentinel query that we can detect
|
|
const sentinelKey = `__sqliteclidriver_sentinel`
|
|
sentinelVal := uuid.Must(uuid.NewRandom()).String()
|
|
submit = append(submit, []byte(fmt.Sprintf("SELECT \"%s\" AS %s;\n", sentinelVal, sentinelKey))...)
|
|
|
|
//
|
|
|
|
_, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit)))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Write: %w", err)
|
|
}
|
|
|
|
// Consume process events until either error or the json decoder is satisfied
|
|
pr, pw := io.Pipe()
|
|
|
|
listenContext, listenContextCancel := context.WithCancel(ctx) // Use to stop signalling once json decoder is satisfied
|
|
|
|
go func() {
|
|
defer pw.Close()
|
|
for {
|
|
select {
|
|
case msg, ok := <-s.conn.listen:
|
|
if !ok {
|
|
pw.CloseWithError(fmt.Errorf("process already closed"))
|
|
return
|
|
}
|
|
if msg.err != nil {
|
|
pw.CloseWithError(msg.err)
|
|
return
|
|
}
|
|
|
|
if msg.evtype == evtypeStdout {
|
|
_, err := io.CopyN(pw, bytes.NewReader(msg.data), int64(len(msg.data)))
|
|
if err != nil {
|
|
pw.CloseWithError(msg.err)
|
|
return
|
|
}
|
|
|
|
} else {
|
|
// Anything else (process event / stderr)
|
|
// Throw
|
|
pw.CloseWithError(msg)
|
|
return
|
|
}
|
|
|
|
case <-listenContext.Done():
|
|
return
|
|
}
|
|
}
|
|
}()
|
|
|
|
// We expect some kind of thing on stdout
|
|
// If something happens on stderr, or to the process, pr will read an error
|
|
ret := []OrderedKV{}
|
|
decoder := json.NewDecoder(pr)
|
|
err = decoder.Decode(&ret)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check if this was the data or the sentinel
|
|
wasSentinel := false
|
|
if len(ret) > 0 && len(ret[0]) > 0 && ret[0][0].Key == sentinelKey {
|
|
if check, ok := ret[0][0].Value.(string); ok && check == sentinelVal {
|
|
// It was the sentinel
|
|
wasSentinel = true
|
|
// Nothing more to parse
|
|
}
|
|
}
|
|
|
|
if wasSentinel {
|
|
// There was no data.
|
|
// Wipe out `ret`
|
|
ret = nil
|
|
|
|
} else {
|
|
// There was data.
|
|
// Need to decode again (from the same decoder reader) until we find the sentinel
|
|
surplus := []map[string]any{}
|
|
err = decoder.Decode(&surplus)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
listenContextCancel()
|
|
|
|
// Drain stderr
|
|
// TODO
|
|
|
|
// Come up with a canonical ordering for the columns after json.NewDecoder
|
|
// wiped it out for us
|
|
// FIXME use a different json parsing library
|
|
|
|
var columnNames []string
|
|
if len(ret) > 0 {
|
|
for _, cell := range ret[0] {
|
|
columnNames = append(columnNames, cell.Key)
|
|
}
|
|
}
|
|
|
|
//
|
|
|
|
return &SCRows{
|
|
idx: 0,
|
|
columns: columnNames,
|
|
data: ret,
|
|
}, nil
|
|
}
|
|
|
|
var _ driver.Stmt = &SCStmt{} // interface assertion
|
|
|
|
//
|
|
|
|
type SCResult struct{}
|
|
|
|
func (r *SCResult) LastInsertId() (int64, error) {
|
|
return 0, ErrNotSupported
|
|
}
|
|
|
|
func (r *SCResult) RowsAffected() (int64, error) {
|
|
return 0, ErrNotSupported
|
|
}
|
|
|
|
var _ driver.Result = &SCResult{} // interface assertion
|
|
|
|
//
|
|
|
|
type SCRows struct {
|
|
idx int
|
|
columns []string
|
|
data []OrderedKV
|
|
}
|
|
|
|
func (r *SCRows) Columns() []string {
|
|
return r.columns
|
|
}
|
|
|
|
func (r *SCRows) Close() error {
|
|
r.idx = 0
|
|
r.data = nil
|
|
return nil
|
|
}
|
|
|
|
func (r *SCRows) Next(dest []driver.Value) error {
|
|
if r.idx >= len(r.data) {
|
|
return io.EOF
|
|
}
|
|
|
|
if len(dest) != len(r.data[r.idx]) {
|
|
return errors.New("Wrong number of arguments to Next()")
|
|
}
|
|
|
|
for i := 0; i < len(dest); i++ {
|
|
cell := r.data[r.idx][i]
|
|
|
|
dest[i] = cell.Value
|
|
}
|
|
|
|
r.idx++
|
|
return nil
|
|
}
|