summaryrefslogtreecommitdiff
path: root/model/database.go
diff options
context:
space:
mode:
Diffstat (limited to 'model/database.go')
-rw-r--r--model/database.go111
1 files changed, 72 insertions, 39 deletions
diff --git a/model/database.go b/model/database.go
index 7490f7b..84ed497 100644
--- a/model/database.go
+++ b/model/database.go
@@ -2,12 +2,11 @@ package model
import (
"database/sql"
+ "embed"
"fmt"
"log"
_ "github.com/mattn/go-sqlite3"
-
- "xengineering.eu/ceres/model/migrations"
)
type DB sql.DB
@@ -48,8 +47,6 @@ func (db *DB) Transaction(f func(*sql.Tx) error) error {
}
func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) {
- var number int
-
cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'`
rows, err := tx.Query(cmd)
if err != nil {
@@ -61,6 +58,7 @@ func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) {
return false, fmt.Errorf("Result set is empty")
}
+ var number int
err = rows.Scan(&number)
if err != nil {
return false, fmt.Errorf("Failed to scan numerical value: %w", err)
@@ -69,69 +67,104 @@ func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) {
return number == 0, nil
}
-func (db *DB) setupMinimal(tx *sql.Tx, execVersion string) error {
- cmd := `
-CREATE TABLE metadata (
- key TEXT PRIMARY KEY,
- value TEXT
-);
-INSERT INTO metadata
- (key, value)
-VALUES
- ('version', ?);
-`
- _, err := tx.Exec(cmd, execVersion)
- return err
-}
-
func (db *DB) SchemaVersion(tx *sql.Tx) (int, error) {
empty, err := db.IsEmpty(tx)
if err != nil {
return 0, fmt.Errorf("Failed to check if DB is empty: %w", err)
}
-
if empty {
return 0, nil
}
rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`)
if err != nil {
- return 0, fmt.Errorf("Select call failed: %w", err)
+ return 0, fmt.Errorf("Select call for version failed: %w", err)
}
defer rows.Close()
-
if rows.Next() {
return 1, nil // version field was only present in one schema version
}
- return 0, fmt.Errorf("Unknown schema version")
+ rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`)
+ if err != nil {
+ return 0, fmt.Errorf("Select call for schema_version failed: %w", err)
+ }
+ defer rows.Close()
+ if !rows.Next() {
+ return 0, fmt.Errorf("No schema_version entry in metadata table")
+ }
+ var number int
+ err = rows.Scan(&number)
+ if err != nil {
+ return 0, fmt.Errorf("Failed to scan schema_version: %w", err)
+ }
+
+ return number, nil
+}
+
+//go:embed sql/migration*.sql
+var migrationSQL embed.FS
+
+func getMigrations() ([]func(tx *sql.Tx) error, error) {
+ migrations := make([]func(tx *sql.Tx) error, 0)
+
+ entries, err := migrationSQL.ReadDir("sql")
+ if err != nil {
+ return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err)
+ }
+ amount := len(entries)
+
+ for i := 0; i < amount; i++ {
+ file := fmt.Sprintf("sql/migration%03d.sql", i)
+ data, err := migrationSQL.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to read migration SQL code: %w", err)
+ }
+ migrations = append(
+ migrations,
+ func(tx *sql.Tx) error {
+ _, err := tx.Exec(string(data))
+ return err
+ },
+ )
+ }
+
+ return migrations, nil
}
-func (db *DB) Migrate(execVersion string) error {
+func (db *DB) Migrate() error {
+ migrations, err := getMigrations()
+ if err != nil {
+ return fmt.Errorf("Failed to get migrations: %w", err)
+ }
+
return db.Transaction(func(tx *sql.Tx) error {
- for {
- schema, err := db.SchemaVersion(tx)
+ var version int
+ for index, migration := range migrations {
+ var err error
+ version, err = db.SchemaVersion(tx)
if err != nil {
return fmt.Errorf("Failed to get DB schema version: %w", err)
}
- switch schema {
- case 0:
- log.Println("Starting with empty database")
- err := db.setupMinimal(tx, execVersion)
- if err != nil {
- return fmt.Errorf("Failed to setup minimal database schema: %w", err)
- }
- log.Println("Executing initial migration")
- err = migrations.Migration001(tx)
+
+ if version == index {
+ log.Printf("Starting database migration for schema version %d", version)
+ err = migration(tx)
if err != nil {
return err
}
- case 1:
- return nil
- default:
- return fmt.Errorf("Cannot migrate database to a matching schema version")
}
}
+ version, err := db.SchemaVersion(tx)
+ if err != nil {
+ return fmt.Errorf("Failed to get DB schema version: %w", err)
+ }
+ target := len(migrations)
+ if version != target {
+ return fmt.Errorf("Expected schema version %d but detected %d", target, version)
+ }
+ log.Printf("Database schema version: %d\n", version)
+ return nil
})
}