yvbolt/sqliteclidriver/sqliteclidriver.go

315 lines
5.7 KiB
Go

// sqliteclidriver is a database/sql driver for SQLite implemented on top of
// the sqlite3 command-line tool. This allows it to be used over remote SSH
// connections.
//
// Functionality is limited.
//
// Known caveats:
// - Lexer only understands ? if it's separated by spaces
// - Bad error handling
// - Few supported types
// - Has to escape parameters for CLI instead of preparing them, so not safe for untrusted usage
// - No context handling
// - No way to configure sqlite3 command line path
package sqliteclidriver
import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"io"
"os/exec"
"yvbolt/lexer"
)
var ErrNotSupported = errors.New("Not supported")
//
type SCDriver struct{}
func (d *SCDriver) Open(connectionString string) (driver.Conn, error) {
// TODO support custom binpath from our connection string
cmd := exec.Command(`/usr/bin/sqlite3`, `-noheader`, `-json`, connectionString) // n.b. doesn't support `--`
pw, err := cmd.StdinPipe()
if err != nil {
return nil, err
}
pr, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
pe, err := cmd.StderrPipe()
if err != nil {
return nil, err
}
err = cmd.Start()
if err != nil {
return nil, err
}
return &SCConn{
stdout: pr,
stderr: pe,
w: pw,
}, nil
}
var _ driver.Driver = &SCDriver{} // interface assertion
func init() {
sql.Register("sqliteclidriver", &SCDriver{})
}
//
type SCConnector struct {
connectionString string
driver *SCDriver
}
func (c *SCConnector) Connect(context.Context) (driver.Conn, error) {
return c.driver.Open(c.connectionString)
}
func (c *SCConnector) Driver() driver.Driver {
return c.driver
}
var _ driver.Connector = &SCConnector{} // interface assertion
//
type SCConn struct {
stdout io.Reader
stderr io.Reader
w io.WriteCloser
}
func (c *SCConn) Prepare(query string) (driver.Stmt, error) {
f, err := lexer.Fields(query) // Rely on the query's escaping still being present
if err != nil {
return nil, err
}
if len(f) == 0 {
return nil, errors.New("Empty query")
}
return &SCStmt{
conn: c,
query: f,
}, nil
}
func (c *SCConn) Close() error {
return c.w.Close()
}
func (c *SCConn) Begin() (driver.Tx, error) {
return nil, ErrNotSupported
}
var _ driver.Conn = &SCConn{} // interface assertion
//
type SCStmt struct {
conn *SCConn
query []string
}
func (s *SCStmt) Close() error {
return nil
}
func (s *SCStmt) NumInput() int {
var ct int = 0
for _, token := range s.query {
if token == `?` {
ct++
}
}
return ct
}
func (s *SCStmt) buildQuery(args []driver.Value) ([]byte, error) {
// Embed query params
// WARNING: Not secure against injection? That's not a security boundary
// for the purposes of yvbolt, but maybe for other package users
var querybuilder bytes.Buffer
for _, token := range s.query {
if token == `?` {
if len(args) == 0 {
return nil, errors.New("Not enough arguments")
}
arg := args[0]
args = args[1:]
// Format the argument
if szarg, ok := arg.(string); ok {
querybuilder.WriteString(lexer.Quote(szarg))
} else if intarg, ok := arg.(int64); ok {
querybuilder.WriteString(fmt.Sprintf("%d", intarg))
} else if flarg, ok := arg.(float64); ok {
querybuilder.WriteString(fmt.Sprintf("%f", flarg))
} else {
return nil, fmt.Errorf("Parameter %v has unsupported type", arg)
}
} else {
// Normal token
querybuilder.WriteString(token)
}
// Whitespace
querybuilder.WriteByte(' ')
}
if len(args) != 0 {
return nil, errors.New("Too many arguments")
}
// Swap final whitespace for a \n
submit := querybuilder.Bytes()
submit[len(submit)-1] = '\n'
return submit, nil
}
func (s *SCStmt) Exec(args []driver.Value) (driver.Result, error) {
submit, err := s.buildQuery(args)
if err != nil {
return nil, err
}
_, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit)))
if err != nil {
return nil, fmt.Errorf("Write: %w", err)
}
// Drain stdout
// TODO
// Drain stderr
// TODO
return &SCResult{}, nil
}
func (s *SCStmt) Query(args []driver.Value) (driver.Rows, error) {
submit, err := s.buildQuery(args)
if err != nil {
return nil, err
}
_, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit)))
if err != nil {
return nil, fmt.Errorf("Write: %w", err)
}
// We expect some kind of thing on stdout
ret := []map[string]any{}
err = json.NewDecoder(s.conn.stdout).Decode(&ret)
if err != nil {
return nil, err
}
// Drain stderr
// TODO
// Come up with a canonical ordering for the columns after json.NewDecoder
// wiped it out for us
// FIXME use a different json parsing library
var columnNames []string
if len(ret) > 0 {
for k, _ := range ret[0] {
columnNames = append(columnNames, k)
}
}
//
return &SCRows{
idx: 0,
columns: columnNames,
data: ret,
}, nil
}
var _ driver.Stmt = &SCStmt{} // interface assertion
//
type SCResult struct{}
func (r *SCResult) LastInsertId() (int64, error) {
return 0, ErrNotSupported
}
func (r *SCResult) RowsAffected() (int64, error) {
return 0, ErrNotSupported
}
var _ driver.Result = &SCResult{} // interface assertion
//
type SCRows struct {
idx int
columns []string
data []map[string]any
}
func (r *SCRows) Columns() []string {
return r.columns
}
func (r *SCRows) Close() error {
r.idx = 0
r.data = nil
return nil
}
func (r *SCRows) Next(dest []driver.Value) error {
if r.idx >= len(r.data) {
return io.EOF
}
if len(dest) != len(r.data[r.idx]) {
return errors.New("Wrong number of arguments to Next()")
}
for i := 0; i < len(dest); i++ {
cell, ok := r.data[r.idx][r.columns[i]]
if !ok {
return fmt.Errorf("Result row %d is missing column #%d %q, unexpected", r.idx, i, r.columns[i])
}
dest[i] = cell
}
r.idx++
return nil
}