diff options
author | xengineering <me@xengineering.eu> | 2024-10-21 21:37:54 +0200 |
---|---|---|
committer | xengineering <me@xengineering.eu> | 2024-10-21 21:37:54 +0200 |
commit | f917bc9e1973bfc1c6ab56bf1362295f89d687c3 (patch) | |
tree | ff7af6b33332033987fe71d8a5f788e15119727b /model/database.go | |
parent | 72601b87ef040a3c6882368ac85c12c1ae705cd2 (diff) | |
download | ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.tar ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.tar.zst ceres-f917bc9e1973bfc1c6ab56bf1362295f89d687c3.zip |
model: Rename to db.go
The old name database.go did not match the type name DB.
Diffstat (limited to 'model/database.go')
-rw-r--r-- | model/database.go | 188 |
1 files changed, 0 insertions, 188 deletions
diff --git a/model/database.go b/model/database.go deleted file mode 100644 index f3fb607..0000000 --- a/model/database.go +++ /dev/null @@ -1,188 +0,0 @@ -package model - -import ( - "database/sql" - "embed" - "fmt" - "log" - - _ "github.com/mattn/go-sqlite3" -) - -type DB sql.DB - -func OpenDB(path string) (*DB, error) { - db, err := sql.Open("sqlite3", path) - if err != nil { - return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err) - } - - err = db.Ping() - if err != nil { - return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err) - } - - return (*DB)(db), nil -} - -func (db *DB) Transaction(f func(*sql.Tx) error) error { - tx, err := (*sql.DB)(db).Begin() - if err != nil { - log.Printf("Failed to start database transaction: %v", err) - return err - } - defer func() { - if tx.Rollback() == nil { - log.Println("Rolled back transaction") - } - }() - - err = f(tx) - if err != nil { - log.Printf("Failed transaction: %v", err) - return err - } - - return tx.Commit() -} - -func (db *DB) isEmpty(tx *sql.Tx) (bool, error) { - cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` - rows, err := tx.Query(cmd) - if err != nil { - return false, fmt.Errorf("Select call failed: %w", err) - } - defer rows.Close() - - if !rows.Next() { - return false, fmt.Errorf("Result set is empty") - } - - var number int - err = rows.Scan(&number) - if err != nil { - return false, fmt.Errorf("Failed to scan numerical value: %w", err) - } - - return number == 0, nil -} - -func (db *DB) schemaVersion(tx *sql.Tx) (int, error) { - empty, err := db.isEmpty(tx) - if err != nil { - return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) - } - if empty { - return 0, nil - } - - rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) - if err != nil { - return 0, fmt.Errorf("Select call for version failed: %w", err) - } - defer rows.Close() - if rows.Next() { - return 1, nil // version field was only present in one schema version - } - - rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) - if err != nil { - return 0, fmt.Errorf("Select call for schema_version failed: %w", err) - } - defer rows.Close() - if !rows.Next() { - return 0, fmt.Errorf("No schema_version entry in metadata table") - } - var number int - err = rows.Scan(&number) - if err != nil { - return 0, fmt.Errorf("Failed to scan schema_version: %w", err) - } - - return number, nil -} - -//go:embed sql/migration*.sql -var migrationSQL embed.FS - -func getMigrations() ([]func(tx *sql.Tx) error, error) { - migrations := make([]func(tx *sql.Tx) error, 0) - - entries, err := migrationSQL.ReadDir("sql") - if err != nil { - return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) - } - amount := len(entries) - - for i := 0; i < amount; i++ { - file := fmt.Sprintf("sql/migration%03d.sql", i) - data, err := migrationSQL.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) - } - migrations = append( - migrations, - func(tx *sql.Tx) error { - _, err := tx.Exec(string(data)) - return err - }, - ) - } - - return migrations, nil -} - -func (db *DB) Migrate() error { - migrations, err := getMigrations() - if err != nil { - return fmt.Errorf("Failed to get migrations: %w", err) - } - - return db.Transaction(func(tx *sql.Tx) error { - var version int - for index, migration := range migrations { - var err error - version, err = db.schemaVersion(tx) - if err != nil { - return fmt.Errorf("Failed to get DB schema version: %w", err) - } - - if version == index { - log.Printf("Starting database migration for schema version %d", version) - err = migration(tx) - if err != nil { - return err - } - } - } - version, err := db.schemaVersion(tx) - if err != nil { - return fmt.Errorf("Failed to get DB schema version: %w", err) - } - target := len(migrations) - if version != target { - return fmt.Errorf("Expected schema version %d but detected %d", target, version) - } - log.Printf("Database schema version: %d\n", version) - return nil - }) -} - -func (db *DB) CreateExamples() error { - return db.Transaction(func(tx *sql.Tx) error { - recipes := RecipeTestData() - - for _, recipe := range recipes { - err := recipe.Create(tx) - if err != nil { - return err - } - } - - return nil - }) -} - -func (db *DB) Close() error { - return (*sql.DB)(db).Close() -} |