diff options
-rw-r--r-- | controller/recipe.go | 6 | ||||
-rw-r--r-- | model/database.go | 62 | ||||
-rw-r--r-- | model/object.go | 15 | ||||
-rw-r--r-- | model/recipe_test.go | 129 | ||||
-rw-r--r-- | view/recipe.go | 2 |
5 files changed, 100 insertions, 114 deletions
diff --git a/controller/recipe.go b/controller/recipe.go index 9427b0d..9529b2a 100644 --- a/controller/recipe.go +++ b/controller/recipe.go @@ -19,7 +19,7 @@ func RecipeCreate(w http.ResponseWriter, r *http.Request) { recipe.Created = recipe.LastChanged var obj model.Object = &recipe - err := model.SafeCrud(obj.Create) + err := model.Transaction(obj.Create) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -50,7 +50,7 @@ func RecipeUpdate(w http.ResponseWriter, r *http.Request) { recipe.LastChanged = fmt.Sprint(time.Now().Unix()) var obj model.Object = &recipe - err = model.SafeCrud(obj.Update) + err = model.Transaction(obj.Update) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return @@ -64,7 +64,7 @@ func RecipeDelete(w http.ResponseWriter, r *http.Request) { recipe.Id = mux.Vars(r)[`id`] var obj model.Object = &recipe - err := model.SafeCrud(obj.Delete) + err := model.Transaction(obj.Delete) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/model/database.go b/model/database.go index d816163..44792fe 100644 --- a/model/database.go +++ b/model/database.go @@ -4,6 +4,7 @@ import ( "database/sql" "embed" "log" + "fmt" _ "github.com/mattn/go-sqlite3" ) @@ -35,37 +36,53 @@ func InitDatabase(path string) { log.Fatal(err) } - // FIXME roll back migration on error query, err := GetSql(`migrate`) if err != nil { log.Fatal(err) } - _, err = db.Exec(query) + + err = Transaction(func(tx *sql.Tx) error { + _, err := tx.Exec(query) + if err != nil { + return err + } + return nil + }) if err != nil { - log.Fatal(err) + log.Fatalf("Migration failed: %v\n", err) } } -func InjectTestRecipes() { - recipes := RecipeTestData() - - tx, err := NewTx() +func Transaction(f func(*sql.Tx) error) error { + tx, err := db.Begin() if err != nil { - log.Fatalf("Failed to inject test recipes: %v\n", err) + return err } - for _, recipe := range recipes { - err = recipe.Create(tx) - if err != nil { - Rollback(tx) - log.Fatalf("Failed to inject test recipe: %v\n", err) + err = f(tx) + if err != nil { + rollbackErr := tx.Rollback() + if rollbackErr != nil { + return fmt.Errorf("Failed rollback '%w' after failed transaction '%w'", rollbackErr, err) } } - err = tx.Commit() - if err != nil { - log.Fatalf("Failed to inject test recipe: %v\n", err) - } + return tx.Commit() +} + +func InjectTestRecipes() { + Transaction(func(tx *sql.Tx) error { + recipes := RecipeTestData() + + for _, recipe := range recipes { + err := recipe.Create(tx) + if err != nil { + return err + } + } + + return nil + }) } func CloseDatabase() { @@ -76,14 +93,3 @@ func CloseDatabase() { log.Println("Closed database") } } - -func NewTx() (*sql.Tx, error) { - return db.Begin() -} - -func Rollback(tx *sql.Tx) { - err := tx.Rollback() - if err != nil { - log.Printf("Failed to rollback transaction: %v\n", err) - } -} diff --git a/model/object.go b/model/object.go index bcbba3e..63ef419 100644 --- a/model/object.go +++ b/model/object.go @@ -10,18 +10,3 @@ type Object interface { Update(tx *sql.Tx) error Delete(tx *sql.Tx) error } - -func SafeCrud(crud func(tx *sql.Tx) error) error { - tx, err := NewTx() - if err != nil { - return err - } - - err = crud(tx) - if err != nil { - Rollback(tx) - return err - } - - return tx.Commit() -} diff --git a/model/recipe_test.go b/model/recipe_test.go index 7b6c14c..74ccfa2 100644 --- a/model/recipe_test.go +++ b/model/recipe_test.go @@ -1,11 +1,12 @@ package model import ( + "database/sql" + "fmt" + "os" + "path/filepath" "reflect" "testing" - "path/filepath" - "os" - "fmt" ) func TestRecipeCrud(t *testing.T) { @@ -25,68 +26,62 @@ func TestRecipeCrud(t *testing.T) { InitDatabase(filepath.Join(storage.Path, "ceres.sqlite3")) defer CloseDatabase() - tx, err := NewTx() - if err != nil { - t.Fatalf("Failed to inject test recipes: %v\n", err) - } - - var original, readback, update, updated, deleted Recipe - - recipes := RecipeTestData() - original = recipes[0] - update = recipes[1] - - err = original.Create(tx) - if err != nil { - t.Fatalf("Failed to create test recipe in DB: %v\n", err) - } - - readback.Id = original.Id - err = readback.Read(tx) - if err != nil { - t.Fatalf("Failed to read test recipe from DB: %v\n", err) - } - - if !reflect.DeepEqual(original, readback) { - t.Fatalf("Recipes did not match after create / read cycle:\n"+ - "Before: %s\nAfter: %s\n", original, readback) - } - - update.Id = original.Id - - err = update.Update(tx) - if err != nil { - t.Fatalf("Failed to update recipe: %v\n", err) - } - - updated.Id = original.Id - err = updated.Read(tx) - if err != nil { - t.Fatalf("Failed to read back updated recipe: %v\n", err) - } - - if !reflect.DeepEqual(update, updated) { - t.Fatalf("Recipes did not match after update / read cycle:\n"+ - "Update: %s\nUpdated: %s\n", update, updated) - } - - if reflect.DeepEqual(updated, original) { - t.Fatalf("Updated and original recipe match") - } - - err = updated.Delete(tx) - if err != nil { - t.Fatalf("Failed to delete updated recipe: %v\n", err) - } - - deleted.Id = updated.Id - err = deleted.Read(tx) - if err == nil { - t.Fatalf("Was able to read back deleted recipe") - } - - err = tx.Commit() - if err != nil { - t.Fatalf("Unable to commit test transaction") - } + Transaction(func(tx *sql.Tx) error { + var original, readback, update, updated, deleted Recipe + + recipes := RecipeTestData() + original = recipes[0] + update = recipes[1] + + err = original.Create(tx) + if err != nil { + t.Fatalf("Failed to create test recipe in DB: %v\n", err) + } + + readback.Id = original.Id + err = readback.Read(tx) + if err != nil { + t.Fatalf("Failed to read test recipe from DB: %v\n", err) + } + + if !reflect.DeepEqual(original, readback) { + t.Fatalf("Recipes did not match after create / read cycle:\n"+ + "Before: %s\nAfter: %s\n", original, readback) + } + + update.Id = original.Id + + err = update.Update(tx) + if err != nil { + t.Fatalf("Failed to update recipe: %v\n", err) + } + + updated.Id = original.Id + err = updated.Read(tx) + if err != nil { + t.Fatalf("Failed to read back updated recipe: %v\n", err) + } + + if !reflect.DeepEqual(update, updated) { + t.Fatalf("Recipes did not match after update / read cycle:\n"+ + "Update: %s\nUpdated: %s\n", update, updated) + } + + if reflect.DeepEqual(updated, original) { + t.Fatalf("Updated and original recipe match") + } + + err = updated.Delete(tx) + if err != nil { + t.Fatalf("Failed to delete updated recipe: %v\n", err) + } + + deleted.Id = updated.Id + err = deleted.Read(tx) + if err == nil { + t.Fatalf("Was able to read back deleted recipe") + } + + return nil + }) } diff --git a/view/recipe.go b/view/recipe.go index ba670a2..7b9980d 100644 --- a/view/recipe.go +++ b/view/recipe.go @@ -13,7 +13,7 @@ func RecipeRead(w http.ResponseWriter, r *http.Request) { recipe.Id = mux.Vars(r)[`id`] var obj model.Object = &recipe - err := model.SafeCrud(obj.Read) + err := model.Transaction(obj.Read) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return |