summaryrefslogtreecommitdiff
path: root/model/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'model/db.go')
-rw-r--r--model/db.go188
1 files changed, 188 insertions, 0 deletions
diff --git a/model/db.go b/model/db.go
new file mode 100644
index 0000000..f3fb607
--- /dev/null
+++ b/model/db.go
@@ -0,0 +1,188 @@
+package model
+
+import (
+ "database/sql"
+ "embed"
+ "fmt"
+ "log"
+
+ _ "github.com/mattn/go-sqlite3"
+)
+
+type DB sql.DB
+
+func OpenDB(path string) (*DB, error) {
+ db, err := sql.Open("sqlite3", path)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err)
+ }
+
+ err = db.Ping()
+ if err != nil {
+ return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err)
+ }
+
+ return (*DB)(db), nil
+}
+
+func (db *DB) Transaction(f func(*sql.Tx) error) error {
+ tx, err := (*sql.DB)(db).Begin()
+ if err != nil {
+ log.Printf("Failed to start database transaction: %v", err)
+ return err
+ }
+ defer func() {
+ if tx.Rollback() == nil {
+ log.Println("Rolled back transaction")
+ }
+ }()
+
+ err = f(tx)
+ if err != nil {
+ log.Printf("Failed transaction: %v", err)
+ return err
+ }
+
+ return tx.Commit()
+}
+
+func (db *DB) isEmpty(tx *sql.Tx) (bool, error) {
+ cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'`
+ rows, err := tx.Query(cmd)
+ if err != nil {
+ return false, fmt.Errorf("Select call failed: %w", err)
+ }
+ defer rows.Close()
+
+ if !rows.Next() {
+ return false, fmt.Errorf("Result set is empty")
+ }
+
+ var number int
+ err = rows.Scan(&number)
+ if err != nil {
+ return false, fmt.Errorf("Failed to scan numerical value: %w", err)
+ }
+
+ return number == 0, nil
+}
+
+func (db *DB) schemaVersion(tx *sql.Tx) (int, error) {
+ empty, err := db.isEmpty(tx)
+ if err != nil {
+ return 0, fmt.Errorf("Failed to check if DB is empty: %w", err)
+ }
+ if empty {
+ return 0, nil
+ }
+
+ rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`)
+ if err != nil {
+ return 0, fmt.Errorf("Select call for version failed: %w", err)
+ }
+ defer rows.Close()
+ if rows.Next() {
+ return 1, nil // version field was only present in one schema version
+ }
+
+ rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`)
+ if err != nil {
+ return 0, fmt.Errorf("Select call for schema_version failed: %w", err)
+ }
+ defer rows.Close()
+ if !rows.Next() {
+ return 0, fmt.Errorf("No schema_version entry in metadata table")
+ }
+ var number int
+ err = rows.Scan(&number)
+ if err != nil {
+ return 0, fmt.Errorf("Failed to scan schema_version: %w", err)
+ }
+
+ return number, nil
+}
+
+//go:embed sql/migration*.sql
+var migrationSQL embed.FS
+
+func getMigrations() ([]func(tx *sql.Tx) error, error) {
+ migrations := make([]func(tx *sql.Tx) error, 0)
+
+ entries, err := migrationSQL.ReadDir("sql")
+ if err != nil {
+ return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err)
+ }
+ amount := len(entries)
+
+ for i := 0; i < amount; i++ {
+ file := fmt.Sprintf("sql/migration%03d.sql", i)
+ data, err := migrationSQL.ReadFile(file)
+ if err != nil {
+ return nil, fmt.Errorf("Failed to read migration SQL code: %w", err)
+ }
+ migrations = append(
+ migrations,
+ func(tx *sql.Tx) error {
+ _, err := tx.Exec(string(data))
+ return err
+ },
+ )
+ }
+
+ return migrations, nil
+}
+
+func (db *DB) Migrate() error {
+ migrations, err := getMigrations()
+ if err != nil {
+ return fmt.Errorf("Failed to get migrations: %w", err)
+ }
+
+ return db.Transaction(func(tx *sql.Tx) error {
+ var version int
+ for index, migration := range migrations {
+ var err error
+ version, err = db.schemaVersion(tx)
+ if err != nil {
+ return fmt.Errorf("Failed to get DB schema version: %w", err)
+ }
+
+ if version == index {
+ log.Printf("Starting database migration for schema version %d", version)
+ err = migration(tx)
+ if err != nil {
+ return err
+ }
+ }
+ }
+ version, err := db.schemaVersion(tx)
+ if err != nil {
+ return fmt.Errorf("Failed to get DB schema version: %w", err)
+ }
+ target := len(migrations)
+ if version != target {
+ return fmt.Errorf("Expected schema version %d but detected %d", target, version)
+ }
+ log.Printf("Database schema version: %d\n", version)
+ return nil
+ })
+}
+
+func (db *DB) CreateExamples() error {
+ return db.Transaction(func(tx *sql.Tx) error {
+ recipes := RecipeTestData()
+
+ for _, recipe := range recipes {
+ err := recipe.Create(tx)
+ if err != nil {
+ return err
+ }
+ }
+
+ return nil
+ })
+}
+
+func (db *DB) Close() error {
+ return (*sql.DB)(db).Close()
+}