diff --git a/store.go b/store.go index 4a9ca96..5575ddc 100644 --- a/store.go +++ b/store.go @@ -56,6 +56,7 @@ type ShareStore interface { GetShares() ([]Share, error) GetShareUsers(share Share) ([]ShareUser, error) GetShareLogins(share Share, username string) ([]Login, error) + GetShareAccess(share Share, username string) (ShareRole, error) FindShareByLogin(username, loginName string) (LoginShare, error) FindSharesByUser(username string) ([]UserShare, error) @@ -384,6 +385,30 @@ func (store *DBStore) AddUserToShare(share Share, username string, role ShareRol }) } +func (store *DBStore) GetShareAccess(share Share, username string) (ShareRole, error) { + if strings.Contains(username, ":") { + return "", ErrInvalidUsername + } + + var shareRole ShareRole + + err := store.db.View(func(tx *buntdb.Tx) error { + if _, err := tx.Get(share.key()); err != nil { + return ErrShareNotFound + } + if val, err := tx.Get(share.userKey(username)); err == buntdb.ErrNotFound { + return ErrUserNotFound + } else if err != nil { + return fmt.Errorf("cannot get user: %w", err) + } else { + shareRole = ShareRole(val) + } + return nil + }) + + return shareRole, err +} + func (store *DBStore) removeUserFromShare(tx *buntdb.Tx, share Share, username string) error { var logins []string loginsPrefix := share.loginKey(username, "") diff --git a/store_test.go b/store_test.go index 030f14d..4d3612e 100644 --- a/store_test.go +++ b/store_test.go @@ -244,6 +244,23 @@ func TestStoreShareHandling(t *testing.T) { } }) + t.Run("can get share access", func(t *testing.T) { + shareRole, err := store.GetShareAccess(share1, user1.Username) + if err != nil { + t.Errorf("unexpected error: %v", err) + } + if shareRole != ShareRoleAdmin { + t.Errorf("unexpected role: %s", shareRole) + } + }) + + t.Run("cannot get share access if user has none", func(t *testing.T) { + _, err := store.GetShareAccess(share2, user1.Username) + if err != ErrUserNotFound { + t.Errorf("unexpected error: %v", err) + } + }) + t.Run("can list users", func(t *testing.T) { users, err := store.GetShareUsers(share1) if err != nil { diff --git a/webadmin.go b/webadmin.go index ff8f312..146766b 100644 --- a/webadmin.go +++ b/webadmin.go @@ -57,6 +57,7 @@ type sessionContext struct { h *webAdminHandler w http.ResponseWriter r *http.Request + user User baseModel map[string]interface{} } @@ -181,6 +182,11 @@ func newWebAdminHandler(app *app) *webAdminHandler { ar.Get("/users", func(w http.ResponseWriter, r *http.Request) { sessionContext := h.buildSessionContext(w, r) + if sessionContext.user.Role != GlobalRoleAdmin { + sessionContext.Unauthorized() + return + } + users, err := app.userStore.GetUsers() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -194,6 +200,11 @@ func newWebAdminHandler(app *app) *webAdminHandler { ar.Get("/shares", func(w http.ResponseWriter, r *http.Request) { sessionContext := h.buildSessionContext(w, r) + if sessionContext.user.Role != GlobalRoleAdmin { + sessionContext.Unauthorized() + return + } + shares, err := app.shareStore.GetShares() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -224,12 +235,20 @@ func newWebAdminHandler(app *app) *webAdminHandler { r.Get("/", func(w http.ResponseWriter, r *http.Request) { sessionContext := h.buildSessionContext(w, r) - shareId := r.URL.Query().Get("share") - if shareId == "" { - http.Error(w, "invalid share id", http.StatusBadRequest) + share, err := app.shareStore.GetShare(r.URL.Query().Get("share")) + if err != nil { + sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "") return } + if sessionContext.user.Role != GlobalRoleAdmin { + shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username) + if err != nil || shareRole != ShareRoleAdmin { + sessionContext.Unauthorized() + return + } + } + users, err := app.userStore.GetUsers() if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -237,7 +256,7 @@ func newWebAdminHandler(app *app) *webAdminHandler { } sessionContext.RenderPage(h.tplShareAddUser, map[string]interface{}{ - "ShareId": shareId, + "ShareId": share.UUID, "Users": users, }) }) @@ -246,12 +265,19 @@ func newWebAdminHandler(app *app) *webAdminHandler { returnURL := "share-add-user?share=" + r.FormValue("share") - shareId, err := uuid.FromString(r.FormValue("share")) + share, err := app.shareStore.GetShare(r.FormValue("share")) if err != nil { sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "") return } - share := Share{UUID: shareId} + + if sessionContext.user.Role != GlobalRoleAdmin { + shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username) + if err != nil || shareRole != ShareRoleAdmin { + sessionContext.Unauthorized() + return + } + } user, err := app.userStore.GetUser(r.FormValue("user")) if err == ErrUserNotFound { @@ -276,12 +302,19 @@ func newWebAdminHandler(app *app) *webAdminHandler { returnURL := "shares" - shareId, err := uuid.FromString(r.FormValue("share")) + share, err := app.shareStore.GetShare(r.FormValue("share")) if err != nil { sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "") return } - share := Share{UUID: shareId} + + if sessionContext.user.Role != GlobalRoleAdmin { + shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username) + if err != nil || shareRole != ShareRoleAdmin { + sessionContext.Unauthorized() + return + } + } err = app.shareStore.RemoveUserFromShare(share, r.FormValue("user")) if err != nil { @@ -325,6 +358,14 @@ func newWebAdminHandler(app *app) *webAdminHandler { return } + if sessionContext.user.Role != GlobalRoleAdmin { + shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username) + if err != nil || shareRole != ShareRoleAdmin { + sessionContext.Unauthorized() + return + } + } + message := fmt.Sprintf(`You are about to delete the share %s (%s).
This will delete all data permanently.

Are you sure you want to continue?`, share.UUID, share.Name) @@ -357,6 +398,7 @@ func (h *webAdminHandler) buildSessionContext(w http.ResponseWriter, r *http.Req } sessionUser := userFromContext(r) if sessionUser != nil { + sessionContext.user = *sessionUser sessionContext.baseModel["SessionUser"] = sessionUser } return sessionContext @@ -387,6 +429,10 @@ func (s *sessionContext) RenderError(msg template.HTML, returnURL string) { s.RenderPage(s.h.tplError, model) } +func (s *sessionContext) Unauthorized() { + http.Error(s.w, http.StatusText(http.StatusForbidden), http.StatusForbidden) +} + func (s *sessionContext) RequestConfirmation(msg template.HTML, returnURL string) int { if s.r.FormValue("_yes") != "" { return confirmAccepted