diff options
author | xengineering <me@xengineering.eu> | 2024-10-21 21:37:54 +0200 |
---|---|---|
committer | xengineering <me@xengineering.eu> | 2024-10-21 21:37:54 +0200 |
commit | f917bc9e1973bfc1c6ab56bf1362295f89d687c3 (patch) | |
tree | ff7af6b33332033987fe71d8a5f788e15119727b /model/db.go | |
parent | 72601b87ef040a3c6882368ac85c12c1ae705cd2 (diff) | |
download | ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.tar ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.tar.zst ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.zip |
model: Rename to db.go
The old name database.go did not match the type name DB.
Diffstat (limited to 'model/db.go')
-rw-r--r-- | model/db.go | 188 |
1 files changed, 188 insertions, 0 deletions
diff --git a/model/db.go b/model/db.go new file mode 100644 index 0000000..f3fb607 --- /dev/null +++ b/model/db.go @@ -0,0 +1,188 @@ +package model + +import ( + "database/sql" + "embed" + "fmt" + "log" + + _ "github.com/mattn/go-sqlite3" +) + +type DB sql.DB + +func OpenDB(path string) (*DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err) + } + + err = db.Ping() + if err != nil { + return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err) + } + + return (*DB)(db), nil +} + +func (db *DB) Transaction(f func(*sql.Tx) error) error { + tx, err := (*sql.DB)(db).Begin() + if err != nil { + log.Printf("Failed to start database transaction: %v", err) + return err + } + defer func() { + if tx.Rollback() == nil { + log.Println("Rolled back transaction") + } + }() + + err = f(tx) + if err != nil { + log.Printf("Failed transaction: %v", err) + return err + } + + return tx.Commit() +} + +func (db *DB) isEmpty(tx *sql.Tx) (bool, error) { + cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` + rows, err := tx.Query(cmd) + if err != nil { + return false, fmt.Errorf("Select call failed: %w", err) + } + defer rows.Close() + + if !rows.Next() { + return false, fmt.Errorf("Result set is empty") + } + + var number int + err = rows.Scan(&number) + if err != nil { + return false, fmt.Errorf("Failed to scan numerical value: %w", err) + } + + return number == 0, nil +} + +func (db *DB) schemaVersion(tx *sql.Tx) (int, error) { + empty, err := db.isEmpty(tx) + if err != nil { + return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) + } + if empty { + return 0, nil + } + + rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) + if err != nil { + return 0, fmt.Errorf("Select call for version failed: %w", err) + } + defer rows.Close() + if rows.Next() { + return 1, nil // version field was only present in one schema version + } + + rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) + if err != nil { + return 0, fmt.Errorf("Select call for schema_version failed: %w", err) + } + defer rows.Close() + if !rows.Next() { + return 0, fmt.Errorf("No schema_version entry in metadata table") + } + var number int + err = rows.Scan(&number) + if err != nil { + return 0, fmt.Errorf("Failed to scan schema_version: %w", err) + } + + return number, nil +} + +//go:embed sql/migration*.sql +var migrationSQL embed.FS + +func getMigrations() ([]func(tx *sql.Tx) error, error) { + migrations := make([]func(tx *sql.Tx) error, 0) + + entries, err := migrationSQL.ReadDir("sql") + if err != nil { + return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) + } + amount := len(entries) + + for i := 0; i < amount; i++ { + file := fmt.Sprintf("sql/migration%03d.sql", i) + data, err := migrationSQL.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) + } + migrations = append( + migrations, + func(tx *sql.Tx) error { + _, err := tx.Exec(string(data)) + return err + }, + ) + } + + return migrations, nil +} + +func (db *DB) Migrate() error { + migrations, err := getMigrations() + if err != nil { + return fmt.Errorf("Failed to get migrations: %w", err) + } + + return db.Transaction(func(tx *sql.Tx) error { + var version int + for index, migration := range migrations { + var err error + version, err = db.schemaVersion(tx) + if err != nil { + return fmt.Errorf("Failed to get DB schema version: %w", err) + } + + if version == index { + log.Printf("Starting database migration for schema version %d", version) + err = migration(tx) + if err != nil { + return err + } + } + } + version, err := db.schemaVersion(tx) + if err != nil { + return fmt.Errorf("Failed to get DB schema version: %w", err) + } + target := len(migrations) + if version != target { + return fmt.Errorf("Expected schema version %d but detected %d", target, version) + } + log.Printf("Database schema version: %d\n", version) + return nil + }) +} + +func (db *DB) CreateExamples() error { + return db.Transaction(func(tx *sql.Tx) error { + recipes := RecipeTestData() + + for _, recipe := range recipes { + err := recipe.Create(tx) + if err != nil { + return err + } + } + + return nil + }) +} + +func (db *DB) Close() error { + return (*sql.DB)(db).Close() +} |