]> code.octet-stream.net Git - broadcaster/commitdiff
Require authentication consistently
authorThomas Karpiniec <tom.karpiniec@outlook.com>
Wed, 23 Oct 2024 01:14:10 +0000 (12:14 +1100)
committerThomas Karpiniec <tom.karpiniec@outlook.com>
Wed, 23 Oct 2024 01:14:10 +0000 (12:14 +1100)
server/main.go

index 8f76cbfacff9d9afb601c6b762d320203bc34daf..5f0a7487c3aff8e0082588386f0f32abc745ae3b 100644 (file)
@@ -22,6 +22,7 @@ const formatString = "2006-01-02T15:04"
 
 //go:embed templates/*
 var content embed.FS
+
 //var content = os.DirFS("../broadcaster-server/")
 
 var config ServerConfig = NewServerConfig()
@@ -79,36 +80,71 @@ func main() {
 
        // Authenticated routes
 
-       http.HandleFunc("/", homePage)
-       http.HandleFunc("/logout", logOutPage)
-       http.HandleFunc("/change-password", changePasswordPage)
+       http.Handle("/", requireUser(homePage))
+       http.Handle("/logout", requireUser(logOutPage))
+       http.Handle("/change-password", requireUser(changePasswordPage))
 
-       http.HandleFunc("/playlists/", playlistSection)
-       http.HandleFunc("/files/", fileSection)
-       http.HandleFunc("/radios/", radioSection)
+       http.Handle("/playlists/", requireUser(playlistSection))
+       http.Handle("/files/", requireUser(fileSection))
+       http.Handle("/radios/", requireUser(radioSection))
 
-       http.Handle("/radio-ws", websocket.Handler(RadioSync))
-       http.Handle("/web-ws", websocket.Handler(WebSync))
-       http.HandleFunc("/stop", stopPage)
+       http.Handle("/stop", requireUser(stopPage))
 
        // Admin routes
 
+       // TODO: user management
+
+       // Websocket routes, which perform their own auth
+
+       http.Handle("/radio-ws", websocket.Handler(RadioSync))
+       http.Handle("/web-ws", websocket.Handler(WebSync))
+
        err := http.ListenAndServe(config.BindAddress+":"+strconv.Itoa(config.Port), nil)
        if err != nil {
                log.Fatal(err)
        }
 }
 
+type authenticatedHandler func(http.ResponseWriter, *http.Request, User)
+
+type AuthMiddleware struct {
+       handler     authenticatedHandler
+       mustBeAdmin bool
+}
+
+func (m AuthMiddleware) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+       user, err := currentUser(w, r)
+       if err != nil || (m.mustBeAdmin && !user.IsAdmin) {
+               http.Redirect(w, r, "/login", http.StatusFound)
+               return
+       }
+       m.handler(w, r, user)
+}
+
+func requireUser(handler authenticatedHandler) AuthMiddleware {
+       return AuthMiddleware{
+               handler:     handler,
+               mustBeAdmin: false,
+       }
+}
+
+func requireAdmin(handler authenticatedHandler) AuthMiddleware {
+       return AuthMiddleware{
+               handler:     handler,
+               mustBeAdmin: true,
+       }
+}
+
 type HeaderData struct {
        SelectedMenu string
-       Username string
+       Username     string
 }
 
 func renderHeader(w http.ResponseWriter, selectedMenu string) {
        tmpl := template.Must(template.ParseFS(content, "templates/header.html"))
        data := HeaderData{
                SelectedMenu: selectedMenu,
-               Username: "username",
+               Username:     "username",
        }
        err := tmpl.Execute(w, data)
        if err != nil {
@@ -129,7 +165,7 @@ type HomeData struct {
        Username string
 }
 
-func homePage(w http.ResponseWriter, r *http.Request) {
+func homePage(w http.ResponseWriter, r *http.Request, user User) {
        renderHeader(w, "status")
        tmpl := template.Must(template.ParseFS(content, "templates/index.html"))
        data := HomeData{
@@ -145,7 +181,6 @@ type LogInData struct {
 }
 
 func logInPage(w http.ResponseWriter, r *http.Request) {
-       log.Println("Log in page!")
        r.ParseForm()
        username := r.Form["username"]
        password := r.Form["password"]
@@ -170,7 +205,7 @@ func logInPage(w http.ResponseWriter, r *http.Request) {
        renderFooter(w)
 }
 
-func playlistSection(w http.ResponseWriter, r *http.Request) {
+func playlistSection(w http.ResponseWriter, r *http.Request, user User) {
        path := strings.Split(r.URL.Path, "/")
        if len(path) != 3 {
                http.NotFound(w, r)
@@ -194,7 +229,7 @@ func playlistSection(w http.ResponseWriter, r *http.Request) {
        }
 }
 
-func fileSection(w http.ResponseWriter, r *http.Request) {
+func fileSection(w http.ResponseWriter, r *http.Request, user User) {
        path := strings.Split(r.URL.Path, "/")
        if len(path) != 3 {
                http.NotFound(w, r)
@@ -212,7 +247,7 @@ func fileSection(w http.ResponseWriter, r *http.Request) {
        }
 }
 
-func radioSection(w http.ResponseWriter, r *http.Request) {
+func radioSection(w http.ResponseWriter, r *http.Request, user User) {
        path := strings.Split(r.URL.Path, "/")
        if len(path) != 3 {
                http.NotFound(w, r)
@@ -241,12 +276,7 @@ type ChangePasswordPageData struct {
        ShowForm bool
 }
 
-func changePasswordPage(w http.ResponseWriter, r *http.Request) {
-       user, err := currentUser(w, r)
-       if err != nil {
-               http.Redirect(w, r, "/login", http.StatusFound)
-               return
-       }
+func changePasswordPage(w http.ResponseWriter, r *http.Request, user User) {
        var data ChangePasswordPageData
        if r.Method == "POST" {
                err := r.ParseForm()
@@ -275,7 +305,7 @@ func changePasswordPage(w http.ResponseWriter, r *http.Request) {
        }
        renderHeader(w, "change-password")
        tmpl := template.Must(template.ParseFS(content, "templates/change_password.html"))
-       err = tmpl.Execute(w, data)
+       err := tmpl.Execute(w, data)
        if err != nil {
                log.Fatal(err)
        }
@@ -512,7 +542,7 @@ func uploadFile(w http.ResponseWriter, r *http.Request) {
        http.Redirect(w, r, "/files/", http.StatusFound)
 }
 
-func logOutPage(w http.ResponseWriter, r *http.Request) {
+func logOutPage(w http.ResponseWriter, r *http.Request, user User) {
        clearSessionCookie(w)
        renderHeader(w, "")
        tmpl := template.Must(template.ParseFS(content, "templates/logout.html"))
@@ -520,12 +550,7 @@ func logOutPage(w http.ResponseWriter, r *http.Request) {
        renderFooter(w)
 }
 
-func stopPage(w http.ResponseWriter, r *http.Request) {
-       _, err := currentUser(w, r)
-       if err != nil {
-               http.Redirect(w, r, "/login", http.StatusFound)
-               return
-       }
+func stopPage(w http.ResponseWriter, r *http.Request, user User) {
        r.ParseForm()
        radioId, err := strconv.Atoi(r.Form.Get("radioId"))
        if err != nil {