From f917bc9e1973bfc1c6ab56bf1362295f89d687c3 Mon Sep 17 00:00:00 2001 From: xengineering Date: Mon, 21 Oct 2024 21:37:54 +0200 Subject: model: Rename to db.go The old name database.go did not match the type name DB. --- model/database.go | 188 ------------------------------------------------------ model/db.go | 188 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 188 insertions(+), 188 deletions(-) delete mode 100644 model/database.go create mode 100644 model/db.go (limited to 'model') diff --git a/model/database.go b/model/database.go deleted file mode 100644 index f3fb607..0000000 --- a/model/database.go +++ /dev/null @@ -1,188 +0,0 @@ -package model - -import ( - "database/sql" - "embed" - "fmt" - "log" - - _ "github.com/mattn/go-sqlite3" -) - -type DB sql.DB - -func OpenDB(path string) (*DB, error) { - db, err := sql.Open("sqlite3", path) - if err != nil { - return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err) - } - - err = db.Ping() - if err != nil { - return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err) - } - - return (*DB)(db), nil -} - -func (db *DB) Transaction(f func(*sql.Tx) error) error { - tx, err := (*sql.DB)(db).Begin() - if err != nil { - log.Printf("Failed to start database transaction: %v", err) - return err - } - defer func() { - if tx.Rollback() == nil { - log.Println("Rolled back transaction") - } - }() - - err = f(tx) - if err != nil { - log.Printf("Failed transaction: %v", err) - return err - } - - return tx.Commit() -} - -func (db *DB) isEmpty(tx *sql.Tx) (bool, error) { - cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` - rows, err := tx.Query(cmd) - if err != nil { - return false, fmt.Errorf("Select call failed: %w", err) - } - defer rows.Close() - - if !rows.Next() { - return false, fmt.Errorf("Result set is empty") - } - - var number int - err = rows.Scan(&number) - if err != nil { - return false, fmt.Errorf("Failed to scan numerical value: %w", err) - } - - return number == 0, nil -} - -func (db *DB) schemaVersion(tx *sql.Tx) (int, error) { - empty, err := db.isEmpty(tx) - if err != nil { - return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) - } - if empty { - return 0, nil - } - - rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) - if err != nil { - return 0, fmt.Errorf("Select call for version failed: %w", err) - } - defer rows.Close() - if rows.Next() { - return 1, nil // version field was only present in one schema version - } - - rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) - if err != nil { - return 0, fmt.Errorf("Select call for schema_version failed: %w", err) - } - defer rows.Close() - if !rows.Next() { - return 0, fmt.Errorf("No schema_version entry in metadata table") - } - var number int - err = rows.Scan(&number) - if err != nil { - return 0, fmt.Errorf("Failed to scan schema_version: %w", err) - } - - return number, nil -} - -//go:embed sql/migration*.sql -var migrationSQL embed.FS - -func getMigrations() ([]func(tx *sql.Tx) error, error) { - migrations := make([]func(tx *sql.Tx) error, 0) - - entries, err := migrationSQL.ReadDir("sql") - if err != nil { - return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) - } - amount := len(entries) - - for i := 0; i < amount; i++ { - file := fmt.Sprintf("sql/migration%03d.sql", i) - data, err := migrationSQL.ReadFile(file) - if err != nil { - return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) - } - migrations = append( - migrations, - func(tx *sql.Tx) error { - _, err := tx.Exec(string(data)) - return err - }, - ) - } - - return migrations, nil -} - -func (db *DB) Migrate() error { - migrations, err := getMigrations() - if err != nil { - return fmt.Errorf("Failed to get migrations: %w", err) - } - - return db.Transaction(func(tx *sql.Tx) error { - var version int - for index, migration := range migrations { - var err error - version, err = db.schemaVersion(tx) - if err != nil { - return fmt.Errorf("Failed to get DB schema version: %w", err) - } - - if version == index { - log.Printf("Starting database migration for schema version %d", version) - err = migration(tx) - if err != nil { - return err - } - } - } - version, err := db.schemaVersion(tx) - if err != nil { - return fmt.Errorf("Failed to get DB schema version: %w", err) - } - target := len(migrations) - if version != target { - return fmt.Errorf("Expected schema version %d but detected %d", target, version) - } - log.Printf("Database schema version: %d\n", version) - return nil - }) -} - -func (db *DB) CreateExamples() error { - return db.Transaction(func(tx *sql.Tx) error { - recipes := RecipeTestData() - - for _, recipe := range recipes { - err := recipe.Create(tx) - if err != nil { - return err - } - } - - return nil - }) -} - -func (db *DB) Close() error { - return (*sql.DB)(db).Close() -} diff --git a/model/db.go b/model/db.go new file mode 100644 index 0000000..f3fb607 --- /dev/null +++ b/model/db.go @@ -0,0 +1,188 @@ +package model + +import ( + "database/sql" + "embed" + "fmt" + "log" + + _ "github.com/mattn/go-sqlite3" +) + +type DB sql.DB + +func OpenDB(path string) (*DB, error) { + db, err := sql.Open("sqlite3", path) + if err != nil { + return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err) + } + + err = db.Ping() + if err != nil { + return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err) + } + + return (*DB)(db), nil +} + +func (db *DB) Transaction(f func(*sql.Tx) error) error { + tx, err := (*sql.DB)(db).Begin() + if err != nil { + log.Printf("Failed to start database transaction: %v", err) + return err + } + defer func() { + if tx.Rollback() == nil { + log.Println("Rolled back transaction") + } + }() + + err = f(tx) + if err != nil { + log.Printf("Failed transaction: %v", err) + return err + } + + return tx.Commit() +} + +func (db *DB) isEmpty(tx *sql.Tx) (bool, error) { + cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` + rows, err := tx.Query(cmd) + if err != nil { + return false, fmt.Errorf("Select call failed: %w", err) + } + defer rows.Close() + + if !rows.Next() { + return false, fmt.Errorf("Result set is empty") + } + + var number int + err = rows.Scan(&number) + if err != nil { + return false, fmt.Errorf("Failed to scan numerical value: %w", err) + } + + return number == 0, nil +} + +func (db *DB) schemaVersion(tx *sql.Tx) (int, error) { + empty, err := db.isEmpty(tx) + if err != nil { + return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) + } + if empty { + return 0, nil + } + + rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) + if err != nil { + return 0, fmt.Errorf("Select call for version failed: %w", err) + } + defer rows.Close() + if rows.Next() { + return 1, nil // version field was only present in one schema version + } + + rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) + if err != nil { + return 0, fmt.Errorf("Select call for schema_version failed: %w", err) + } + defer rows.Close() + if !rows.Next() { + return 0, fmt.Errorf("No schema_version entry in metadata table") + } + var number int + err = rows.Scan(&number) + if err != nil { + return 0, fmt.Errorf("Failed to scan schema_version: %w", err) + } + + return number, nil +} + +//go:embed sql/migration*.sql +var migrationSQL embed.FS + +func getMigrations() ([]func(tx *sql.Tx) error, error) { + migrations := make([]func(tx *sql.Tx) error, 0) + + entries, err := migrationSQL.ReadDir("sql") + if err != nil { + return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) + } + amount := len(entries) + + for i := 0; i < amount; i++ { + file := fmt.Sprintf("sql/migration%03d.sql", i) + data, err := migrationSQL.ReadFile(file) + if err != nil { + return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) + } + migrations = append( + migrations, + func(tx *sql.Tx) error { + _, err := tx.Exec(string(data)) + return err + }, + ) + } + + return migrations, nil +} + +func (db *DB) Migrate() error { + migrations, err := getMigrations() + if err != nil { + return fmt.Errorf("Failed to get migrations: %w", err) + } + + return db.Transaction(func(tx *sql.Tx) error { + var version int + for index, migration := range migrations { + var err error + version, err = db.schemaVersion(tx) + if err != nil { + return fmt.Errorf("Failed to get DB schema version: %w", err) + } + + if version == index { + log.Printf("Starting database migration for schema version %d", version) + err = migration(tx) + if err != nil { + return err + } + } + } + version, err := db.schemaVersion(tx) + if err != nil { + return fmt.Errorf("Failed to get DB schema version: %w", err) + } + target := len(migrations) + if version != target { + return fmt.Errorf("Expected schema version %d but detected %d", target, version) + } + log.Printf("Database schema version: %d\n", version) + return nil + }) +} + +func (db *DB) CreateExamples() error { + return db.Transaction(func(tx *sql.Tx) error { + recipes := RecipeTestData() + + for _, recipe := range recipes { + err := recipe.Create(tx) + if err != nil { + return err + } + } + + return nil + }) +} + +func (db *DB) Close() error { + return (*sql.DB)(db).Close() +} -- cgit v1.2.3-70-g09d2