refactor: 优化评论模块代码格式,提升可读性

This commit is contained in:
2025-09-10 13:00:14 +08:00
parent 6651de5858
commit 2bf6c50c68

View File

@ -1,16 +1,16 @@
package repo package repo
import ( import (
"errors" "errors"
"net/http" "net/http"
"slices" "slices"
"strconv" "strconv"
"github.com/snowykami/neo-blog/internal/model" "github.com/snowykami/neo-blog/internal/model"
"github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/utils" "github.com/snowykami/neo-blog/pkg/utils"
"gorm.io/gorm" "gorm.io/gorm"
) )
type CommentRepo struct { type CommentRepo struct {
@ -21,227 +21,195 @@ var Comment = &CommentRepo{}
// 检查设置父评论是否会造成循环引用 // 检查设置父评论是否会造成循环引用
// 它通过向上遍历潜在父评论的所有祖先来实现 // 它通过向上遍历潜在父评论的所有祖先来实现
func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) { func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) {
// 如果没有父评论,则不可能有循环 // 如果没有父评论,则不可能有循环
if parentID == 0 { if parentID == 0 {
return false, nil return false, nil
} }
currentID := parentID currentID := parentID
for currentID != 0 { for currentID != 0 {
// 如果在向上追溯的过程中找到了自己的ID说明存在循环 // 如果在向上追溯的过程中找到了自己的ID说明存在循环
if currentID == commentID { if currentID == commentID {
return true, nil return true, nil
} }
var parent model.Comment var parent model.Comment
if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil { if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil {
// 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环 // 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return false, nil return false, nil
} }
return false, err return false, err
} }
// 继续向上追溯 // 继续向上追溯
currentID = parent.ReplyID currentID = parent.ReplyID
} }
// 已经追溯到树的根节点,没有发现循环 // 已经追溯到树的根节点,没有发现循环
return false, nil return false, nil
} }
// 递归删除子评论的辅助函数 // 递归删除子评论的辅助函数
func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error { func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error {
var children []*model.Comment var children []*model.Comment
// 1. 找到所有直接子评论 // 1. 找到所有直接子评论
if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil { if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil {
return err return err
} }
// 2. 对每一个子评论,递归删除它的子评论 // 2. 对每一个子评论,递归删除它的子评论
for _, child := range children { for _, child := range children {
if err := cr.deleteChildren(tx, child.ID); err != nil { if err := cr.deleteChildren(tx, child.ID); err != nil {
return err return err
} }
} }
// 3. 删除当前层级的子评论 // 3. 删除当前层级的子评论
if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil { if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil {
return err return err
} }
return nil return nil
} }
func (cr *CommentRepo) CreateComment(comment *model.Comment) error { func (cr *CommentRepo) CreateComment(comment *model.Comment) error {
err := GetDB().Transaction(func(tx *gorm.DB) error { err := GetDB().Transaction(func(tx *gorm.DB) error {
depth := 0 depth := 0
if comment.ReplyID != 0 { if comment.ReplyID != 0 {
isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID) isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID)
if err != nil { if err != nil {
return err // 检查过程中发生数据库错误 return err // 检查过程中发生数据库错误
} }
if isCircular { if isCircular {
return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil) return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil)
} }
var parentComment model.Comment var parentComment model.Comment
if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil { if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil {
return err return err
} }
parentComment.CommentCount += 1 parentComment.CommentCount += 1
if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil { if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil {
return err return err
} }
depth = parentComment.Depth + 1 depth = parentComment.Depth + 1
} }
if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) { if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) {
return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil) return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil)
} }
comment.Depth = depth comment.Depth = depth
if err := tx.Create(comment).Error; err != nil { if err := tx.Create(comment).Error; err != nil {
return err return err
} }
return nil return nil
}) })
return err return err
} }
func (cr *CommentRepo) UpdateComment(comment *model.Comment) error { func (cr *CommentRepo) UpdateComment(comment *model.Comment) error {
if comment.ID == 0 { if comment.ID == 0 {
return errs.New(http.StatusBadRequest, "invalid comment ID", nil) return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
} }
if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil { if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil {
return err return err
} }
return nil return nil
} }
func (cr *CommentRepo) DeleteComment(commentID string) error { func (cr *CommentRepo) DeleteComment(commentID string) error {
if commentID == "" { if commentID == "" {
return errs.New(http.StatusBadRequest, "invalid comment ID", nil) return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
} }
err := GetDB().Transaction(func(tx *gorm.DB) error { err := GetDB().Transaction(func(tx *gorm.DB) error {
var comment model.Comment var comment model.Comment
// 1. 查找主评论 // 1. 查找主评论
if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil { if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil {
return err return err
} }
// 2. 删除子评论 // 2. 删除子评论
if err := cr.deleteChildren(tx, comment.ID); err != nil { if err := cr.deleteChildren(tx, comment.ID); err != nil {
return err return err
} }
// 3. 删除主评论 // 3. 删除主评论
if err := tx.Delete(&comment).Error; err != nil { if err := tx.Delete(&comment).Error; err != nil {
return err return err
} }
// 4. 更新父评论的回复计数 // 4. 更新父评论的回复计数
if comment.ReplyID != 0 { if comment.ReplyID != 0 {
var parent model.Comment var parent model.Comment
if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil { if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil {
return err return err
} }
parent.CommentCount -= 1 parent.CommentCount -= 1
if err := tx.Save(&parent).Error; err != nil { if err := tx.Save(&parent).Error; err != nil {
return err return err
} }
} }
return nil return nil
}) })
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) { func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) {
var comment model.Comment var comment model.Comment
if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil { if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil {
return nil, err return nil, err
} }
return &comment, nil return &comment, nil
} }
func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, targetType string, page, size uint64, orderBy string, desc bool, depth int) ([]model.Comment, error) { func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, targetType string, page, size uint64, orderBy string, desc bool, depth int) ([]model.Comment, error) {
if !slices.Contains(constant.OrderByEnumComment, orderBy) { if !slices.Contains(constant.OrderByEnumComment, orderBy) {
return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil) return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil)
} }
var masterID uint var masterID uint
if targetType == constant.TargetTypePost { if targetType == constant.TargetTypePost {
post, err := Post.GetPostByID(strconv.Itoa(int(targetID))) post, err := Post.GetPostByID(strconv.Itoa(int(targetID)))
if err != nil { if err != nil {
return nil, err return nil, err
} }
masterID = post.UserID masterID = post.UserID
} }
query := GetDB().Model(&model.Comment{}).Preload("User") query := GetDB().Model(&model.Comment{}).Preload("User")
if commentID > 0 { if commentID > 0 {
query = query.Where("reply_id = ?", commentID) query = query.Where("reply_id = ?", commentID)
} }
if currentUserID > 0 { if currentUserID > 0 {
query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID) query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID)
} else { } else {
query = query.Where("is_private = ?", false) query = query.Where("is_private = ?", false)
} }
if depth >= 0 { if depth >= 0 {
query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth) query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth)
} else { } else {
query = query.Where("target_id = ? AND target_type = ?", targetID, targetType) query = query.Where("target_id = ? AND target_type = ?", targetID, targetType)
} }
items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc) items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
return items, nil return items, nil
}
func (cr *CommentRepo) CountReplyComments(currentUserID, commentID uint) (int64, error) {
var count int64
var masterID uint
// 根据commentID查询所属对象的用户ID
comment, err := cr.GetComment(strconv.Itoa(int(commentID)))
if err != nil {
return 0, err
}
if comment.TargetType == constant.TargetTypePost {
post, err := Post.GetPostByID(strconv.Itoa(int(comment.TargetID)))
if err != nil {
return 0, err
}
masterID = post.UserID
} else {
// 如果不是文章类型,可以根据需要添加其他类型的处理逻辑
return 0, errs.New(http.StatusBadRequest, "unsupported target type for counting replies", nil)
}
query := GetDB().Model(&model.Comment{}).Where("reply_id = ?", commentID)
if currentUserID > 0 {
query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID)
} else {
query = query.Where("is_private = ?", false)
}
if err := query.Count(&count).Error; err != nil {
return 0, err
}
return count, nil
} }