diff --git a/store.go b/store.go
index 4a9ca96..5575ddc 100644
--- a/store.go
+++ b/store.go
@@ -56,6 +56,7 @@ type ShareStore interface {
GetShares() ([]Share, error)
GetShareUsers(share Share) ([]ShareUser, error)
GetShareLogins(share Share, username string) ([]Login, error)
+ GetShareAccess(share Share, username string) (ShareRole, error)
FindShareByLogin(username, loginName string) (LoginShare, error)
FindSharesByUser(username string) ([]UserShare, error)
@@ -384,6 +385,30 @@ func (store *DBStore) AddUserToShare(share Share, username string, role ShareRol
})
}
+func (store *DBStore) GetShareAccess(share Share, username string) (ShareRole, error) {
+ if strings.Contains(username, ":") {
+ return "", ErrInvalidUsername
+ }
+
+ var shareRole ShareRole
+
+ err := store.db.View(func(tx *buntdb.Tx) error {
+ if _, err := tx.Get(share.key()); err != nil {
+ return ErrShareNotFound
+ }
+ if val, err := tx.Get(share.userKey(username)); err == buntdb.ErrNotFound {
+ return ErrUserNotFound
+ } else if err != nil {
+ return fmt.Errorf("cannot get user: %w", err)
+ } else {
+ shareRole = ShareRole(val)
+ }
+ return nil
+ })
+
+ return shareRole, err
+}
+
func (store *DBStore) removeUserFromShare(tx *buntdb.Tx, share Share, username string) error {
var logins []string
loginsPrefix := share.loginKey(username, "")
diff --git a/store_test.go b/store_test.go
index 030f14d..4d3612e 100644
--- a/store_test.go
+++ b/store_test.go
@@ -244,6 +244,23 @@ func TestStoreShareHandling(t *testing.T) {
}
})
+ t.Run("can get share access", func(t *testing.T) {
+ shareRole, err := store.GetShareAccess(share1, user1.Username)
+ if err != nil {
+ t.Errorf("unexpected error: %v", err)
+ }
+ if shareRole != ShareRoleAdmin {
+ t.Errorf("unexpected role: %s", shareRole)
+ }
+ })
+
+ t.Run("cannot get share access if user has none", func(t *testing.T) {
+ _, err := store.GetShareAccess(share2, user1.Username)
+ if err != ErrUserNotFound {
+ t.Errorf("unexpected error: %v", err)
+ }
+ })
+
t.Run("can list users", func(t *testing.T) {
users, err := store.GetShareUsers(share1)
if err != nil {
diff --git a/webadmin.go b/webadmin.go
index ff8f312..146766b 100644
--- a/webadmin.go
+++ b/webadmin.go
@@ -57,6 +57,7 @@ type sessionContext struct {
h *webAdminHandler
w http.ResponseWriter
r *http.Request
+ user User
baseModel map[string]interface{}
}
@@ -181,6 +182,11 @@ func newWebAdminHandler(app *app) *webAdminHandler {
ar.Get("/users", func(w http.ResponseWriter, r *http.Request) {
sessionContext := h.buildSessionContext(w, r)
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+
users, err := app.userStore.GetUsers()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -194,6 +200,11 @@ func newWebAdminHandler(app *app) *webAdminHandler {
ar.Get("/shares", func(w http.ResponseWriter, r *http.Request) {
sessionContext := h.buildSessionContext(w, r)
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+
shares, err := app.shareStore.GetShares()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -224,12 +235,20 @@ func newWebAdminHandler(app *app) *webAdminHandler {
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
sessionContext := h.buildSessionContext(w, r)
- shareId := r.URL.Query().Get("share")
- if shareId == "" {
- http.Error(w, "invalid share id", http.StatusBadRequest)
+ share, err := app.shareStore.GetShare(r.URL.Query().Get("share"))
+ if err != nil {
+ sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "")
return
}
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username)
+ if err != nil || shareRole != ShareRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+ }
+
users, err := app.userStore.GetUsers()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@@ -237,7 +256,7 @@ func newWebAdminHandler(app *app) *webAdminHandler {
}
sessionContext.RenderPage(h.tplShareAddUser, map[string]interface{}{
- "ShareId": shareId,
+ "ShareId": share.UUID,
"Users": users,
})
})
@@ -246,12 +265,19 @@ func newWebAdminHandler(app *app) *webAdminHandler {
returnURL := "share-add-user?share=" + r.FormValue("share")
- shareId, err := uuid.FromString(r.FormValue("share"))
+ share, err := app.shareStore.GetShare(r.FormValue("share"))
if err != nil {
sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "")
return
}
- share := Share{UUID: shareId}
+
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username)
+ if err != nil || shareRole != ShareRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+ }
user, err := app.userStore.GetUser(r.FormValue("user"))
if err == ErrUserNotFound {
@@ -276,12 +302,19 @@ func newWebAdminHandler(app *app) *webAdminHandler {
returnURL := "shares"
- shareId, err := uuid.FromString(r.FormValue("share"))
+ share, err := app.shareStore.GetShare(r.FormValue("share"))
if err != nil {
sessionContext.RenderError(template.HTML("Internal error: "+err.Error()), "")
return
}
- share := Share{UUID: shareId}
+
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username)
+ if err != nil || shareRole != ShareRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+ }
err = app.shareStore.RemoveUserFromShare(share, r.FormValue("user"))
if err != nil {
@@ -325,6 +358,14 @@ func newWebAdminHandler(app *app) *webAdminHandler {
return
}
+ if sessionContext.user.Role != GlobalRoleAdmin {
+ shareRole, err := app.shareStore.GetShareAccess(share, sessionContext.user.Username)
+ if err != nil || shareRole != ShareRoleAdmin {
+ sessionContext.Unauthorized()
+ return
+ }
+ }
+
message := fmt.Sprintf(`You are about to delete the share %s (%s).
This will delete all data permanently.
Are you sure you want to continue?`, share.UUID, share.Name)
@@ -357,6 +398,7 @@ func (h *webAdminHandler) buildSessionContext(w http.ResponseWriter, r *http.Req
}
sessionUser := userFromContext(r)
if sessionUser != nil {
+ sessionContext.user = *sessionUser
sessionContext.baseModel["SessionUser"] = sessionUser
}
return sessionContext
@@ -387,6 +429,10 @@ func (s *sessionContext) RenderError(msg template.HTML, returnURL string) {
s.RenderPage(s.h.tplError, model)
}
+func (s *sessionContext) Unauthorized() {
+ http.Error(s.w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
+}
+
func (s *sessionContext) RequestConfirmation(msg template.HTML, returnURL string) int {
if s.r.FormValue("_yes") != "" {
return confirmAccepted