diff --git a/store.go b/store.go index 3aa562b..19661a3 100644 --- a/store.go +++ b/store.go @@ -26,6 +26,7 @@ package main import ( + "encoding/json" "errors" "fmt" "strings" @@ -64,9 +65,10 @@ const ( ) type User struct { - Username string - Password string - Role GlobalRole + Username string + Password string `json:"-"` + PasswordHash string `json:"Password"` + Role GlobalRole } type Share struct { @@ -96,18 +98,6 @@ func (u User) key() string { return usersPrefix + u.Username } -func (u User) prefix() string { - return fmt.Sprintf("user:%s", u.Username) -} - -func (u User) roleKey() string { - return u.prefix() + ":role" -} - -func (u User) passwordKey() string { - return u.prefix() + ":password" -} - func NewDBStore(filename string) (*DBStore, error) { db, err := buntdb.Open(filename) if err != nil { @@ -128,24 +118,21 @@ var ErrUserExists = errors.New("user already exists") var ErrUserNotFound = errors.New("user not found") var ErrInvalidUsername = errors.New("invalid username") -func (store *DBStore) setUserValues(tx *buntdb.Tx, user User) (exists bool, err error) { - pwString, err := buildPassword(user.Username, user.Password) - if err != nil { - return false, fmt.Errorf("cannot hash password: %w", err) +func (u User) merge(updates User) (User, error) { + merged := u + + if updates.Password != "" { + pwHash, err := hashPassword(updates.Password) + if err != nil { + return u, fmt.Errorf("cannot hash password: %w", err) + } + merged.PasswordHash = pwHash + } + if updates.Role != "" { + merged.Role = updates.Role } - if _, replaced, err := tx.Set(user.roleKey(), string(user.Role), nil); err != nil { - return false, err - } else if replaced { - exists = true - } - if _, replaced, err := tx.Set(user.passwordKey(), pwString, nil); err != nil { - return false, err - } else if replaced { - exists = true - } - - return exists, nil + return merged, nil } func (store *DBStore) AddUser(user User) (err error) { @@ -153,17 +140,17 @@ func (store *DBStore) AddUser(user User) (err error) { return ErrInvalidUsername } if err = store.db.Update(func(tx *buntdb.Tx) error { - if _, exists, err := tx.Set(user.key(), "", nil); err != nil { + userBytes, err := json.Marshal(user) + if err != nil { + return fmt.Errorf("cannot marshal user: %w", err) + } + + if _, exists, err := tx.Set(user.key(), string(userBytes), nil); err != nil { return err } else if exists { return ErrUserExists } - if exists, err := store.setUserValues(tx, user); err != nil { - return err - } else if exists { - return ErrUserExists - } return nil }); err != nil { return err @@ -175,22 +162,16 @@ func (store *DBStore) GetUser(username string) (user User, err error) { if strings.Contains(username, ":") { return user, ErrInvalidUsername } - user.Username = username if err := store.db.View(func(tx *buntdb.Tx) error { - if val, err := tx.Get(user.roleKey()); err != nil && err != buntdb.ErrNotFound { + user.Username = username + if val, err := tx.Get(user.key()); err != nil && err != buntdb.ErrNotFound { return err } else if err == buntdb.ErrNotFound { return ErrUserNotFound - } else { - user.Role = GlobalRole(val) - } - if val, err := tx.Get(user.passwordKey()); err != nil && err != buntdb.ErrNotFound { - return err - } else if err == buntdb.ErrNotFound { - return ErrUserNotFound - } else { - user.Password = val + } else if err := json.Unmarshal([]byte(val), &user); err != nil { + return fmt.Errorf("cannot unmarshal user: %w", err) } + return nil }); err != nil { return user, err @@ -200,24 +181,26 @@ func (store *DBStore) GetUser(username string) (user User, err error) { func (store *DBStore) GetUsers() (users []User, err error) { err = store.db.View(func(tx *buntdb.Tx) error { - if err := tx.AscendKeys(usersPrefix+"*", func(key, value string) bool { + var processingError error + + if err := tx.AscendKeys(userPrefix+"*", func(key, value string) bool { var user User - user.Username = strings.TrimPrefix(key, usersPrefix) + + if err := json.Unmarshal([]byte(value), &user); err != nil { + processingError = err + return false + } + + // Just in case ... + user.Username = strings.TrimPrefix(key, userPrefix) + users = append(users, user) return true }); err != nil { return err } - for i := range users { - if roleString, err := tx.Get(users[i].roleKey()); err != nil { - return err - } else { - users[i].Role = GlobalRole(roleString) - } - } - - return nil + return processingError }) return users, err @@ -228,7 +211,29 @@ func (store *DBStore) Update(user User) error { return ErrInvalidUsername } return store.db.Update(func(tx *buntdb.Tx) error { - _, err := store.setUserValues(tx, user) + var existingUser User + + if val, err := tx.Get(user.key()); err != nil && err != buntdb.ErrNotFound { + return err + } else if err == buntdb.ErrNotFound { + return ErrUserNotFound + } else if err := json.Unmarshal([]byte(val), &existingUser); err != nil { + return fmt.Errorf("cannot unmarshal user: %w", err) + } + + mergedUser, err := existingUser.merge(user) + if err != nil { + return fmt.Errorf("cannot merge user: %w", err) + } + mergedUser.Username = user.Username + + userBytes, err := json.Marshal(mergedUser) + if err != nil { + return fmt.Errorf("cannot marshal user: %w", err) + } + + _, _, err = tx.Set(mergedUser.key(), string(userBytes), nil) + return err }) } @@ -239,35 +244,19 @@ func (store *DBStore) RemoveUser(username string) (err error) { } return store.db.Update(func(tx *buntdb.Tx) error { user := User{Username: username} - - // Delete the main key first. This is a good indicator if the user generally exists. if _, err := tx.Delete(user.key()); err == buntdb.ErrNotFound { return ErrUserNotFound } else if err != nil { return err } - - // Now get all attributes and delete them as well. One by one. - var keys []string - if err := tx.AscendKeys(user.prefix()+":*", func(key, value string) bool { - keys = append(keys, key) - return true - }); err != nil { - return fmt.Errorf("cannot iterate keys: %w", err) - } - for _, key := range keys { - if _, err := tx.Delete(key); err != nil { - return fmt.Errorf("cannot remove key: %w", err) - } - } return nil }) } -func buildPassword(username string, password string) (string, error) { +func hashPassword(password string) (string, error) { hash, err := bcrypt.GenerateFromPassword([]byte(password), 0) if err != nil { return "", err } - return fmt.Sprintf("%s:%s", username, hash), nil + return string(hash), nil } diff --git a/store_test.go b/store_test.go index 0d42e2f..b33b2af 100644 --- a/store_test.go +++ b/store_test.go @@ -31,7 +31,7 @@ func TestStoreUserHandling(t *testing.T) { }) t.Run("adding users should work", func(t *testing.T) { - if err := store.AddUser(User{"myuser", "mypass", GlobalRoleUser}); err != nil { + if err := store.AddUser(User{Username: "myuser", Password: "mypass", Role: GlobalRoleUser}); err != nil { t.Errorf("cannot add user: %v", err) }