From fe32ffe1819038f7edadd749b7c1ef84667d1c81 Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Mon, 30 Dec 2024 10:21:57 -0800 Subject: [PATCH] Merge updatecommentattachment functions (#33044) Extract from #32178 --- models/issues/comment.go | 57 ++++++++++++----------------------- models/issues/comment_test.go | 18 +++++++++++ models/issues/issue_update.go | 15 ++------- routers/web/repo/issue.go | 2 +- 4 files changed, 42 insertions(+), 50 deletions(-) diff --git a/models/issues/comment.go b/models/issues/comment.go index e4537aa872..880a6ce8c5 100644 --- a/models/issues/comment.go +++ b/models/issues/comment.go @@ -592,26 +592,26 @@ func (c *Comment) LoadAttachments(ctx context.Context) error { return nil } -// UpdateAttachments update attachments by UUIDs for the comment -func (c *Comment) UpdateAttachments(ctx context.Context, uuids []string) error { - ctx, committer, err := db.TxContext(ctx) - if err != nil { - return err +// UpdateCommentAttachments update attachments by UUIDs for the comment +func UpdateCommentAttachments(ctx context.Context, c *Comment, uuids []string) error { + if len(uuids) == 0 { + return nil } - defer committer.Close() - - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err) - } - for i := 0; i < len(attachments); i++ { - attachments[i].IssueID = c.IssueID - attachments[i].CommentID = c.ID - if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { - return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) + return db.WithTx(ctx, func(ctx context.Context) error { + attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, uuids) + if err != nil { + return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", uuids, err) } - } - return committer.Commit() + for i := 0; i < len(attachments); i++ { + attachments[i].IssueID = c.IssueID + attachments[i].CommentID = c.ID + if err := repo_model.UpdateAttachment(ctx, attachments[i]); err != nil { + return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) + } + } + c.Attachments = attachments + return nil + }) } // LoadAssigneeUserAndTeam if comment.Type is CommentTypeAssignees, then load assignees @@ -878,7 +878,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment // Check comment type. switch opts.Type { case CommentTypeCode: - if err = updateAttachments(ctx, opts, comment); err != nil { + if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil { return err } if comment.ReviewID != 0 { @@ -898,7 +898,7 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment } fallthrough case CommentTypeReview: - if err = updateAttachments(ctx, opts, comment); err != nil { + if err = UpdateCommentAttachments(ctx, comment, opts.Attachments); err != nil { return err } case CommentTypeReopen, CommentTypeClose: @@ -910,23 +910,6 @@ func updateCommentInfos(ctx context.Context, opts *CreateCommentOptions, comment return UpdateIssueCols(ctx, opts.Issue, "updated_unix") } -func updateAttachments(ctx context.Context, opts *CreateCommentOptions, comment *Comment) error { - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err) - } - for i := range attachments { - attachments[i].IssueID = opts.Issue.ID - attachments[i].CommentID = comment.ID - // No assign value could be 0, so ignore AllCols(). - if _, err = db.GetEngine(ctx).ID(attachments[i].ID).Update(attachments[i]); err != nil { - return fmt.Errorf("update attachment [%d]: %w", attachments[i].ID, err) - } - } - comment.Attachments = attachments - return nil -} - func createDeadlineComment(ctx context.Context, doer *user_model.User, issue *Issue, newDeadlineUnix timeutil.TimeStamp) (*Comment, error) { var content string var commentType CommentType diff --git a/models/issues/comment_test.go b/models/issues/comment_test.go index d81f33f953..32b86891e8 100644 --- a/models/issues/comment_test.go +++ b/models/issues/comment_test.go @@ -45,6 +45,24 @@ func TestCreateComment(t *testing.T) { unittest.AssertInt64InRange(t, now, then, int64(updatedIssue.UpdatedUnix)) } +func Test_UpdateCommentAttachment(t *testing.T) { + assert.NoError(t, unittest.PrepareTestDatabase()) + + comment := unittest.AssertExistsAndLoadBean(t, &issues_model.Comment{ID: 1}) + attachment := repo_model.Attachment{ + Name: "test.txt", + } + assert.NoError(t, db.Insert(db.DefaultContext, &attachment)) + + err := issues_model.UpdateCommentAttachments(db.DefaultContext, comment, []string{attachment.UUID}) + assert.NoError(t, err) + + attachment2 := unittest.AssertExistsAndLoadBean(t, &repo_model.Attachment{ID: attachment.ID}) + assert.EqualValues(t, attachment.Name, attachment2.Name) + assert.EqualValues(t, comment.ID, attachment2.CommentID) + assert.EqualValues(t, comment.IssueID, attachment2.IssueID) +} + func TestFetchCodeComments(t *testing.T) { assert.NoError(t, unittest.PrepareTestDatabase()) diff --git a/models/issues/issue_update.go b/models/issues/issue_update.go index 03863fe968..479834045c 100644 --- a/models/issues/issue_update.go +++ b/models/issues/issue_update.go @@ -405,19 +405,10 @@ func NewIssueWithIndex(ctx context.Context, doer *user_model.User, opts NewIssue return err } - if len(opts.Attachments) > 0 { - attachments, err := repo_model.GetAttachmentsByUUIDs(ctx, opts.Attachments) - if err != nil { - return fmt.Errorf("getAttachmentsByUUIDs [uuids: %v]: %w", opts.Attachments, err) - } - - for i := 0; i < len(attachments); i++ { - attachments[i].IssueID = opts.Issue.ID - if _, err = e.ID(attachments[i].ID).Update(attachments[i]); err != nil { - return fmt.Errorf("update attachment [id: %d]: %w", attachments[i].ID, err) - } - } + if err := UpdateIssueAttachments(ctx, opts.Issue.ID, opts.Attachments); err != nil { + return err } + if err = opts.Issue.LoadAttributes(ctx); err != nil { return err } diff --git a/routers/web/repo/issue.go b/routers/web/repo/issue.go index a3a4e73d7b..f150897a2d 100644 --- a/routers/web/repo/issue.go +++ b/routers/web/repo/issue.go @@ -602,7 +602,7 @@ func updateAttachments(ctx *context.Context, item any, files []string) error { case *issues_model.Issue: err = issues_model.UpdateIssueAttachments(ctx, content.ID, files) case *issues_model.Comment: - err = content.UpdateAttachments(ctx, files) + err = issues_model.UpdateCommentAttachments(ctx, content, files) default: return fmt.Errorf("unknown Type: %T", content) }