diff --git a/internal/controller/v1/comment.go b/internal/controller/v1/comment.go index b898066..3f78e5a 100644 --- a/internal/controller/v1/comment.go +++ b/internal/controller/v1/comment.go @@ -2,12 +2,16 @@ package v1 import ( "context" + "slices" + "strconv" + "github.com/cloudwego/hertz/pkg/app" + "github.com/snowykami/neo-blog/internal/ctxutils" "github.com/snowykami/neo-blog/internal/dto" "github.com/snowykami/neo-blog/internal/service" + "github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/resps" - "strconv" ) type CommentController struct { @@ -23,7 +27,7 @@ func NewCommentController() *CommentController { func (cc *CommentController) CreateComment(ctx context.Context, c *app.RequestContext) { var req dto.CreateCommentReq if err := c.BindAndValidate(&req); err != nil { - resps.BadRequest(c, resps.ErrParamInvalid) + resps.BadRequest(c, err.Error()) return } err := cc.service.CreateComment(ctx, &req) @@ -85,7 +89,35 @@ func (cc *CommentController) GetComment(ctx context.Context, c *app.RequestConte } func (cc *CommentController) GetCommentList(ctx context.Context, c *app.RequestContext) { - // pagenation := ctxutils.GetPaginationParams(c) + pagination := ctxutils.GetPaginationParams(c) + if pagination.OrderBy == "" { + pagination.OrderBy = constant.OrderByUpdatedAt + } + if pagination.OrderBy != "" && !slices.Contains(constant.OrderByEnumComment, pagination.OrderBy) { + resps.BadRequest(c, "无效的排序字段") + return + } + targetID, err := strconv.Atoi(c.Query("target_id")) + if err != nil { + resps.BadRequest(c, "无效的 target_id") + return + } + + req := dto.GetCommentListReq{ + Desc: pagination.Desc, + OrderBy: pagination.OrderBy, + Page: pagination.Page, + Size: pagination.Size, + TargetID: uint(targetID), + TargetType: c.Query("target_type"), + } + resp, err := cc.service.GetCommentList(ctx, &req) + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + resps.Ok(c, resps.Success, resp) } func (cc *CommentController) ReactComment(ctx context.Context, c *app.RequestContext) {} diff --git a/internal/dto/comment.go b/internal/dto/comment.go index a5568f9..a4dd7be 100644 --- a/internal/dto/comment.go +++ b/internal/dto/comment.go @@ -17,9 +17,19 @@ type CreateCommentReq struct { TargetType string `json:"target_type" binding:"required"` // 目标类型,如 "post", "page" Content string `json:"content" binding:"required"` // 评论内容 ReplyID uint `json:"reply_id"` // 回复的评论ID + IsPrivate bool `json:"is_private" binding:"required"` // 是否私密 } type UpdateCommentReq struct { CommentID uint `json:"comment_id" binding:"required"` // 评论ID Content string `json:"content" binding:"required"` // 评论内容 } + +type GetCommentListReq struct { + TargetID uint `json:"target_id" binding:"required"` + TargetType string `json:"target_type" binding:"required"` + OrderBy string `json:"order_by"` // 排序方式 + Page uint64 `json:"page"` // 页码 + Size uint64 `json:"size"` + Desc bool `json:"desc"` +} \ No newline at end of file diff --git a/internal/repo/comment.go b/internal/repo/comment.go index d304b47..8ccb45f 100644 --- a/internal/repo/comment.go +++ b/internal/repo/comment.go @@ -1,33 +1,207 @@ package repo -import "github.com/snowykami/neo-blog/internal/model" +import ( + "errors" + "net/http" + "slices" + "strconv" + + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/errs" + "gorm.io/gorm" +) type CommentRepo struct { } var Comment = &CommentRepo{} -func (cr *CommentRepo) CreateComment(comment *model.Comment) error { - // Implementation for creating a comment +// 检查设置父评论是否会造成循环引用 +// 它通过向上遍历潜在父评论的所有祖先来实现 +func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) { + // 如果没有父评论,则不可能有循环 + if parentID == 0 { + return false, nil + } + + currentID := parentID + for currentID != 0 { + // 如果在向上追溯的过程中找到了自己的ID,说明存在循环 + if currentID == commentID { + return true, nil + } + + var parent model.Comment + if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil { + // 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环 + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + // 继续向上追溯 + currentID = parent.ReplyID + } + + // 已经追溯到树的根节点,没有发现循环 + return false, nil +} + + +// 递归删除子评论的辅助函数 +func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error { + var children []*model.Comment + // 1. 找到所有直接子评论 + if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil { + return err + } + + // 2. 对每一个子评论,递归删除它的子评论 + for _, child := range children { + if err := cr.deleteChildren(tx, child.ID); err != nil { + return err + } + } + + // 3. 删除当前层级的子评论 + if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil { + return err + } + return nil } +func (cr *CommentRepo) CreateComment(comment *model.Comment) error { + err := GetDB().Transaction(func(tx *gorm.DB) error { + depth := 0 + if comment.ReplyID != 0 { + isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID) + if err != nil { + return err // 检查过程中发生数据库错误 + } + if isCircular { + return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil) + } + var parentComment model.Comment + if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil { + return err + } + parentComment.CommentCount += 1 + if err := tx.Save(&parentComment).Error; err != nil { + return err + } + depth = parentComment.Depth + 1 + } + + comment.Depth = depth + + if err := tx.Create(comment).Error; err != nil { + return err + } + + return nil + }) + + return err +} + func (cr *CommentRepo) UpdateComment(comment *model.Comment) error { - // Implementation for updating a comment + if comment.ID == 0 { + return errs.New(http.StatusBadRequest, "invalid comment ID", nil) + } + + if err := GetDB().Updates(comment).Error; err != nil { + return err + } + return nil } func (cr *CommentRepo) DeleteComment(commentID string) error { - // Implementation for deleting a comment + if commentID == "" { + return errs.New(http.StatusBadRequest, "invalid comment ID", nil) + } + + err := GetDB().Transaction(func(tx *gorm.DB) error { + var comment model.Comment + + // 1. 查找主评论 + if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil { + return err + } + + // 2. 删除子评论 + if err := cr.deleteChildren(tx, comment.ID); err != nil { + return err + } + + // 3. 删除主评论 + if err := tx.Delete(&comment).Error; err != nil { + return err + } + + // 4. 更新父评论的回复计数 + if comment.ReplyID != 0 { + var parent model.Comment + if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil { + return err + } + + parent.CommentCount -= 1 + + if err := tx.Save(&parent).Error; err != nil { + return err + } + } + + return nil + }) + + if err != nil { + return err + } + return nil } func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) { - // Implementation for getting a comment by ID - return nil, nil + var comment model.Comment + if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil { + return nil, err + } + return &comment, nil } -func (cr *CommentRepo) ListComments(currentUserID uint, page, size uint, orderBy string, desc bool) ([]model.Comment, error) { - // Implementation for listing comments for a post - return nil, nil +func (cr *CommentRepo) ListComments(currentUserID uint, targetID uint, targetType string, page, size uint64, orderBy string, desc bool) ([]model.Comment, error) { + if !slices.Contains(constant.OrderByEnumComment, orderBy) { + return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil) + } + + var masterID uint + + if targetType == constant.TargetTypePost { + post, err := Post.GetPostByID(strconv.Itoa(int(targetID))) + if err != nil { + return nil, err + } + masterID = post.UserID + } + + query := GetDB().Model(&model.Comment{}).Preload("User") + + if currentUserID > 0 { + query = query.Where("is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?))", false, true, currentUserID, masterID) + } else { + query = query.Where("is_private = ?", false) + } + + query = query.Where("target_id = ? AND target_type = ?", targetID, targetType) + + items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc) + if err != nil { + return nil, err + } + + return items, nil } diff --git a/internal/service/comment.go b/internal/service/comment.go index 4cfc2b1..1839bca 100644 --- a/internal/service/comment.go +++ b/internal/service/comment.go @@ -2,7 +2,13 @@ package service import ( "context" + "strconv" + + "github.com/snowykami/neo-blog/internal/ctxutils" "github.com/snowykami/neo-blog/internal/dto" + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/internal/repo" + "github.com/snowykami/neo-blog/pkg/errs" ) type CommentService struct{} @@ -12,25 +18,124 @@ func NewCommentService() *CommentService { } func (cs *CommentService) CreateComment(ctx context.Context, req *dto.CreateCommentReq) error { + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } + + comment := &model.Comment{ + Content: req.Content, + ReplyID: req.ReplyID, + TargetID: req.TargetID, + TargetType: req.TargetType, + UserID: currentUser.ID, + IsPrivate: req.IsPrivate, + } + + err := repo.Comment.CreateComment(comment) + + if err != nil { + return err + } + return nil } func (cs *CommentService) UpdateComment(ctx context.Context, req *dto.UpdateCommentReq) error { - // Implementation for updating a comment + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } + + comment, err := repo.Comment.GetComment(strconv.Itoa(int(req.CommentID))) + if err != nil { + return err + } + + if currentUser.ID != comment.UserID { + return errs.ErrForbidden + } + + comment.Content = req.Content + + err = repo.Comment.UpdateComment(comment) + + if err != nil { + return err + } + return nil } func (cs *CommentService) DeleteComment(ctx context.Context, commentID string) error { - // Implementation for deleting a comment + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } + if commentID == "" { + return errs.ErrBadRequest + } + + comment, err := repo.Comment.GetComment(commentID) + if err != nil { + return errs.New(errs.ErrNotFound.Code, "comment not found", err) + } + + if comment.UserID != currentUser.ID { + return errs.ErrForbidden + } + + if err := repo.Comment.DeleteComment(commentID); err != nil { + return err + } + return nil } func (cs *CommentService) GetComment(ctx context.Context, commentID string) (*dto.CommentDto, error) { - // Implementation for getting a single comment - return nil, nil + comment, err := repo.Comment.GetComment(commentID) + + if err != nil { + return nil, errs.New(errs.ErrNotFound.Code, "comment not found", err) + } + + commentDto := dto.CommentDto{ + ID: comment.ID, + TargetID: comment.TargetID, + TargetType: comment.TargetType, + Content: comment.Content, + ReplyID: comment.ReplyID, + Depth: comment.Depth, + CreatedAt: comment.CreatedAt.String(), + UpdatedAt: comment.UpdatedAt.String(), + User: *comment.User.ToDto(), + } + + return &commentDto, err } -//func (cs *CommentService) GetCommentList(ctx context.Context, req *dto.GetCommentListReq) ([]dto.CommentDto, error) { -// // Implementation for getting a list of comments -// return nil, nil -//} +func (cs *CommentService) GetCommentList(ctx context.Context, req *dto.GetCommentListReq) ([]dto.CommentDto, error) { + currentUser, _ := ctxutils.GetCurrentUser(ctx) + + comments, err := repo.Comment.ListComments(currentUser.ID, req.TargetID, req.TargetType, req.Page, req.Size, req.OrderBy, req.Desc) + if err != nil { + return nil, errs.New(errs.ErrInternalServer.Code, "failed to list comments", err) + } + + commentDtos := make([]dto.CommentDto, 0) + + for _, comment := range comments { + commentDto := dto.CommentDto{ + ID: comment.ID, + Content: comment.Content, + TargetID: comment.TargetID, + TargetType: comment.TargetType, + CreatedAt: comment.CreatedAt.String(), + UpdatedAt: comment.UpdatedAt.String(), + Depth: comment.Depth, + User: *comment.User.ToDto(), + } + commentDtos = append(commentDtos, commentDto) + } + return commentDtos, nil +} diff --git a/pkg/constant/constant.go b/pkg/constant/constant.go index c1e2a83..0dff5d2 100644 --- a/pkg/constant/constant.go +++ b/pkg/constant/constant.go @@ -41,4 +41,5 @@ const ( var ( OrderByEnumPost = []string{OrderByCreatedAt, OrderByUpdatedAt, OrderByLikeCount, OrderByCommentCount, OrderByViewCount, OrderByHeat} // 帖子可用的排序方式 + OrderByEnumComment = []string{OrderByCreatedAt, OrderByUpdatedAt, OrderByCommentCount} // 评论可用的排序方式 )