summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorxengineering <me@xengineering.eu>2024-10-13 19:52:28 +0200
committerxengineering <me@xengineering.eu>2024-10-13 19:55:38 +0200
commit473052ed8f2c83052ed5b47a7f4cec68ac2621a6 (patch)
tree2d5da088c6879317734277350c873a258b4d1dac
parented19b82335345833c5b8f5446237d559a3657a35 (diff)
downloadceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.tar
ceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.tar.zst
ceres-473052ed8f2c83052ed5b47a7f4cec68ac2621a6.zip
model: Replace global db variable by custom type
Reducing global variables makes it easier to understand functions independently of the rest of the code. Adding the new model.DB type as a custom variant of the sql.DB type makes it possible to write methods for the database which makes the code way more readable.
-rw-r--r--controller/recipe.go130
-rw-r--r--main.go12
-rw-r--r--model/database.go38
-rw-r--r--server.go13
-rw-r--r--view/recipe.go72
-rw-r--r--view/recipes.go26
6 files changed, 151 insertions, 140 deletions
diff --git a/controller/recipe.go b/controller/recipe.go
index b0a81d9..aec7144 100644
--- a/controller/recipe.go
+++ b/controller/recipe.go
@@ -12,74 +12,80 @@ import (
"github.com/gorilla/mux"
)
-func RecipeCreate(w http.ResponseWriter, r *http.Request) {
- buf, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
+func RecipeCreate(db *model.DB) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ buf, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ recipe := model.Recipe{}
+ err = json.Unmarshal(buf, &recipe)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ recipe.LastChanged = fmt.Sprint(time.Now().Unix())
+ recipe.Created = recipe.LastChanged
+
+ var obj model.Object = &recipe
+ err = db.Transaction(obj.Create)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther)
}
-
- recipe := model.Recipe{}
- err = json.Unmarshal(buf, &recipe)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
- }
-
- recipe.LastChanged = fmt.Sprint(time.Now().Unix())
- recipe.Created = recipe.LastChanged
-
- var obj model.Object = &recipe
- err = model.Transaction(obj.Create)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther)
}
-func RecipeUpdate(w http.ResponseWriter, r *http.Request) {
- buf, err := io.ReadAll(r.Body)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- recipe := model.Recipe{}
- err = json.Unmarshal(buf, &recipe)
- if err != nil {
- http.Error(w, err.Error(), http.StatusBadRequest)
- return
+func RecipeUpdate(db *model.DB) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ buf, err := io.ReadAll(r.Body)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ recipe := model.Recipe{}
+ err = json.Unmarshal(buf, &recipe)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusBadRequest)
+ return
+ }
+
+ if recipe.Id != mux.Vars(r)[`id`] {
+ http.Error(w, "IDs in URL and JSON do not match", http.StatusBadRequest)
+ return
+ }
+
+ recipe.LastChanged = fmt.Sprint(time.Now().Unix())
+
+ var obj model.Object = &recipe
+ err = db.Transaction(obj.Update)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
+
+ http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther)
}
-
- if recipe.Id != mux.Vars(r)[`id`] {
- http.Error(w, "IDs in URL and JSON do not match", http.StatusBadRequest)
- return
- }
-
- recipe.LastChanged = fmt.Sprint(time.Now().Unix())
-
- var obj model.Object = &recipe
- err = model.Transaction(obj.Update)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
-
- http.Redirect(w, r, "/recipe/"+recipe.Id, http.StatusSeeOther)
}
-func RecipeDelete(w http.ResponseWriter, r *http.Request) {
- recipe := model.Recipe{}
- recipe.Id = mux.Vars(r)[`id`]
+func RecipeDelete(db *model.DB) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ recipe := model.Recipe{}
+ recipe.Id = mux.Vars(r)[`id`]
- var obj model.Object = &recipe
- err := model.Transaction(obj.Delete)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
+ var obj model.Object = &recipe
+ err := db.Transaction(obj.Delete)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
- http.Redirect(w, r, "/recipes", http.StatusSeeOther)
+ http.Redirect(w, r, "/recipes", http.StatusSeeOther)
+ }
}
diff --git a/main.go b/main.go
index f744206..f648a2d 100644
--- a/main.go
+++ b/main.go
@@ -34,16 +34,16 @@ func main() {
}
log.Printf("Storage directory: %s\n", storage.Path)
- model.ConnectDatabase(filepath.Join(storage.Path, "ceres.sqlite3"))
- defer model.DisconnectDatabase()
- model.MigrateDatabase(version)
+ db := model.OpenDB(filepath.Join(storage.Path, "ceres.sqlite3"))
+ defer db.Close()
+ db.Migrate(version)
if flags.examples {
- model.InjectExampleRecipes()
- log.Println("Added example recipes")
+ db.CreateExamples()
+ log.Println("Created example recipes")
}
- server := NewServer(config.HttpAddress)
+ server := NewServer(config.HttpAddress, db)
go server.Start()
defer server.Stop()
diff --git a/model/database.go b/model/database.go
index 4740899..0e31882 100644
--- a/model/database.go
+++ b/model/database.go
@@ -9,12 +9,10 @@ import (
"xengineering.eu/ceres/model/migrations"
)
-var db *sql.DB
+type DB sql.DB
-func ConnectDatabase(path string) {
- var err error
-
- db, err = sql.Open("sqlite3", path)
+func OpenDB(path string) *DB {
+ db, err := sql.Open("sqlite3", path)
if err != nil {
log.Fatal(err)
}
@@ -23,10 +21,12 @@ func ConnectDatabase(path string) {
if err != nil {
log.Fatal(err)
}
+
+ return (*DB)(db)
}
-func Transaction(f func(*sql.Tx) error) error {
- tx, err := db.Begin()
+func (db *DB) Transaction(f func(*sql.Tx) error) error {
+ tx, err := (*sql.DB)(db).Begin()
if err != nil {
log.Printf("Failed to start database transaction: %v", err)
return err
@@ -46,7 +46,7 @@ func Transaction(f func(*sql.Tx) error) error {
return tx.Commit()
}
-func isDatabaseEmpty(tx *sql.Tx) bool {
+func (db *DB) IsEmpty(tx *sql.Tx) bool {
cmd := `SELECT COUNT(*) FROM sqlite_master WHERE type='table'`
rows, err := tx.Query(cmd)
if err != nil {
@@ -67,7 +67,7 @@ func isDatabaseEmpty(tx *sql.Tx) bool {
return number == 0
}
-func setupMinimalDatabase(tx *sql.Tx, execVersion string) error {
+func (db *DB) setupMinimal(tx *sql.Tx, execVersion string) error {
cmd := `
CREATE TABLE metadata (
key TEXT PRIMARY KEY,
@@ -82,7 +82,7 @@ VALUES
return err
}
-func getDatabaseVersion(tx *sql.Tx) string {
+func (db *DB) Version(tx *sql.Tx) string {
rows, err := tx.Query(`SELECT value FROM metadata WHERE key='version';`)
if err != nil {
log.Fatal(err)
@@ -102,11 +102,11 @@ func getDatabaseVersion(tx *sql.Tx) string {
return version
}
-func MigrateDatabase(execVersion string) {
- err := Transaction(func(tx *sql.Tx) error {
- if isDatabaseEmpty(tx) {
+func (db *DB) Migrate(execVersion string) {
+ err := db.Transaction(func(tx *sql.Tx) error {
+ if db.IsEmpty(tx) {
log.Println("Starting with empty database")
- err := setupMinimalDatabase(tx, execVersion)
+ err := db.setupMinimal(tx, execVersion)
if err != nil {
log.Fatalf("Failed to setup minimal database schema: %v", err)
}
@@ -118,7 +118,7 @@ func MigrateDatabase(execVersion string) {
}
}
- dbVersion := getDatabaseVersion(tx)
+ dbVersion := db.Version(tx)
if dbVersion != execVersion {
log.Fatalf(
"Database version '%s' does not match executable version '%s'",
@@ -134,8 +134,8 @@ func MigrateDatabase(execVersion string) {
}
}
-func InjectExampleRecipes() {
- err := Transaction(func(tx *sql.Tx) error {
+func (db *DB) CreateExamples() {
+ err := db.Transaction(func(tx *sql.Tx) error {
recipes := RecipeTestData()
for _, recipe := range recipes {
@@ -152,8 +152,8 @@ func InjectExampleRecipes() {
}
}
-func DisconnectDatabase() {
- var err error = db.Close()
+func (db *DB) Close() {
+ err := (*sql.DB)(db).Close()
if err != nil {
log.Printf("Failed to close database: %v\n", err)
} else {
diff --git a/server.go b/server.go
index 7caf328..0c26188 100644
--- a/server.go
+++ b/server.go
@@ -7,6 +7,7 @@ import (
"net/http"
"xengineering.eu/ceres/controller"
+ "xengineering.eu/ceres/model"
"xengineering.eu/ceres/view"
"github.com/gorilla/mux"
@@ -19,7 +20,7 @@ type Server struct {
//go:embed view/static/simple.css/simple.css view/static/ceres.js
var static embed.FS
-func NewServer(addr string) Server {
+func NewServer(addr string, db *model.DB) Server {
var r *mux.Router = mux.NewRouter()
r.PathPrefix("/static/").
@@ -27,13 +28,13 @@ func NewServer(addr string) Server {
r.HandleFunc("/version", view.VersionRead(version)).Methods(`GET`)
- r.HandleFunc("/recipes", view.RecipesRead).Methods(`GET`)
+ r.HandleFunc("/recipes", view.RecipesRead(db)).Methods(`GET`)
r.HandleFunc("/recipe/create", view.RecipeCreate).Methods(`GET`)
- r.HandleFunc("/recipe", controller.RecipeCreate).Methods(`POST`)
- r.HandleFunc("/recipe/{id:[0-9]+}", view.RecipeRead).Methods(`GET`)
- r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeUpdate).Methods(`POST`)
- r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeDelete).Methods(`DELETE`)
+ r.HandleFunc("/recipe", controller.RecipeCreate(db)).Methods(`POST`)
+ r.HandleFunc("/recipe/{id:[0-9]+}", view.RecipeRead(db)).Methods(`GET`)
+ r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeUpdate(db)).Methods(`POST`)
+ r.HandleFunc("/recipe/{id:[0-9]+}", controller.RecipeDelete(db)).Methods(`DELETE`)
r.HandleFunc("/favicon.ico", view.FaviconRead).Methods(`GET`)
diff --git a/view/recipe.go b/view/recipe.go
index d77d771..0578107 100644
--- a/view/recipe.go
+++ b/view/recipe.go
@@ -8,47 +8,49 @@ import (
"github.com/gorilla/mux"
)
-func RecipeRead(w http.ResponseWriter, r *http.Request) {
- recipe := model.Recipe{}
- recipe.Id = mux.Vars(r)[`id`]
-
- var obj model.Object = &recipe
- err := model.Transaction(obj.Read)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
+func RecipeRead(db *model.DB) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ recipe := model.Recipe{}
+ recipe.Id = mux.Vars(r)[`id`]
- template := "recipe"
- view, ok := r.URL.Query()["view"]
- if ok {
- if len(view) > 1 {
- http.Error(w, "More than one 'view' parameter given in URL", http.StatusBadRequest)
+ var obj model.Object = &recipe
+ err := db.Transaction(obj.Read)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
- template = view[0]
- }
- is_valid := false
- valid_templates := []string{
- "recipe",
- "recipe-edit",
- "recipe-confirm-deletion",
- }
- for _, v := range valid_templates {
- if template == v {
- is_valid = true
+ template := "recipe"
+ view, ok := r.URL.Query()["view"]
+ if ok {
+ if len(view) > 1 {
+ http.Error(w, "More than one 'view' parameter given in URL", http.StatusBadRequest)
+ return
+ }
+ template = view[0]
}
- }
- if !is_valid {
- http.Error(w, "Unsupported view: "+template, http.StatusBadRequest)
- return
- }
- err = html.ExecuteTemplate(w, template, recipe)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
+ is_valid := false
+ valid_templates := []string{
+ "recipe",
+ "recipe-edit",
+ "recipe-confirm-deletion",
+ }
+ for _, v := range valid_templates {
+ if template == v {
+ is_valid = true
+ }
+ }
+ if !is_valid {
+ http.Error(w, "Unsupported view: "+template, http.StatusBadRequest)
+ return
+ }
+
+ err = html.ExecuteTemplate(w, template, recipe)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
}
}
diff --git a/view/recipes.go b/view/recipes.go
index fe995b2..e7153cd 100644
--- a/view/recipes.go
+++ b/view/recipes.go
@@ -6,19 +6,21 @@ import (
"xengineering.eu/ceres/model"
)
-func RecipesRead(w http.ResponseWriter, r *http.Request) {
- recipes := make(model.Recipes, 0)
+func RecipesRead(db *model.DB) func(http.ResponseWriter, *http.Request) {
+ return func(w http.ResponseWriter, r *http.Request) {
+ recipes := make(model.Recipes, 0)
- var obj model.Object = &recipes
- err := model.Transaction(obj.Read)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
- }
+ var obj model.Object = &recipes
+ err := db.Transaction(obj.Read)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
- err = html.ExecuteTemplate(w, "recipes", recipes)
- if err != nil {
- http.Error(w, err.Error(), http.StatusInternalServerError)
- return
+ err = html.ExecuteTemplate(w, "recipes", recipes)
+ if err != nil {
+ http.Error(w, err.Error(), http.StatusInternalServerError)
+ return
+ }
}
}