diff --git a/models/issues/comment_list.go b/models/issues/comment_list.go index 61ac1c8f56..7a133d1c16 100644 --- a/models/issues/comment_list.go +++ b/models/issues/comment_list.go @@ -23,7 +23,7 @@ func (comments CommentList) LoadPosters(ctx context.Context) error { } posterIDs := container.FilterSlice(comments, func(c *Comment) (int64, bool) { - return c.PosterID, c.Poster == nil && c.PosterID > 0 + return c.PosterID, c.Poster == nil && user_model.IsValidUserID(c.PosterID) }) posterMaps, err := getPostersByIDs(ctx, posterIDs) @@ -33,7 +33,7 @@ func (comments CommentList) LoadPosters(ctx context.Context) error { for _, comment := range comments { if comment.Poster == nil { - comment.Poster = getPoster(comment.PosterID, posterMaps) + comment.PosterID, comment.Poster = user_model.GetUserFromMap(comment.PosterID, posterMaps) } } return nil @@ -165,7 +165,7 @@ func (comments CommentList) loadOldMilestones(ctx context.Context) error { func (comments CommentList) getAssigneeIDs() []int64 { return container.FilterSlice(comments, func(comment *Comment) (int64, bool) { - return comment.AssigneeID, comment.AssigneeID > 0 + return comment.AssigneeID, user_model.IsValidUserID(comment.AssigneeID) }) } @@ -206,11 +206,7 @@ func (comments CommentList) loadAssignees(ctx context.Context) error { } for _, comment := range comments { - comment.Assignee = assignees[comment.AssigneeID] - if comment.Assignee == nil { - comment.AssigneeID = user_model.GhostUserID - comment.Assignee = user_model.NewGhostUser() - } + comment.AssigneeID, comment.Assignee = user_model.GetUserFromMap(comment.AssigneeID, assignees) } return nil } diff --git a/models/issues/comment_list_test.go b/models/issues/comment_list_test.go new file mode 100644 index 0000000000..66037d7358 --- /dev/null +++ b/models/issues/comment_list_test.go @@ -0,0 +1,86 @@ +// Copyright 2024 The Forgejo Authors +// SPDX-License-Identifier: MIT + +package issues + +import ( + "testing" + + "code.gitea.io/gitea/models/db" + repo_model "code.gitea.io/gitea/models/repo" + "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCommentListLoadUser(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + issue := unittest.AssertExistsAndLoadBean(t, &Issue{}) + repo := unittest.AssertExistsAndLoadBean(t, &repo_model.Repository{ID: issue.RepoID}) + doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: repo.OwnerID}) + + for _, testCase := range []struct { + poster int64 + assignee int64 + user *user_model.User + }{ + { + poster: user_model.ActionsUserID, + assignee: user_model.ActionsUserID, + user: user_model.NewActionsUser(), + }, + { + poster: user_model.GhostUserID, + assignee: user_model.GhostUserID, + user: user_model.NewGhostUser(), + }, + { + poster: doer.ID, + assignee: doer.ID, + user: doer, + }, + { + poster: 0, + assignee: 0, + user: user_model.NewGhostUser(), + }, + { + poster: -200, + assignee: -200, + user: user_model.NewGhostUser(), + }, + { + poster: 200, + assignee: 200, + user: user_model.NewGhostUser(), + }, + } { + t.Run(testCase.user.Name, func(t *testing.T) { + comment, err := CreateComment(db.DefaultContext, &CreateCommentOptions{ + Type: CommentTypeComment, + Doer: testCase.user, + Repo: repo, + Issue: issue, + Content: "Hello", + }) + assert.NoError(t, err) + + list := CommentList{comment} + + comment.PosterID = testCase.poster + comment.Poster = nil + assert.NoError(t, list.LoadPosters(db.DefaultContext)) + require.NotNil(t, comment.Poster) + assert.Equal(t, testCase.user.ID, comment.Poster.ID) + + comment.AssigneeID = testCase.assignee + comment.Assignee = nil + assert.NoError(t, list.loadAssignees(db.DefaultContext)) + require.NotNil(t, comment.Assignee) + assert.Equal(t, testCase.user.ID, comment.Assignee.ID) + }) + } +} diff --git a/models/issues/issue_list.go b/models/issues/issue_list.go index fbfa7584a0..fe6c630a31 100644 --- a/models/issues/issue_list.go +++ b/models/issues/issue_list.go @@ -79,7 +79,7 @@ func (issues IssueList) LoadPosters(ctx context.Context) error { } posterIDs := container.FilterSlice(issues, func(issue *Issue) (int64, bool) { - return issue.PosterID, issue.Poster == nil && issue.PosterID > 0 + return issue.PosterID, issue.Poster == nil && user_model.IsValidUserID(issue.PosterID) }) posterMaps, err := getPostersByIDs(ctx, posterIDs) @@ -89,7 +89,7 @@ func (issues IssueList) LoadPosters(ctx context.Context) error { for _, issue := range issues { if issue.Poster == nil { - issue.Poster = getPoster(issue.PosterID, posterMaps) + issue.PosterID, issue.Poster = user_model.GetUserFromMap(issue.PosterID, posterMaps) } } return nil @@ -115,20 +115,6 @@ func getPostersByIDs(ctx context.Context, posterIDs []int64) (map[int64]*user_mo return posterMaps, nil } -func getPoster(posterID int64, posterMaps map[int64]*user_model.User) *user_model.User { - if posterID == user_model.ActionsUserID { - return user_model.NewActionsUser() - } - if posterID <= 0 { - return nil - } - poster, ok := posterMaps[posterID] - if !ok { - return user_model.NewGhostUser() - } - return poster -} - func (issues IssueList) getIssueIDs() []int64 { ids := make([]int64, 0, len(issues)) for _, issue := range issues { diff --git a/models/issues/issue_list_test.go b/models/issues/issue_list_test.go index 10ba38a64b..50bbd5c667 100644 --- a/models/issues/issue_list_test.go +++ b/models/issues/issue_list_test.go @@ -9,9 +9,11 @@ import ( "code.gitea.io/gitea/models/db" issues_model "code.gitea.io/gitea/models/issues" "code.gitea.io/gitea/models/unittest" + user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/setting" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIssueList_LoadRepositories(t *testing.T) { @@ -78,3 +80,50 @@ func TestIssueList_LoadAttributes(t *testing.T) { assert.Equal(t, issue.ID == 1, issue.IsRead, "unexpected is_read value for issue[%d]", issue.ID) } } + +func TestIssueListLoadUser(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + issue := unittest.AssertExistsAndLoadBean(t, &issues_model.Issue{}) + doer := unittest.AssertExistsAndLoadBean(t, &user_model.User{ID: 1}) + + for _, testCase := range []struct { + poster int64 + user *user_model.User + }{ + { + poster: user_model.ActionsUserID, + user: user_model.NewActionsUser(), + }, + { + poster: user_model.GhostUserID, + user: user_model.NewGhostUser(), + }, + { + poster: doer.ID, + user: doer, + }, + { + poster: 0, + user: user_model.NewGhostUser(), + }, + { + poster: -200, + user: user_model.NewGhostUser(), + }, + { + poster: 200, + user: user_model.NewGhostUser(), + }, + } { + t.Run(testCase.user.Name, func(t *testing.T) { + list := issues_model.IssueList{issue} + + issue.PosterID = testCase.poster + issue.Poster = nil + assert.NoError(t, list.LoadPosters(db.DefaultContext)) + require.NotNil(t, issue.Poster) + assert.Equal(t, testCase.user.ID, issue.Poster.ID) + }) + } +} diff --git a/models/user/user.go b/models/user/user.go index f6d649eaf3..b1731021fd 100644 --- a/models/user/user.go +++ b/models/user/user.go @@ -939,6 +939,20 @@ func GetUserByIDs(ctx context.Context, ids []int64) ([]*User, error) { return users, err } +func IsValidUserID(id int64) bool { + return id > 0 || id == GhostUserID || id == ActionsUserID +} + +func GetUserFromMap(id int64, idMap map[int64]*User) (int64, *User) { + if user, ok := idMap[id]; ok { + return id, user + } + if id == ActionsUserID { + return ActionsUserID, NewActionsUser() + } + return GhostUserID, NewGhostUser() +} + // GetPossibleUserByID returns the user if id > 0 or return system usrs if id < 0 func GetPossibleUserByID(ctx context.Context, id int64) (*User, error) { switch id { diff --git a/models/user/user_test.go b/models/user/user_test.go index abeff078c5..5bd1f21b5c 100644 --- a/models/user/user_test.go +++ b/models/user/user_test.go @@ -35,6 +35,39 @@ func TestOAuth2Application_LoadUser(t *testing.T) { assert.NotNil(t, user) } +func TestIsValidUserID(t *testing.T) { + assert.False(t, user_model.IsValidUserID(-30)) + assert.False(t, user_model.IsValidUserID(0)) + assert.True(t, user_model.IsValidUserID(user_model.GhostUserID)) + assert.True(t, user_model.IsValidUserID(user_model.ActionsUserID)) + assert.True(t, user_model.IsValidUserID(200)) +} + +func TestGetUserFromMap(t *testing.T) { + id := int64(200) + idMap := map[int64]*user_model.User{ + id: {ID: id}, + } + + ghostID := int64(user_model.GhostUserID) + actionsID := int64(user_model.ActionsUserID) + actualID, actualUser := user_model.GetUserFromMap(-20, idMap) + assert.Equal(t, ghostID, actualID) + assert.Equal(t, ghostID, actualUser.ID) + + actualID, actualUser = user_model.GetUserFromMap(0, idMap) + assert.Equal(t, ghostID, actualID) + assert.Equal(t, ghostID, actualUser.ID) + + actualID, actualUser = user_model.GetUserFromMap(ghostID, idMap) + assert.Equal(t, ghostID, actualID) + assert.Equal(t, ghostID, actualUser.ID) + + actualID, actualUser = user_model.GetUserFromMap(actionsID, idMap) + assert.Equal(t, actionsID, actualID) + assert.Equal(t, actionsID, actualUser.ID) +} + func TestGetUserByName(t *testing.T) { defer tests.AddFixtures("models/user/fixtures/")() assert.NoError(t, unittest.PrepareTestDatabase())