package main import ( "database/sql" "errors" "fmt" "io/ioutil" "log" "os" "path/filepath" "regexp" "sort" "strconv" "strings" _ "github.com/mattn/go-sqlite3" ) func (this *Application) connectToDB() error { db, err := sql.Open("sqlite3", this.c.DBPath) if err != nil { return err } this.db = db // Check current schema version currentSchemaVer, err := this.getCurrentSchemaVersion() if err != nil && (errors.Is(err, sql.ErrNoRows) || strings.Contains(err.Error(), `no such table: schema`)) { currentSchemaVer = 0 err = nil } if err != nil { return err } // Perform migrations err = this.applyMigrations(currentSchemaVer) if err != nil { return fmt.Errorf("applying migrations: %w", err) } // Done return nil } func (this *Application) getCurrentSchemaVersion() (int, error) { var maxId int err := this.db.QueryRow(`SELECT MAX(id) m FROM schema`).Scan(&maxId) return maxId, err } func (this *Application) applyMigrations(currentSchemaVer int) error { rx := regexp.MustCompile(`^([0-9\-]+).*\.sql$`) dh, err := os.Open(this.c.SchemaDir) if err != nil { return err } filenames, err := dh.Readdirnames(-1) if err != nil { return err } sort.Strings(filenames) log.Printf("DB Schema version %d, searching for migrations...", currentSchemaVer) applied := 0 for _, filename := range filenames { parts := rx.FindStringSubmatch(filename) if parts == nil { return fmt.Errorf("found file '%s' in %s directory not matching expected file format, aborting", filename, this.c.SchemaDir) } schemaVer, err := strconv.Atoi(strings.Replace(parts[1], `-`, "", -1)) if err != nil { return err } if currentSchemaVer >= schemaVer { continue // already applied } // Need to apply this schema fpath := filepath.Join(this.c.SchemaDir, filename) sqlFile, err := ioutil.ReadFile(fpath) if err != nil { return fmt.Errorf("loading '%s' for schema migration: %w", fpath, err) } // The SQLite driver does not support multiple SQL statements in a single Exec() call // Try to break it up into multiple by splitting on `);` sqlStmts := strings.Split(string(sqlFile), `);`) for i := 0; i < len(sqlStmts)-1; i++ { sqlStmts[i] += `);` } // Don't need to call Exec() if the trailing part is just blank if strings.TrimSpace(sqlStmts[len(sqlStmts)-1]) == "" { sqlStmts = sqlStmts[0 : len(sqlStmts)-1] } log.Printf("Applying schema migration '%s' (%d statement(s))", fpath, len(sqlStmts)) for _, stmt := range sqlStmts { _, err = this.db.Exec(stmt) if err != nil { return fmt.Errorf("applying schema '%s': %w", fpath, err) } } applied += 1 } if applied > 0 { log.Printf("Successfully applied %d schema migration(s)", applied) } else { log.Println("No new schema migrations to apply.") } // Done return nil }