summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorxengineering <me@xengineering.eu>2023-02-11 13:46:45 +0100
committerxengineering <me@xengineering.eu>2023-02-11 13:46:45 +0100
commit6b52b4bbc81c64b29daff043b83e21c38f332052 (patch)
tree3a5d41f75228fe6ece3df5c109365d54b0ba6d04
parent312c59564700da719dbafad11c2d9f647b46d912 (diff)
downloadceres-6b52b4bbc81c64b29daff043b83e21c38f332052.tar
ceres-6b52b4bbc81c64b29daff043b83e21c38f332052.tar.zst
ceres-6b52b4bbc81c64b29daff043b83e21c38f332052.zip
Switch to global database pointer
Passing the database pointer around is a lot of text and has no benefit.
-rw-r--r--database.go79
-rw-r--r--handler.go28
-rw-r--r--main.go6
-rw-r--r--mux.go20
-rw-r--r--server.go14
5 files changed, 62 insertions, 85 deletions
diff --git a/database.go b/database.go
index d19572a..1efd22e 100644
--- a/database.go
+++ b/database.go
@@ -20,17 +20,26 @@ import (
const databaseSchemaVersion int = 2 // this defines the needed version for the
// executable
-type Database struct {
- target string
- Backend *sql.DB
-}
+func setupDatabase() *sql.DB {
+
+ u,err := user.Current()
+ if err != nil {
+ log.Fatal(err)
+ }
+ target := fmt.Sprintf("%s@unix(%s)/%s", u.Username, config.Database.Socket,
+ config.Database.Database)
+
+ db,err := sql.Open("mysql", target)
+ if err != nil {
+ log.Fatal(err)
+ }
-func InitDatabase() Database {
+ err = db.Ping()
+ if err != nil {
+ log.Fatal(err)
+ }
- db := NewDatabase()
- db.Connect()
- db.Ping()
- db.Migrate()
+ migrate(db)
// allow graceful shutdown
var listener = make(chan os.Signal)
@@ -39,53 +48,19 @@ func InitDatabase() Database {
go func() {
signal := <-listener
log.Printf("\nGot signal '%+v'. Shutting down ...\n", signal)
- db.Cleanup()
+ dbCleanup(db)
os.Exit(0) // TODO this does not belong to a database - write utils file 'shutdown.go'
}()
return db
}
-func NewDatabase() Database {
-
- db := Database{}
-
- var username string
- user_ptr,err := user.Current()
- if err != nil {
- log.Fatal(err)
- }
- username = user_ptr.Username
- db.target = fmt.Sprintf("%s@unix(%s)/%s", username, config.Database.Socket,
- config.Database.Database)
-
- return db
-}
-
-func (db *Database) Connect() {
- var err error
- db.Backend,err = sql.Open("mysql", db.target)
- if err != nil {
- log.Fatal(err)
- }
- log.Printf("Connected to database '%s'\n", db.target)
-}
-
-func (db *Database) Ping() {
- err := db.Backend.Ping()
- if err != nil {
- log.Fatal(err)
- } else {
- log.Println("Database is responding")
- }
-}
-
-func (db *Database) Migrate() {
+func migrate(db *sql.DB) {
const t = databaseSchemaVersion // targeted database schema version
for {
- v := db.SchemaVersion() // read schema version from DB table
+ v := schemaVersion(db) // read schema version from DB table
// handle current database schema which is newer than targeted one
if v > t {
@@ -102,12 +77,12 @@ func (db *Database) Migrate() {
log.Printf("Starting database schema migration to version %d.\n", v+1)
path := filepath.Join(config.Database.Migrations,
fmt.Sprintf("%04d_migration.sql", v+1))
- RunSql(path)
+ RunSqlScript(path)
log.Printf("Finished database schema migration to version %d.\n", v+1)
}
}
-func RunSql(path string) {
+func RunSqlScript(path string) {
script, err := os.Open(path)
if err != nil {
@@ -133,11 +108,11 @@ func RunSql(path string) {
}
}
-func (db *Database) SchemaVersion() int {
+func schemaVersion(db *sql.DB) int {
// ask database for schema version
cmd := "SELECT value FROM meta WHERE (identifier='version');"
- rows, err := db.Backend.Query(cmd)
+ rows, err := db.Query(cmd)
// handle missing meta table
if err != nil {
@@ -165,8 +140,8 @@ func (db *Database) SchemaVersion() int {
return v
}
-func (db *Database) Cleanup() {
- err := db.Backend.Close()
+func dbCleanup(db *sql.DB) {
+ err := db.Close()
if err != nil {
log.Println("Could not close database connection")
} else {
diff --git a/handler.go b/handler.go
index b6c57c4..0899797 100644
--- a/handler.go
+++ b/handler.go
@@ -17,12 +17,12 @@ const (
VALID_ID_REGEX = `^[0-9]+$`
)
-func indexGet(w http.ResponseWriter, r *http.Request, db *Database) {
+func indexGet(w http.ResponseWriter, r *http.Request) {
// get data from database
cmd := "SELECT id,title FROM recipes ORDER BY title;"
log.Printf("Query: %s", cmd)
- rows, err := db.Backend.Query(cmd)
+ rows, err := db.Query(cmd)
if err != nil {
http.Error(w, "Failed to load recipes from database.", 500)
return
@@ -53,7 +53,7 @@ func indexGet(w http.ResponseWriter, r *http.Request, db *Database) {
ServeTemplate(w, "index", path, elements)
}
-func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) {
+func recipeGet(w http.ResponseWriter, r *http.Request) {
// get id from URL parameters
ids := r.URL.Query()["id"]
@@ -74,7 +74,7 @@ func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) {
// get data from database
cmd := fmt.Sprintf("SELECT title,upstream_url,description_markdown FROM recipes WHERE (id='%s');", idStr)
log.Printf("Query: %s", cmd)
- rows, err := db.Backend.Query(cmd)
+ rows, err := db.Query(cmd)
if err != nil {
http.Error(w, "Database returned error: " + err.Error(), 500)
return
@@ -120,7 +120,7 @@ func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) {
ServeTemplate(w, "recipe", path, elements[0])
}
-func recipePost(w http.ResponseWriter, r *http.Request, db *Database) {
+func recipePost(w http.ResponseWriter, r *http.Request) {
// get id from URL parameters
ids := r.URL.Query()["id"]
@@ -141,10 +141,10 @@ func recipePost(w http.ResponseWriter, r *http.Request, db *Database) {
// read request body
buffer,_ := ioutil.ReadAll(r.Body) // FIXME error handling
body := string(buffer)
- updateRecipe(db, body, idStr)
+ updateRecipe(body, idStr)
}
-func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) {
+func recipeEditGet(w http.ResponseWriter, r *http.Request) {
// get id from URL parameters
ids := r.URL.Query()["id"]
@@ -164,7 +164,7 @@ func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) {
// get data from database
cmd := fmt.Sprintf("SELECT title,upstream_url,description_markdown FROM recipes WHERE (id='%s');", idStr)
log.Printf("Query: %s", cmd)
- rows, err := db.Backend.Query(cmd)
+ rows, err := db.Query(cmd)
if err != nil {
http.Error(w, "Got error from database: " + err.Error(), 500)
return
@@ -205,7 +205,7 @@ func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) {
ServeTemplate(w, "recipe", path, elements[0])
}
-func recipeEditPost(w http.ResponseWriter, r *http.Request, db *Database) {
+func recipeEditPost(w http.ResponseWriter, r *http.Request) {
// get id from URL parameters
ids := r.URL.Query()["id"]
@@ -225,13 +225,13 @@ func recipeEditPost(w http.ResponseWriter, r *http.Request, db *Database) {
// read request body
buffer,_ := ioutil.ReadAll(r.Body) // FIXME error handling
body := string(buffer)
- updateRecipe(db, body, idStr)
+ updateRecipe(body, idStr)
}
-func updateRecipe(db *Database, body string, idStr string) {
+func updateRecipe(body string, idStr string) {
// execute SQL UPDATE
- _,_ = db.Backend.Exec(`
+ _,_ = db.Exec(`
UPDATE
recipes
SET
@@ -275,14 +275,14 @@ func addRecipesGet(w http.ResponseWriter, r *http.Request) {
http.ServeFile(w, r, path)
}
-func addRecipesPost(w http.ResponseWriter, r *http.Request, db *Database) {
+func addRecipesPost(w http.ResponseWriter, r *http.Request) {
url := r.FormValue("url")
title := r.FormValue("title")
cmd := fmt.Sprintf("INSERT INTO recipes (title,upstream_url) VALUES ('%s', '%s')", title, url)
log.Println(cmd)
- res,err := db.Backend.Exec(cmd)
+ res,err := db.Exec(cmd)
if err != nil {
http.Error(w, "Could not add recipe.", 500)
return
diff --git a/main.go b/main.go
index 0473231..3afdf49 100644
--- a/main.go
+++ b/main.go
@@ -3,13 +3,15 @@ package main
import (
"log"
+ "database/sql"
)
var config RuntimeConfig
+var db *sql.DB
func main() {
config = GetRuntimeConfig()
log.Printf("Starting ceres with config file '%s'\n", config.Path)
- db := InitDatabase()
- runServer(&db)
+ db = setupDatabase()
+ runServer()
}
diff --git a/mux.go b/mux.go
index 3aed42d..2b37286 100644
--- a/mux.go
+++ b/mux.go
@@ -5,37 +5,37 @@ import (
"net/http"
)
-func indexMux(db *Database) func(http.ResponseWriter, *http.Request) {
+func indexMux() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
- indexGet(w, r, db)
+ indexGet(w, r)
default:
http.Error(w, "Bad Request", 400)
}
}
}
-func recipeMux(db *Database) func(http.ResponseWriter, *http.Request) {
+func recipeMux() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
- recipeGet(w, r, db)
+ recipeGet(w, r)
case "POST":
- recipePost(w, r, db)
+ recipePost(w, r)
default:
http.Error(w, "Bad Request", 400)
}
}
}
-func recipeEditMux(db *Database) func(http.ResponseWriter, *http.Request) {
+func recipeEditMux() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
- recipeEditGet(w, r, db)
+ recipeEditGet(w, r)
case "POST":
- recipeEditPost(w, r, db)
+ recipeEditPost(w, r)
default:
http.Error(w, "Bad Request", 400)
}
@@ -53,13 +53,13 @@ func recipeImageMux() func(http.ResponseWriter, *http.Request) {
}
}
-func addRecipesMux(db *Database) func(http.ResponseWriter, *http.Request) {
+func addRecipesMux() func(http.ResponseWriter, *http.Request) {
return func(w http.ResponseWriter, r *http.Request) {
switch r.Method {
case "GET":
addRecipesGet(w, r)
case "POST":
- addRecipesPost(w, r, db)
+ addRecipesPost(w, r)
default:
http.Error(w, "Bad Request", 400)
}
diff --git a/server.go b/server.go
index f7f67f4..d70373b 100644
--- a/server.go
+++ b/server.go
@@ -6,20 +6,20 @@ import (
"net/http"
)
-func setupRoutes(db *Database) {
+func setupRoutes() {
- http.HandleFunc("/", indexMux(db))
- http.HandleFunc("/recipe", recipeMux(db))
- http.HandleFunc("/recipe/edit", recipeEditMux(db))
+ http.HandleFunc("/", indexMux())
+ http.HandleFunc("/recipe", recipeMux())
+ http.HandleFunc("/recipe/edit", recipeEditMux())
http.HandleFunc("/recipe/image", recipeImageMux())
- http.HandleFunc("/add_recipes", addRecipesMux(db))
+ http.HandleFunc("/add_recipes", addRecipesMux())
http.HandleFunc("/static/style.css", staticStyleMux())
http.HandleFunc("/favicon.ico", faviconMux())
}
-func runServer(db *Database) {
+func runServer() {
- setupRoutes(db)
+ setupRoutes()
address := config.Http.Host + ":" + config.Http.Port
log.Println("Binding to 'http://" + address)
log.Fatal(http.ListenAndServe(address, nil))