diff options
Diffstat (limited to 'model/database.go')
-rw-r--r-- | model/database.go | 111 |
1 files changed, 72 insertions, 39 deletions
diff --git a/model/database.go b/model/database.go index 7490f7b..84ed497 100644 --- a/model/database.go +++ b/model/database.go @@ -2,12 +2,11 @@ package model import ( "database/sql" + "embed" "fmt" "log" _ "github.com/mattn/go-sqlite3" - - "xengineering.eu/ceres/model/migrations" ) type DB sql.DB @@ -48,8 +47,6 @@ func (db *DB) Transaction(f func(*sql.Tx) error) error { } func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) { - var number int - cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` rows, err := tx.Query(cmd) if err != nil { @@ -61,6 +58,7 @@ func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) { return false, fmt.Errorf("Result set is empty") } + var number int err = rows.Scan(&number) if err != nil { return false, fmt.Errorf("Failed to scan numerical value: %w", err) @@ -69,69 +67,104 @@ func (db *DB) IsEmpty(tx *sql.Tx) (bool, error) { return number == 0, nil } -func (db *DB) setupMinimal(tx *sql.Tx, execVersion string) error { - cmd := ` -CREATE TABLE metadata ( - key TEXT PRIMARY KEY, - value TEXT -); -INSERT INTO metadata - (key, value) -VALUES - ('version', ?); -` - _, err := tx.Exec(cmd, execVersion) - return err -} - func (db *DB) SchemaVersion(tx *sql.Tx) (int, error) { empty, err := db.IsEmpty(tx) if err != nil { return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) } - if empty { return 0, nil } rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) if err != nil { - return 0, fmt.Errorf("Select call failed: %w", err) + return 0, fmt.Errorf("Select call for version failed: %w", err) } defer rows.Close() - if rows.Next() { return 1, nil // version field was only present in one schema version } - return 0, fmt.Errorf("Unknown schema version") + rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) + if err != nil { + return 0, fmt.Errorf("Select call for schema_version failed: %w", err) + } + defer rows.Close() + if !rows.Next() { + return 0, fmt.Errorf("No schema_version entry in metadata table") + } + var number int + err = rows.Scan(&number) + if err != nil { + return 0, fmt.Errorf("Failed to scan schema_version: %w", err) + } + + return number, nil +} + +//go:embed sql/migration*.sql +var migrationSQL embed.FS + +func getMigrations() ([]func(tx *sql.Tx) error, error) { + migrations := make([]func(tx *sql.Tx) error, 0) + + entries, err := migrationSQL.ReadDir("sql") + if err != nil { + return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) + } + amount := len(entries) + + for i := 0; i < amount; i++ { + file := fmt.Sprintf("sql/migration%03d.sql", i) + data, err := migrationSQL.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) + } + migrations = append( + migrations, + func(tx *sql.Tx) error { + _, err := tx.Exec(string(data)) + return err + }, + ) + } + + return migrations, nil } -func (db *DB) Migrate(execVersion string) error { +func (db *DB) Migrate() error { + migrations, err := getMigrations() + if err != nil { + return fmt.Errorf("Failed to get migrations: %w", err) + } + return db.Transaction(func(tx *sql.Tx) error { - for { - schema, err := db.SchemaVersion(tx) + var version int + for index, migration := range migrations { + var err error + version, err = db.SchemaVersion(tx) if err != nil { return fmt.Errorf("Failed to get DB schema version: %w", err) } - switch schema { - case 0: - log.Println("Starting with empty database") - err := db.setupMinimal(tx, execVersion) - if err != nil { - return fmt.Errorf("Failed to setup minimal database schema: %w", err) - } - log.Println("Executing initial migration") - err = migrations.Migration001(tx) + + if version == index { + log.Printf("Starting database migration for schema version %d", version) + err = migration(tx) if err != nil { return err } - case 1: - return nil - default: - return fmt.Errorf("Cannot migrate database to a matching schema version") } } + version, err := db.SchemaVersion(tx) + if err != nil { + return fmt.Errorf("Failed to get DB schema version: %w", err) + } + target := len(migrations) + if version != target { + return fmt.Errorf("Expected schema version %d but detected %d", target, version) + } + log.Printf("Database schema version: %d\n", version) + return nil }) } |