mirror of
https://github.com/snowykami/neo-blog.git
synced 2025-09-26 19:16:24 +00:00
refactor: 优化评论模块代码格式,提升可读性
This commit is contained in:
@ -1,16 +1,16 @@
|
|||||||
package repo
|
package repo
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/snowykami/neo-blog/internal/model"
|
"github.com/snowykami/neo-blog/internal/model"
|
||||||
"github.com/snowykami/neo-blog/pkg/constant"
|
"github.com/snowykami/neo-blog/pkg/constant"
|
||||||
"github.com/snowykami/neo-blog/pkg/errs"
|
"github.com/snowykami/neo-blog/pkg/errs"
|
||||||
"github.com/snowykami/neo-blog/pkg/utils"
|
"github.com/snowykami/neo-blog/pkg/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type CommentRepo struct {
|
type CommentRepo struct {
|
||||||
@ -21,227 +21,195 @@ var Comment = &CommentRepo{}
|
|||||||
// 检查设置父评论是否会造成循环引用
|
// 检查设置父评论是否会造成循环引用
|
||||||
// 它通过向上遍历潜在父评论的所有祖先来实现
|
// 它通过向上遍历潜在父评论的所有祖先来实现
|
||||||
func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) {
|
func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) {
|
||||||
// 如果没有父评论,则不可能有循环
|
// 如果没有父评论,则不可能有循环
|
||||||
if parentID == 0 {
|
if parentID == 0 {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
currentID := parentID
|
currentID := parentID
|
||||||
for currentID != 0 {
|
for currentID != 0 {
|
||||||
// 如果在向上追溯的过程中找到了自己的ID,说明存在循环
|
// 如果在向上追溯的过程中找到了自己的ID,说明存在循环
|
||||||
if currentID == commentID {
|
if currentID == commentID {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var parent model.Comment
|
var parent model.Comment
|
||||||
if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil {
|
if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil {
|
||||||
// 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环
|
// 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
// 继续向上追溯
|
// 继续向上追溯
|
||||||
currentID = parent.ReplyID
|
currentID = parent.ReplyID
|
||||||
}
|
}
|
||||||
|
|
||||||
// 已经追溯到树的根节点,没有发现循环
|
// 已经追溯到树的根节点,没有发现循环
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// 递归删除子评论的辅助函数
|
// 递归删除子评论的辅助函数
|
||||||
func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error {
|
func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error {
|
||||||
var children []*model.Comment
|
var children []*model.Comment
|
||||||
// 1. 找到所有直接子评论
|
// 1. 找到所有直接子评论
|
||||||
if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil {
|
if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 对每一个子评论,递归删除它的子评论
|
// 2. 对每一个子评论,递归删除它的子评论
|
||||||
for _, child := range children {
|
for _, child := range children {
|
||||||
if err := cr.deleteChildren(tx, child.ID); err != nil {
|
if err := cr.deleteChildren(tx, child.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 删除当前层级的子评论
|
// 3. 删除当前层级的子评论
|
||||||
if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil {
|
if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CommentRepo) CreateComment(comment *model.Comment) error {
|
func (cr *CommentRepo) CreateComment(comment *model.Comment) error {
|
||||||
err := GetDB().Transaction(func(tx *gorm.DB) error {
|
err := GetDB().Transaction(func(tx *gorm.DB) error {
|
||||||
depth := 0
|
depth := 0
|
||||||
if comment.ReplyID != 0 {
|
if comment.ReplyID != 0 {
|
||||||
isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID)
|
isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err // 检查过程中发生数据库错误
|
return err // 检查过程中发生数据库错误
|
||||||
}
|
}
|
||||||
if isCircular {
|
if isCircular {
|
||||||
return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil)
|
return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil)
|
||||||
}
|
}
|
||||||
var parentComment model.Comment
|
var parentComment model.Comment
|
||||||
if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil {
|
if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
parentComment.CommentCount += 1
|
parentComment.CommentCount += 1
|
||||||
if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil {
|
if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
depth = parentComment.Depth + 1
|
depth = parentComment.Depth + 1
|
||||||
}
|
}
|
||||||
if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) {
|
if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) {
|
||||||
return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil)
|
return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil)
|
||||||
}
|
}
|
||||||
comment.Depth = depth
|
comment.Depth = depth
|
||||||
if err := tx.Create(comment).Error; err != nil {
|
if err := tx.Create(comment).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CommentRepo) UpdateComment(comment *model.Comment) error {
|
func (cr *CommentRepo) UpdateComment(comment *model.Comment) error {
|
||||||
if comment.ID == 0 {
|
if comment.ID == 0 {
|
||||||
return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
|
return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil {
|
if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CommentRepo) DeleteComment(commentID string) error {
|
func (cr *CommentRepo) DeleteComment(commentID string) error {
|
||||||
if commentID == "" {
|
if commentID == "" {
|
||||||
return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
|
return errs.New(http.StatusBadRequest, "invalid comment ID", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
err := GetDB().Transaction(func(tx *gorm.DB) error {
|
err := GetDB().Transaction(func(tx *gorm.DB) error {
|
||||||
var comment model.Comment
|
var comment model.Comment
|
||||||
|
|
||||||
// 1. 查找主评论
|
// 1. 查找主评论
|
||||||
if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil {
|
if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. 删除子评论
|
// 2. 删除子评论
|
||||||
if err := cr.deleteChildren(tx, comment.ID); err != nil {
|
if err := cr.deleteChildren(tx, comment.ID); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 3. 删除主评论
|
// 3. 删除主评论
|
||||||
if err := tx.Delete(&comment).Error; err != nil {
|
if err := tx.Delete(&comment).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. 更新父评论的回复计数
|
// 4. 更新父评论的回复计数
|
||||||
if comment.ReplyID != 0 {
|
if comment.ReplyID != 0 {
|
||||||
var parent model.Comment
|
var parent model.Comment
|
||||||
if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil {
|
if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
parent.CommentCount -= 1
|
parent.CommentCount -= 1
|
||||||
|
|
||||||
if err := tx.Save(&parent).Error; err != nil {
|
if err := tx.Save(&parent).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) {
|
func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) {
|
||||||
var comment model.Comment
|
var comment model.Comment
|
||||||
if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil {
|
if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &comment, nil
|
return &comment, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, targetType string, page, size uint64, orderBy string, desc bool, depth int) ([]model.Comment, error) {
|
func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, targetType string, page, size uint64, orderBy string, desc bool, depth int) ([]model.Comment, error) {
|
||||||
if !slices.Contains(constant.OrderByEnumComment, orderBy) {
|
if !slices.Contains(constant.OrderByEnumComment, orderBy) {
|
||||||
return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil)
|
return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
var masterID uint
|
var masterID uint
|
||||||
|
|
||||||
if targetType == constant.TargetTypePost {
|
if targetType == constant.TargetTypePost {
|
||||||
post, err := Post.GetPostByID(strconv.Itoa(int(targetID)))
|
post, err := Post.GetPostByID(strconv.Itoa(int(targetID)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
masterID = post.UserID
|
masterID = post.UserID
|
||||||
}
|
}
|
||||||
|
|
||||||
query := GetDB().Model(&model.Comment{}).Preload("User")
|
query := GetDB().Model(&model.Comment{}).Preload("User")
|
||||||
|
|
||||||
if commentID > 0 {
|
if commentID > 0 {
|
||||||
query = query.Where("reply_id = ?", commentID)
|
query = query.Where("reply_id = ?", commentID)
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentUserID > 0 {
|
if currentUserID > 0 {
|
||||||
query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID)
|
query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID)
|
||||||
} else {
|
} else {
|
||||||
query = query.Where("is_private = ?", false)
|
query = query.Where("is_private = ?", false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if depth >= 0 {
|
if depth >= 0 {
|
||||||
query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth)
|
query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth)
|
||||||
} else {
|
} else {
|
||||||
query = query.Where("target_id = ? AND target_type = ?", targetID, targetType)
|
query = query.Where("target_id = ? AND target_type = ?", targetID, targetType)
|
||||||
}
|
}
|
||||||
|
|
||||||
items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc)
|
items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return items, nil
|
return items, nil
|
||||||
}
|
|
||||||
|
|
||||||
func (cr *CommentRepo) CountReplyComments(currentUserID, commentID uint) (int64, error) {
|
|
||||||
var count int64
|
|
||||||
var masterID uint
|
|
||||||
|
|
||||||
// 根据commentID查询所属对象的用户ID
|
|
||||||
comment, err := cr.GetComment(strconv.Itoa(int(commentID)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
if comment.TargetType == constant.TargetTypePost {
|
|
||||||
post, err := Post.GetPostByID(strconv.Itoa(int(comment.TargetID)))
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
masterID = post.UserID
|
|
||||||
} else {
|
|
||||||
// 如果不是文章类型,可以根据需要添加其他类型的处理逻辑
|
|
||||||
return 0, errs.New(http.StatusBadRequest, "unsupported target type for counting replies", nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
query := GetDB().Model(&model.Comment{}).Where("reply_id = ?", commentID)
|
|
||||||
if currentUserID > 0 {
|
|
||||||
query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID)
|
|
||||||
} else {
|
|
||||||
query = query.Where("is_private = ?", false)
|
|
||||||
}
|
|
||||||
if err := query.Count(&count).Error; err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user