diff options
author | xengineering <me@xengineering.eu> | 2023-02-11 13:46:45 +0100 |
---|---|---|
committer | xengineering <me@xengineering.eu> | 2023-02-11 13:46:45 +0100 |
commit | 6b52b4bbc81c64b29daff043b83e21c38f332052 (patch) | |
tree | 3a5d41f75228fe6ece3df5c109365d54b0ba6d04 | |
parent | 312c59564700da719dbafad11c2d9f647b46d912 (diff) | |
download | ceres-6b52b4bbc81c64b29daff043b83e21c38f332052.tar ceres-6b52b4bbc81c64b29daff043b83e21c38f332052.tar.zst ceres-6b52b4bbc81c64b29daff043b83e21c38f332052.zip |
Switch to global database pointer
Passing the database pointer around is a lot of text and has no benefit.
-rw-r--r-- | database.go | 79 | ||||
-rw-r--r-- | handler.go | 28 | ||||
-rw-r--r-- | main.go | 6 | ||||
-rw-r--r-- | mux.go | 20 | ||||
-rw-r--r-- | server.go | 14 |
5 files changed, 62 insertions, 85 deletions
diff --git a/database.go b/database.go index d19572a..1efd22e 100644 --- a/database.go +++ b/database.go @@ -20,17 +20,26 @@ import ( const databaseSchemaVersion int = 2 // this defines the needed version for the // executable -type Database struct { - target string - Backend *sql.DB -} +func setupDatabase() *sql.DB { + + u,err := user.Current() + if err != nil { + log.Fatal(err) + } + target := fmt.Sprintf("%s@unix(%s)/%s", u.Username, config.Database.Socket, + config.Database.Database) + + db,err := sql.Open("mysql", target) + if err != nil { + log.Fatal(err) + } -func InitDatabase() Database { + err = db.Ping() + if err != nil { + log.Fatal(err) + } - db := NewDatabase() - db.Connect() - db.Ping() - db.Migrate() + migrate(db) // allow graceful shutdown var listener = make(chan os.Signal) @@ -39,53 +48,19 @@ func InitDatabase() Database { go func() { signal := <-listener log.Printf("\nGot signal '%+v'. Shutting down ...\n", signal) - db.Cleanup() + dbCleanup(db) os.Exit(0) // TODO this does not belong to a database - write utils file 'shutdown.go' }() return db } -func NewDatabase() Database { - - db := Database{} - - var username string - user_ptr,err := user.Current() - if err != nil { - log.Fatal(err) - } - username = user_ptr.Username - db.target = fmt.Sprintf("%s@unix(%s)/%s", username, config.Database.Socket, - config.Database.Database) - - return db -} - -func (db *Database) Connect() { - var err error - db.Backend,err = sql.Open("mysql", db.target) - if err != nil { - log.Fatal(err) - } - log.Printf("Connected to database '%s'\n", db.target) -} - -func (db *Database) Ping() { - err := db.Backend.Ping() - if err != nil { - log.Fatal(err) - } else { - log.Println("Database is responding") - } -} - -func (db *Database) Migrate() { +func migrate(db *sql.DB) { const t = databaseSchemaVersion // targeted database schema version for { - v := db.SchemaVersion() // read schema version from DB table + v := schemaVersion(db) // read schema version from DB table // handle current database schema which is newer than targeted one if v > t { @@ -102,12 +77,12 @@ func (db *Database) Migrate() { log.Printf("Starting database schema migration to version %d.\n", v+1) path := filepath.Join(config.Database.Migrations, fmt.Sprintf("%04d_migration.sql", v+1)) - RunSql(path) + RunSqlScript(path) log.Printf("Finished database schema migration to version %d.\n", v+1) } } -func RunSql(path string) { +func RunSqlScript(path string) { script, err := os.Open(path) if err != nil { @@ -133,11 +108,11 @@ func RunSql(path string) { } } -func (db *Database) SchemaVersion() int { +func schemaVersion(db *sql.DB) int { // ask database for schema version cmd := "SELECT value FROM meta WHERE (identifier='version');" - rows, err := db.Backend.Query(cmd) + rows, err := db.Query(cmd) // handle missing meta table if err != nil { @@ -165,8 +140,8 @@ func (db *Database) SchemaVersion() int { return v } -func (db *Database) Cleanup() { - err := db.Backend.Close() +func dbCleanup(db *sql.DB) { + err := db.Close() if err != nil { log.Println("Could not close database connection") } else { @@ -17,12 +17,12 @@ const ( VALID_ID_REGEX = `^[0-9]+$` ) -func indexGet(w http.ResponseWriter, r *http.Request, db *Database) { +func indexGet(w http.ResponseWriter, r *http.Request) { // get data from database cmd := "SELECT id,title FROM recipes ORDER BY title;" log.Printf("Query: %s", cmd) - rows, err := db.Backend.Query(cmd) + rows, err := db.Query(cmd) if err != nil { http.Error(w, "Failed to load recipes from database.", 500) return @@ -53,7 +53,7 @@ func indexGet(w http.ResponseWriter, r *http.Request, db *Database) { ServeTemplate(w, "index", path, elements) } -func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) { +func recipeGet(w http.ResponseWriter, r *http.Request) { // get id from URL parameters ids := r.URL.Query()["id"] @@ -74,7 +74,7 @@ func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) { // get data from database cmd := fmt.Sprintf("SELECT title,upstream_url,description_markdown FROM recipes WHERE (id='%s');", idStr) log.Printf("Query: %s", cmd) - rows, err := db.Backend.Query(cmd) + rows, err := db.Query(cmd) if err != nil { http.Error(w, "Database returned error: " + err.Error(), 500) return @@ -120,7 +120,7 @@ func recipeGet(w http.ResponseWriter, r *http.Request, db *Database) { ServeTemplate(w, "recipe", path, elements[0]) } -func recipePost(w http.ResponseWriter, r *http.Request, db *Database) { +func recipePost(w http.ResponseWriter, r *http.Request) { // get id from URL parameters ids := r.URL.Query()["id"] @@ -141,10 +141,10 @@ func recipePost(w http.ResponseWriter, r *http.Request, db *Database) { // read request body buffer,_ := ioutil.ReadAll(r.Body) // FIXME error handling body := string(buffer) - updateRecipe(db, body, idStr) + updateRecipe(body, idStr) } -func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) { +func recipeEditGet(w http.ResponseWriter, r *http.Request) { // get id from URL parameters ids := r.URL.Query()["id"] @@ -164,7 +164,7 @@ func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) { // get data from database cmd := fmt.Sprintf("SELECT title,upstream_url,description_markdown FROM recipes WHERE (id='%s');", idStr) log.Printf("Query: %s", cmd) - rows, err := db.Backend.Query(cmd) + rows, err := db.Query(cmd) if err != nil { http.Error(w, "Got error from database: " + err.Error(), 500) return @@ -205,7 +205,7 @@ func recipeEditGet(w http.ResponseWriter, r *http.Request, db *Database) { ServeTemplate(w, "recipe", path, elements[0]) } -func recipeEditPost(w http.ResponseWriter, r *http.Request, db *Database) { +func recipeEditPost(w http.ResponseWriter, r *http.Request) { // get id from URL parameters ids := r.URL.Query()["id"] @@ -225,13 +225,13 @@ func recipeEditPost(w http.ResponseWriter, r *http.Request, db *Database) { // read request body buffer,_ := ioutil.ReadAll(r.Body) // FIXME error handling body := string(buffer) - updateRecipe(db, body, idStr) + updateRecipe(body, idStr) } -func updateRecipe(db *Database, body string, idStr string) { +func updateRecipe(body string, idStr string) { // execute SQL UPDATE - _,_ = db.Backend.Exec(` + _,_ = db.Exec(` UPDATE recipes SET @@ -275,14 +275,14 @@ func addRecipesGet(w http.ResponseWriter, r *http.Request) { http.ServeFile(w, r, path) } -func addRecipesPost(w http.ResponseWriter, r *http.Request, db *Database) { +func addRecipesPost(w http.ResponseWriter, r *http.Request) { url := r.FormValue("url") title := r.FormValue("title") cmd := fmt.Sprintf("INSERT INTO recipes (title,upstream_url) VALUES ('%s', '%s')", title, url) log.Println(cmd) - res,err := db.Backend.Exec(cmd) + res,err := db.Exec(cmd) if err != nil { http.Error(w, "Could not add recipe.", 500) return @@ -3,13 +3,15 @@ package main import ( "log" + "database/sql" ) var config RuntimeConfig +var db *sql.DB func main() { config = GetRuntimeConfig() log.Printf("Starting ceres with config file '%s'\n", config.Path) - db := InitDatabase() - runServer(&db) + db = setupDatabase() + runServer() } @@ -5,37 +5,37 @@ import ( "net/http" ) -func indexMux(db *Database) func(http.ResponseWriter, *http.Request) { +func indexMux() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": - indexGet(w, r, db) + indexGet(w, r) default: http.Error(w, "Bad Request", 400) } } } -func recipeMux(db *Database) func(http.ResponseWriter, *http.Request) { +func recipeMux() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": - recipeGet(w, r, db) + recipeGet(w, r) case "POST": - recipePost(w, r, db) + recipePost(w, r) default: http.Error(w, "Bad Request", 400) } } } -func recipeEditMux(db *Database) func(http.ResponseWriter, *http.Request) { +func recipeEditMux() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": - recipeEditGet(w, r, db) + recipeEditGet(w, r) case "POST": - recipeEditPost(w, r, db) + recipeEditPost(w, r) default: http.Error(w, "Bad Request", 400) } @@ -53,13 +53,13 @@ func recipeImageMux() func(http.ResponseWriter, *http.Request) { } } -func addRecipesMux(db *Database) func(http.ResponseWriter, *http.Request) { +func addRecipesMux() func(http.ResponseWriter, *http.Request) { return func(w http.ResponseWriter, r *http.Request) { switch r.Method { case "GET": addRecipesGet(w, r) case "POST": - addRecipesPost(w, r, db) + addRecipesPost(w, r) default: http.Error(w, "Bad Request", 400) } @@ -6,20 +6,20 @@ import ( "net/http" ) -func setupRoutes(db *Database) { +func setupRoutes() { - http.HandleFunc("/", indexMux(db)) - http.HandleFunc("/recipe", recipeMux(db)) - http.HandleFunc("/recipe/edit", recipeEditMux(db)) + http.HandleFunc("/", indexMux()) + http.HandleFunc("/recipe", recipeMux()) + http.HandleFunc("/recipe/edit", recipeEditMux()) http.HandleFunc("/recipe/image", recipeImageMux()) - http.HandleFunc("/add_recipes", addRecipesMux(db)) + http.HandleFunc("/add_recipes", addRecipesMux()) http.HandleFunc("/static/style.css", staticStyleMux()) http.HandleFunc("/favicon.ico", faviconMux()) } -func runServer(db *Database) { +func runServer() { - setupRoutes(db) + setupRoutes() address := config.Http.Host + ":" + config.Http.Port log.Println("Binding to 'http://" + address) log.Fatal(http.ListenAndServe(address, nil)) |