package model import ( "database/sql" "embed" "fmt" "log" _ "github.com/mattn/go-sqlite3" ) type DB sql.DB func OpenDB(path string) (*DB, error) { db, err := sql.Open("sqlite3", path) if err != nil { return nil, fmt.Errorf("Failed to open SQLite3 database: %w", err) } err = db.Ping() if err != nil { return nil, fmt.Errorf("Failed to ping SQLite3 database: %w", err) } return (*DB)(db), nil } func (db *DB) Transaction(f func(*sql.Tx) error) error { tx, err := (*sql.DB)(db).Begin() if err != nil { return fmt.Errorf("Failed to start database transaction: %w", err) } defer func() { if tx.Rollback() == nil { log.Println("Rolled back transaction") } }() err = f(tx) if err != nil { return fmt.Errorf("Failed transaction: %w", err) } return tx.Commit() } func (db *DB) isEmpty(tx *sql.Tx) (bool, error) { cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` rows, err := tx.Query(cmd) if err != nil { return false, fmt.Errorf("Select call failed: %w", err) } defer rows.Close() if !rows.Next() { return false, fmt.Errorf("Result set is empty") } var number int err = rows.Scan(&number) if err != nil { return false, fmt.Errorf("Failed to scan numerical value: %w", err) } return number == 0, nil } func (db *DB) schemaVersion(tx *sql.Tx) (int, error) { empty, err := db.isEmpty(tx) if err != nil { return 0, fmt.Errorf("Failed to check if DB is empty: %w", err) } if empty { return 0, nil } rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) if err != nil { return 0, fmt.Errorf("Select call for version failed: %w", err) } defer rows.Close() if rows.Next() { return 1, nil // version field was only present in one schema version } rows, err = tx.Query(`SELECT value FROM metadata WHERE key='schema_version';`) if err != nil { return 0, fmt.Errorf("Select call for schema_version failed: %w", err) } defer rows.Close() if !rows.Next() { return 0, fmt.Errorf("No schema_version entry in metadata table") } var number int err = rows.Scan(&number) if err != nil { return 0, fmt.Errorf("Failed to scan schema_version: %w", err) } return number, nil } //go:embed sql/migration*.sql var migrationSQL embed.FS func getMigrations() ([]func(tx *sql.Tx) error, error) { migrations := make([]func(tx *sql.Tx) error, 0) entries, err := migrationSQL.ReadDir("sql") if err != nil { return nil, fmt.Errorf("Failed to read embedded migration SQL FS: %w", err) } amount := len(entries) for i := 0; i < amount; i++ { file := fmt.Sprintf("sql/migration%03d.sql", i) data, err := migrationSQL.ReadFile(file) if err != nil { return nil, fmt.Errorf("Failed to read migration SQL code: %w", err) } migrations = append( migrations, func(tx *sql.Tx) error { _, err := tx.Exec(string(data)) return err }, ) } return migrations, nil } func (db *DB) Migrate() error { migrations, err := getMigrations() if err != nil { return fmt.Errorf("Failed to get migrations: %w", err) } return db.Transaction(func(tx *sql.Tx) error { var version int for index, migration := range migrations { var err error version, err = db.schemaVersion(tx) if err != nil { return fmt.Errorf("Failed to get DB schema version: %w", err) } if version == index { log.Printf("Starting database migration for schema version %d", version) err = migration(tx) if err != nil { return err } } } version, err := db.schemaVersion(tx) if err != nil { return fmt.Errorf("Failed to get DB schema version: %w", err) } target := len(migrations) if version != target { return fmt.Errorf("Expected schema version %d but detected %d", target, version) } log.Printf("Database schema version: %d\n", version) return nil }) } func (db *DB) CreateExamples() error { return db.Transaction(func(tx *sql.Tx) error { recipes := RecipeTestData() for _, recipe := range recipes { err := recipe.Create(tx) if err != nil { return err } } return nil }) } func (db *DB) Close() error { return (*sql.DB)(db).Close() }