131 lines
2.8 KiB
Go
131 lines
2.8 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"errors"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"log"
|
|
"os"
|
|
"path/filepath"
|
|
"regexp"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
func (this *Application) connectToDB() error {
|
|
|
|
db, err := sql.Open("sqlite3", this.c.DBPath)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
this.db = db
|
|
|
|
// Check current schema version
|
|
currentSchemaVer, err := this.getCurrentSchemaVersion()
|
|
if err != nil && (errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), `no such table: schema`)) {
|
|
currentSchemaVer = 0
|
|
err = nil
|
|
}
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Perform migrations
|
|
err = this.applyMigrations(currentSchemaVer)
|
|
if err != nil {
|
|
return fmt.Errorf("applying migrations: %w", err)
|
|
}
|
|
|
|
// Done
|
|
return nil
|
|
}
|
|
|
|
func (this *Application) getCurrentSchemaVersion() (int, error) {
|
|
var maxId int
|
|
err := this.db.QueryRow(`SELECT MAX(id) m FROM schema`).Scan(&maxId)
|
|
return maxId, err
|
|
}
|
|
|
|
func (this *Application) applyMigrations(currentSchemaVer int) error {
|
|
|
|
rx := regexp.MustCompile(`^([0-9\-]+).*\.sql$`)
|
|
|
|
dh, err := os.Open(this.c.SchemaDir)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
filenames, err := dh.Readdirnames(-1)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
sort.Strings(filenames)
|
|
|
|
log.Printf("DB Schema version %d, searching for migrations...", currentSchemaVer)
|
|
|
|
applied := 0
|
|
|
|
for _, filename := range filenames {
|
|
|
|
parts := rx.FindStringSubmatch(filename)
|
|
if parts == nil {
|
|
return fmt.Errorf("found file '%s' in %s directory not matching expected file format, aborting", filename, this.c.SchemaDir)
|
|
}
|
|
|
|
schemaVer, err := strconv.Atoi(strings.Replace(parts[1], `-`, "", -1))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if currentSchemaVer >= schemaVer {
|
|
continue // already applied
|
|
}
|
|
|
|
// Need to apply this schema
|
|
|
|
fpath := filepath.Join(this.c.SchemaDir, filename)
|
|
sqlFile, err := ioutil.ReadFile(fpath)
|
|
if err != nil {
|
|
return fmt.Errorf("loading '%s' for schema migration: %w", fpath, err)
|
|
}
|
|
|
|
// The SQLite driver does not support multiple SQL statements in a single Exec() call
|
|
// Try to break it up into multiple by splitting on `);`
|
|
sqlStmts := strings.Split(string(sqlFile), `);`)
|
|
for i := 0; i < len(sqlStmts)-1; i++ {
|
|
sqlStmts[i] += `);`
|
|
}
|
|
|
|
// Don't need to call Exec() if the trailing part is just blank
|
|
if strings.TrimSpace(sqlStmts[len(sqlStmts)-1]) == "" {
|
|
sqlStmts = sqlStmts[0 : len(sqlStmts)-1]
|
|
}
|
|
|
|
log.Printf("Applying schema migration '%s' (%d statement(s))", fpath, len(sqlStmts))
|
|
|
|
for _, stmt := range sqlStmts {
|
|
_, err = this.db.Exec(stmt)
|
|
if err != nil {
|
|
return fmt.Errorf("applying schema '%s': %w", fpath, err)
|
|
}
|
|
}
|
|
|
|
applied += 1
|
|
}
|
|
|
|
if applied > 0 {
|
|
log.Printf("Successfully applied %d schema migration(s)", applied)
|
|
} else {
|
|
log.Println("No new schema migrations to apply.")
|
|
}
|
|
|
|
// Done
|
|
return nil
|
|
}
|