sqliteclidriver: use channel events, handle no results via sentinel

This commit is contained in:
mappu 2024-06-30 12:33:47 +12:00
parent be91cd54c6
commit 7cec5cee4c
4 changed files with 108 additions and 25 deletions

2
go.mod
View File

@ -5,6 +5,7 @@ go 1.19
require (
github.com/cockroachdb/pebble v1.0.0
github.com/dgraph-io/badger/v4 v4.2.0
github.com/google/uuid v1.6.0
github.com/mattn/go-sqlite3 v1.14.22
github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c
github.com/redis/go-redis/v9 v9.5.3
@ -32,7 +33,6 @@ require (
github.com/golang/protobuf v1.5.3 // indirect
github.com/golang/snappy v0.0.4 // indirect
github.com/google/flatbuffers v1.12.1 // indirect
github.com/google/uuid v1.3.0 // indirect
github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect
github.com/klauspost/compress v1.16.0 // indirect
github.com/kr/pretty v0.3.1 // indirect

4
go.sum
View File

@ -160,8 +160,8 @@ github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg=
github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=

View File

@ -5,7 +5,6 @@
// Functionality is limited.
//
// Known caveats:
// - Lexer only understands ? if it's separated by spaces
// - Bad error handling
// - Few supported types
// - Has to escape parameters for CLI instead of preparing them, so not safe for untrusted usage
@ -23,7 +22,10 @@ import (
"fmt"
"io"
"os/exec"
"yvbolt/lexer"
"github.com/google/uuid"
)
var ErrNotSupported = errors.New("Not supported")
@ -37,29 +39,13 @@ func (d *SCDriver) Open(connectionString string) (driver.Conn, error) {
cmd := exec.Command(`/usr/bin/sqlite3`, `-noheader`, `-json`, connectionString) // n.b. doesn't support `--`
pw, err := cmd.StdinPipe()
if err != nil {
return nil, err
}
pr, err := cmd.StdoutPipe()
if err != nil {
return nil, err
}
pe, err := cmd.StderrPipe()
if err != nil {
return nil, err
}
err = cmd.Start()
chEvents, pw, err := ExecEvents(cmd)
if err != nil {
return nil, err
}
return &SCConn{
stdout: pr,
stderr: pe,
listen: chEvents,
w: pw,
}, nil
}
@ -90,8 +76,7 @@ var _ driver.Connector = &SCConnector{} // interface assertion
//
type SCConn struct {
stdout io.Reader
stderr io.Reader
listen <-chan processEvent
w io.WriteCloser
}
@ -105,6 +90,10 @@ func (c *SCConn) Prepare(query string) (driver.Stmt, error) {
return nil, errors.New("Empty query")
}
if f[len(f)-1] != ";" {
f = append(f, ";") // Query must end in semicolon
}
return &SCStmt{
conn: c,
query: f,
@ -216,24 +205,97 @@ func (s *SCStmt) Exec(args []driver.Value) (driver.Result, error) {
}
func (s *SCStmt) Query(args []driver.Value) (driver.Rows, error) {
ctx := context.Background()
submit, err := s.buildQuery(args)
if err != nil {
return nil, err
}
// If there are no results to the query, the sqlite3 -json mode does not
// print anything on stdout at all and we would hang forever
// Add a followup sentinel query that we can detect
const sentinelKey = `__sqliteclidriver_sentinel`
sentinelVal := uuid.Must(uuid.NewRandom()).String()
submit = append(submit, []byte(fmt.Sprintf("SELECT \"%s\" AS %s;\n", sentinelVal, sentinelKey))...)
//
_, err = io.CopyN(s.conn.w, bytes.NewReader(submit), int64(len(submit)))
if err != nil {
return nil, fmt.Errorf("Write: %w", err)
}
// Consume process events until either error or the json decoder is satisfied
pr, pw := io.Pipe()
listenContext, listenContextCancel := context.WithCancel(ctx) // Use to stop signalling once json decoder is satisfied
go func() {
defer pw.Close()
for {
select {
case msg, ok := <-s.conn.listen:
if !ok {
pw.CloseWithError(fmt.Errorf("process already closed"))
return
}
if msg.err != nil {
pw.CloseWithError(msg.err)
return
}
if msg.evtype == evtypeStdout {
_, err := io.CopyN(pw, bytes.NewReader(msg.data), int64(len(msg.data)))
if err != nil {
pw.CloseWithError(msg.err)
return
}
} else {
// Anything else (process event / stderr)
// Throw
pw.CloseWithError(fmt.Errorf("other thing %#v", msg))
return
}
case <-listenContext.Done():
return
}
}
}()
// We expect some kind of thing on stdout
// If something happens on stderr, or to the process, pr will read an error
ret := []map[string]any{}
err = json.NewDecoder(s.conn.stdout).Decode(&ret)
decoder := json.NewDecoder(pr)
err = decoder.Decode(&ret)
if err != nil {
return nil, err
}
// Check if this was the data or the sentinel
wasSentinel := false
if len(ret) > 0 {
if val, ok := ret[0][sentinelKey]; ok {
if check, ok := val.(string); ok && check == sentinelVal {
// It was the sentinel
wasSentinel = true
// Nothing more to parse
}
}
}
if !wasSentinel {
// Need to decode again (from the same decoder reader) until we find the sentinel
surplus := []map[string]any{}
err = decoder.Decode(&surplus)
if err != nil {
return nil, err
}
}
listenContextCancel()
// Drain stderr
// TODO

View File

@ -55,3 +55,24 @@ func TestSqliteCliDriver(t *testing.T) {
}
}
func TestSqliteCliDriverNoResults(t *testing.T) {
db, err := sql.Open("sqliteclidriver", ":memory:")
require.NoError(t, err)
// Repeat this part to ensure we can make followup queries on the same connection
for i := 0; i < 3; i++ {
_, err = db.Query(`SELECT 1 AS expect_no_result WHERE 1=2`)
require.NoError(t, err)
// Mix of results and no-results
rr := db.QueryRow(`SELECT 1 AS expect_result WHERE 1=1`)
require.NoError(t, rr.Err())
var result int64
err = rr.Scan(&result)
require.NoError(t, err)
require.EqualValues(t, result, 1)
}
}