webscaffold/db.go

131 lines
2.8 KiB
Go

package main
import (
"database/sql"
"errors"
"fmt"
"io/ioutil"
"log"
"os"
"path/filepath"
"regexp"
"sort"
"strconv"
"strings"
_ "github.com/mattn/go-sqlite3"
)
func (this *Application) connectToDB() error {
db, err := sql.Open("sqlite3", this.c.DBPath)
if err != nil {
return err
}
this.db = db
// Check current schema version
currentSchemaVer, err := this.getCurrentSchemaVersion()
if err != nil && (errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), `no such table: schema`)) {
currentSchemaVer = 0
err = nil
}
if err != nil {
return err
}
// Perform migrations
err = this.applyMigrations(currentSchemaVer)
if err != nil {
return fmt.Errorf("applying migrations: %w", err)
}
// Done
return nil
}
func (this *Application) getCurrentSchemaVersion() (int, error) {
var maxId int
err := this.db.QueryRow(`SELECT MAX(id) m FROM schema`).Scan(&maxId)
return maxId, err
}
func (this *Application) applyMigrations(currentSchemaVer int) error {
rx := regexp.MustCompile(`^([0-9\-]+).*\.sql$`)
dh, err := os.Open(this.c.SchemaDir)
if err != nil {
return err
}
filenames, err := dh.Readdirnames(-1)
if err != nil {
return err
}
sort.Strings(filenames)
log.Printf("DB Schema version %d, searching for migrations...", currentSchemaVer)
applied := 0
for _, filename := range filenames {
parts := rx.FindStringSubmatch(filename)
if parts == nil {
return fmt.Errorf("found file '%s' in %s directory not matching expected file format, aborting", filename, this.c.SchemaDir)
}
schemaVer, err := strconv.Atoi(strings.Replace(parts[1], `-`, "", -1))
if err != nil {
return err
}
if currentSchemaVer >= schemaVer {
continue // already applied
}
// Need to apply this schema
fpath := filepath.Join(this.c.SchemaDir, filename)
sqlFile, err := ioutil.ReadFile(fpath)
if err != nil {
return fmt.Errorf("loading '%s' for schema migration: %w", fpath, err)
}
// The SQLite driver does not support multiple SQL statements in a single Exec() call
// Try to break it up into multiple by splitting on `);`
sqlStmts := strings.Split(string(sqlFile), `);`)
for i := 0; i < len(sqlStmts)-1; i++ {
sqlStmts[i] += `);`
}
// Don't need to call Exec() if the trailing part is just blank
if strings.TrimSpace(sqlStmts[len(sqlStmts)-1]) == "" {
sqlStmts = sqlStmts[0 : len(sqlStmts)-1]
}
log.Printf("Applying schema migration '%s' (%d statement(s))", fpath, len(sqlStmts))
for _, stmt := range sqlStmts {
_, err = this.db.Exec(stmt)
if err != nil {
return fmt.Errorf("applying schema '%s': %w", fpath, err)
}
}
applied += 1
}
if applied > 0 {
log.Printf("Successfully applied %d schema migration(s)", applied)
} else {
log.Println("No new schema migrations to apply.")
}
// Done
return nil
}