diff options
author | xengineering <me@xengineering.eu> | 2024-10-13 19:52:28 +0200 |
---|---|---|
committer | xengineering <me@xengineering.eu> | 2024-10-13 19:55:38 +0200 |
commit | 473052ed8f2c83052ed5b47a7f4cec68ac2621a6 (patch) | |
tree | 2d5da088c6879317734277350c873a258b4d1dac | |
parent | ed19b82335345833c5b8f5446237d559a3657a35 (diff) | |
download | ceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.tar ceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.tar.zst ceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.zip |
model: Replace global db variable by custom type
Reducing global variables makes it easier to understand functions
independently of the rest of the code.
Adding the new model.DB type as a custom variant of the sql.DB type
makes it possible to write methods for the database which makes the code
way more readable.
-rw-r--r-- | controller/recipe.go | 130 | ||||
-rw-r--r-- | main.go | 12 | ||||
-rw-r--r-- | model/database.go | 38 | ||||
-rw-r--r-- | server.go | 13 | ||||
-rw-r--r-- | view/recipe.go | 72 | ||||
-rw-r--r-- | view/recipes.go | 26 |
6 files changed, 151 insertions, 140 deletions
diff --git a/controller/recipe.go b/controller/recipe.go index b0a81d9..aec7144 100644 --- a/controller/recipe.go +++ b/controller/recipe.go @@ -12,74 +12,80 @@ import ( "github.com/gorilla/mux" ) -func RecipeCreate(w http.ResponseWriter, r *http.Request) { - buf, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return +func RecipeCreate(db *model.DB) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + buf, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + recipe := model.Recipe{} + err = json.Unmarshal(buf, &recipe) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + recipe.LastChanged = fmt.Sprint(time.Now().Unix()) + recipe.Created = recipe.LastChanged + + var obj model.Object = &recipe + err = db.Transaction(obj.Create) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther) } - - recipe := model.Recipe{} - err = json.Unmarshal(buf, &recipe) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return - } - - recipe.LastChanged = fmt.Sprint(time.Now().Unix()) - recipe.Created = recipe.LastChanged - - var obj model.Object = &recipe - err = model.Transaction(obj.Create) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther) } -func RecipeUpdate(w http.ResponseWriter, r *http.Request) { - buf, err := io.ReadAll(r.Body) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - recipe := model.Recipe{} - err = json.Unmarshal(buf, &recipe) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) - return +func RecipeUpdate(db *model.DB) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + buf, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + recipe := model.Recipe{} + err = json.Unmarshal(buf, &recipe) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + + if recipe.Id != mux.Vars(r)[`id`] { + http.Error(w, "IDs in URL and JSON do not match", http.StatusBadRequest) + return + } + + recipe.LastChanged = fmt.Sprint(time.Now().Unix()) + + var obj model.Object = &recipe + err = db.Transaction(obj.Update) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther) } - - if recipe.Id != mux.Vars(r)[`id`] { - http.Error(w, "IDs in URL and JSON do not match", http.StatusBadRequest) - return - } - - recipe.LastChanged = fmt.Sprint(time.Now().Unix()) - - var obj model.Object = &recipe - err = model.Transaction(obj.Update) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - - http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther) } -func RecipeDelete(w http.ResponseWriter, r *http.Request) { - recipe := model.Recipe{} - recipe.Id = mux.Vars(r)[`id`] +func RecipeDelete(db *model.DB) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + recipe := model.Recipe{} + recipe.Id = mux.Vars(r)[`id`] - var obj model.Object = &recipe - err := model.Transaction(obj.Delete) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + var obj model.Object = &recipe + err := db.Transaction(obj.Delete) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - http.Redirect(w, r, "/recipes", http.StatusSeeOther) + http.Redirect(w, r, "/recipes", http.StatusSeeOther) + } } @@ -34,16 +34,16 @@ func main() { } log.Printf("Storage directory: %s\n", storage.Path) - model.ConnectDatabase(filepath.Join(storage.Path, "ceres.sqlite3")) - defer model.DisconnectDatabase() - model.MigrateDatabase(version) + db := model.OpenDB(filepath.Join(storage.Path, "ceres.sqlite3")) + defer db.Close() + db.Migrate(version) if flags.examples { - model.InjectExampleRecipes() - log.Println("Added example recipes") + db.CreateExamples() + log.Println("Created example recipes") } - server := NewServer(config.HttpAddress) + server := NewServer(config.HttpAddress, db) go server.Start() defer server.Stop() diff --git a/model/database.go b/model/database.go index 4740899..0e31882 100644 --- a/model/database.go +++ b/model/database.go @@ -9,12 +9,10 @@ import ( "xengineering.eu/ceres/model/migrations" ) -var db *sql.DB +type DB sql.DB -func ConnectDatabase(path string) { - var err error - - db, err = sql.Open("sqlite3", path) +func OpenDB(path string) *DB { + db, err := sql.Open("sqlite3", path) if err != nil { log.Fatal(err) } @@ -23,10 +21,12 @@ func ConnectDatabase(path string) { if err != nil { log.Fatal(err) } + + return (*DB)(db) } -func Transaction(f func(*sql.Tx) error) error { - tx, err := db.Begin() +func (db *DB) Transaction(f func(*sql.Tx) error) error { + tx, err := (*sql.DB)(db).Begin() if err != nil { log.Printf("Failed to start database transaction: %v", err) return err @@ -46,7 +46,7 @@ func Transaction(f func(*sql.Tx) error) error { return tx.Commit() } -func isDatabaseEmpty(tx *sql.Tx) bool { +func (db *DB) IsEmpty(tx *sql.Tx) bool { cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'` rows, err := tx.Query(cmd) if err != nil { @@ -67,7 +67,7 @@ func isDatabaseEmpty(tx *sql.Tx) bool { return number == 0 } -func setupMinimalDatabase(tx *sql.Tx, execVersion string) error { +func (db *DB) setupMinimal(tx *sql.Tx, execVersion string) error { cmd := ` CREATE TABLE metadata ( key TEXT PRIMARY KEY, @@ -82,7 +82,7 @@ VALUES return err } -func getDatabaseVersion(tx *sql.Tx) string { +func (db *DB) Version(tx *sql.Tx) string { rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`) if err != nil { log.Fatal(err) @@ -102,11 +102,11 @@ func getDatabaseVersion(tx *sql.Tx) string { return version } -func MigrateDatabase(execVersion string) { - err := Transaction(func(tx *sql.Tx) error { - if isDatabaseEmpty(tx) { +func (db *DB) Migrate(execVersion string) { + err := db.Transaction(func(tx *sql.Tx) error { + if db.IsEmpty(tx) { log.Println("Starting with empty database") - err := setupMinimalDatabase(tx, execVersion) + err := db.setupMinimal(tx, execVersion) if err != nil { log.Fatalf("Failed to setup minimal database schema: %v", err) } @@ -118,7 +118,7 @@ func MigrateDatabase(execVersion string) { } } - dbVersion := getDatabaseVersion(tx) + dbVersion := db.Version(tx) if dbVersion != execVersion { log.Fatalf( "Database version '%s' does not match executable version '%s'", @@ -134,8 +134,8 @@ func MigrateDatabase(execVersion string) { } } -func InjectExampleRecipes() { - err := Transaction(func(tx *sql.Tx) error { +func (db *DB) CreateExamples() { + err := db.Transaction(func(tx *sql.Tx) error { recipes := RecipeTestData() for _, recipe := range recipes { @@ -152,8 +152,8 @@ func InjectExampleRecipes() { } } -func DisconnectDatabase() { - var err error = db.Close() +func (db *DB) Close() { + err := (*sql.DB)(db).Close() if err != nil { log.Printf("Failed to close database: %v\n", err) } else { @@ -7,6 +7,7 @@ import ( "net/http" "xengineering.eu/ceres/controller" + "xengineering.eu/ceres/model" "xengineering.eu/ceres/view" "github.com/gorilla/mux" @@ -19,7 +20,7 @@ type Server struct { //go:embed view/static/simple.css/simple.css view/static/ceres.js var static embed.FS -func NewServer(addr string) Server { +func NewServer(addr string, db *model.DB) Server { var r *mux.Router = mux.NewRouter() r.PathPrefix("/static/"). @@ -27,13 +28,13 @@ func NewServer(addr string) Server { r.HandleFunc("/version", view.VersionRead(version)).Methods(`GET`) - r.HandleFunc("/recipes", view.RecipesRead).Methods(`GET`) + r.HandleFunc("/recipes", view.RecipesRead(db)).Methods(`GET`) r.HandleFunc("/recipe/create", view.RecipeCreate).Methods(`GET`) - r.HandleFunc("/recipe", controller.RecipeCreate).Methods(`POST`) - r.HandleFunc("/recipe/{id:[0-9]+}", view.RecipeRead).Methods(`GET`) - r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeUpdate).Methods(`POST`) - r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeDelete).Methods(`DELETE`) + r.HandleFunc("/recipe", controller.RecipeCreate(db)).Methods(`POST`) + r.HandleFunc("/recipe/{id:[0-9]+}", view.RecipeRead(db)).Methods(`GET`) + r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeUpdate(db)).Methods(`POST`) + r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeDelete(db)).Methods(`DELETE`) r.HandleFunc("/favicon.ico", view.FaviconRead).Methods(`GET`) diff --git a/view/recipe.go b/view/recipe.go index d77d771..0578107 100644 --- a/view/recipe.go +++ b/view/recipe.go @@ -8,47 +8,49 @@ import ( "github.com/gorilla/mux" ) -func RecipeRead(w http.ResponseWriter, r *http.Request) { - recipe := model.Recipe{} - recipe.Id = mux.Vars(r)[`id`] - - var obj model.Object = &recipe - err := model.Transaction(obj.Read) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } +func RecipeRead(db *model.DB) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + recipe := model.Recipe{} + recipe.Id = mux.Vars(r)[`id`] - template := "recipe" - view, ok := r.URL.Query()["view"] - if ok { - if len(view) > 1 { - http.Error(w, "More than one 'view' parameter given in URL", http.StatusBadRequest) + var obj model.Object = &recipe + err := db.Transaction(obj.Read) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - template = view[0] - } - is_valid := false - valid_templates := []string{ - "recipe", - "recipe-edit", - "recipe-confirm-deletion", - } - for _, v := range valid_templates { - if template == v { - is_valid = true + template := "recipe" + view, ok := r.URL.Query()["view"] + if ok { + if len(view) > 1 { + http.Error(w, "More than one 'view' parameter given in URL", http.StatusBadRequest) + return + } + template = view[0] } - } - if !is_valid { - http.Error(w, "Unsupported view: "+template, http.StatusBadRequest) - return - } - err = html.ExecuteTemplate(w, template, recipe) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + is_valid := false + valid_templates := []string{ + "recipe", + "recipe-edit", + "recipe-confirm-deletion", + } + for _, v := range valid_templates { + if template == v { + is_valid = true + } + } + if !is_valid { + http.Error(w, "Unsupported view: "+template, http.StatusBadRequest) + return + } + + err = html.ExecuteTemplate(w, template, recipe) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } } diff --git a/view/recipes.go b/view/recipes.go index fe995b2..e7153cd 100644 --- a/view/recipes.go +++ b/view/recipes.go @@ -6,19 +6,21 @@ import ( "xengineering.eu/ceres/model" ) -func RecipesRead(w http.ResponseWriter, r *http.Request) { - recipes := make(model.Recipes, 0) +func RecipesRead(db *model.DB) func(http.ResponseWriter, *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + recipes := make(model.Recipes, 0) - var obj model.Object = &recipes - err := model.Transaction(obj.Read) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } + var obj model.Object = &recipes + err := db.Transaction(obj.Read) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } - err = html.ExecuteTemplate(w, "recipes", recipes) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return + err = html.ExecuteTemplate(w, "recipes", recipes) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } } } |