diff --git a/internal/controller/v1/comment.go b/internal/controller/v1/comment.go index b4eba19..ee616a0 100644 --- a/internal/controller/v1/comment.go +++ b/internal/controller/v1/comment.go @@ -6,6 +6,7 @@ import ( "strconv" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/snowykami/neo-blog/internal/ctxutils" "github.com/snowykami/neo-blog/internal/dto" "github.com/snowykami/neo-blog/internal/service" @@ -30,13 +31,13 @@ func (cc *CommentController) CreateComment(ctx context.Context, c *app.RequestCo resps.BadRequest(c, err.Error()) return } - err := cc.service.CreateComment(ctx, &req) + commentID, err := cc.service.CreateComment(ctx, &req) if err != nil { serviceErr := errs.AsServiceError(err) resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) return } - resps.Ok(c, resps.Success, nil) + resps.Ok(c, resps.Success, utils.H{"id": commentID}) } func (cc *CommentController) UpdateComment(ctx context.Context, c *app.RequestContext) { diff --git a/internal/controller/v1/post.go b/internal/controller/v1/post.go index bcb8e24..4641836 100644 --- a/internal/controller/v1/post.go +++ b/internal/controller/v1/post.go @@ -2,6 +2,9 @@ package v1 import ( "context" + "slices" + "strings" + "github.com/cloudwego/hertz/pkg/app" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/snowykami/neo-blog/internal/ctxutils" @@ -10,8 +13,6 @@ import ( "github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/resps" - "slices" - "strings" ) type PostController struct { diff --git a/internal/controller/v1/user.go b/internal/controller/v1/user.go index 60a79f7..b010540 100644 --- a/internal/controller/v1/user.go +++ b/internal/controller/v1/user.go @@ -69,7 +69,7 @@ func (u *UserController) Logout(ctx context.Context, c *app.RequestContext) { ctxutils.ClearTokenAndRefreshTokenCookie(c) resps.Ok(c, resps.Success, nil) // 尝试吊销服务端状态:若用户登录的情况下 - // TODO: 这里可以添加服务端状态的吊销逻辑 + // TODO: 添加服务端状态的吊销逻辑 } func (u *UserController) OidcList(ctx context.Context, c *app.RequestContext) { @@ -176,3 +176,11 @@ func (u *UserController) VerifyEmail(ctx context.Context, c *app.RequestContext) } resps.Ok(c, resps.Success, resp) } + +func (u *UserController) ChangePassword(ctx context.Context, c *app.RequestContext) { + // TODO: 实现修改密码功能 +} + +func (u *UserController) ChangeEmail(ctx context.Context, c *app.RequestContext) { + // TODO: 实现修改邮箱功能 +} diff --git a/internal/model/comment.go b/internal/model/comment.go index ad4b157..9685d67 100644 --- a/internal/model/comment.go +++ b/internal/model/comment.go @@ -1,6 +1,8 @@ package model -import "gorm.io/gorm" +import ( + "gorm.io/gorm" +) type Comment struct { gorm.Model diff --git a/internal/repo/comment.go b/internal/repo/comment.go index 9438b7a..b18ec94 100644 --- a/internal/repo/comment.go +++ b/internal/repo/comment.go @@ -1,16 +1,16 @@ package repo import ( - "errors" - "net/http" - "slices" - "strconv" + "errors" + "net/http" + "slices" + "strconv" - "github.com/snowykami/neo-blog/internal/model" - "github.com/snowykami/neo-blog/pkg/constant" - "github.com/snowykami/neo-blog/pkg/errs" - "github.com/snowykami/neo-blog/pkg/utils" - "gorm.io/gorm" + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/errs" + "github.com/snowykami/neo-blog/pkg/utils" + "gorm.io/gorm" ) type CommentRepo struct { @@ -21,195 +21,237 @@ var Comment = &CommentRepo{} // 检查设置父评论是否会造成循环引用 // 它通过向上遍历潜在父评论的所有祖先来实现 func (cr *CommentRepo) isCircularReference(tx *gorm.DB, commentID, parentID uint) (bool, error) { - // 如果没有父评论,则不可能有循环 - if parentID == 0 { - return false, nil - } + // 如果没有父评论,则不可能有循环 + if parentID == 0 { + return false, nil + } - currentID := parentID - for currentID != 0 { - // 如果在向上追溯的过程中找到了自己的ID,说明存在循环 - if currentID == commentID { - return true, nil - } + currentID := parentID + for currentID != 0 { + // 如果在向上追溯的过程中找到了自己的ID,说明存在循环 + if currentID == commentID { + return true, nil + } - var parent model.Comment - if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil { - // 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环 - if errors.Is(err, gorm.ErrRecordNotFound) { - return false, nil - } - return false, err - } - // 继续向上追溯 - currentID = parent.ReplyID - } + var parent model.Comment + if err := tx.Where("id = ?", currentID).First(&parent).Error; err != nil { + // 如果祖先链中的某个评论不存在,说明链已经断开,不可能形成循环 + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + // 继续向上追溯 + currentID = parent.ReplyID + } - // 已经追溯到树的根节点,没有发现循环 - return false, nil + // 已经追溯到树的根节点,没有发现循环 + return false, nil } // 递归删除子评论的辅助函数 func (cr *CommentRepo) deleteChildren(tx *gorm.DB, parentID uint) error { - var children []*model.Comment - // 1. 找到所有直接子评论 - if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil { - return err - } + var children []*model.Comment + // 1. 找到所有直接子评论 + if err := tx.Where("reply_id = ?", parentID).Find(&children).Error; err != nil { + return err + } - // 2. 对每一个子评论,递归删除它的子评论 - for _, child := range children { - if err := cr.deleteChildren(tx, child.ID); err != nil { - return err - } - } + // 2. 对每一个子评论,递归删除它的子评论 + for _, child := range children { + if err := cr.deleteChildren(tx, child.ID); err != nil { + return err + } + } - // 3. 删除当前层级的子评论 - if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil { - return err - } + // 3. 删除当前层级的子评论 + if err := tx.Where("reply_id = ?", parentID).Delete(&model.Comment{}).Error; err != nil { + return err + } - return nil + return nil } -func (cr *CommentRepo) CreateComment(comment *model.Comment) error { - err := GetDB().Transaction(func(tx *gorm.DB) error { - depth := 0 - if comment.ReplyID != 0 { - isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID) - if err != nil { - return err // 检查过程中发生数据库错误 - } - if isCircular { - return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil) - } - var parentComment model.Comment - if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil { - return err - } - parentComment.CommentCount += 1 - if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil { - return err - } - depth = parentComment.Depth + 1 - } - if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) { - return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil) - } - comment.Depth = depth - if err := tx.Create(comment).Error; err != nil { - return err - } - return nil - }) - return err +func (cr *CommentRepo) CreateComment(comment *model.Comment) (uint, error) { + var commentID uint + err := GetDB().Transaction(func(tx *gorm.DB) error { + depth := 0 + if comment.ReplyID != 0 { + isCircular, err := cr.isCircularReference(tx, comment.ID, comment.ReplyID) + if err != nil { + return err + } + if isCircular { + return errs.New(http.StatusBadRequest, "circular reference detected in comment tree", nil) + } + var parentComment model.Comment + if err := tx.Where("id = ?", comment.ReplyID).First(&parentComment).Error; err != nil { + return err + } + parentComment.CommentCount += 1 + if err := tx.Model(&parentComment).UpdateColumn("CommentCount", parentComment.CommentCount).Error; err != nil { + return err + } + depth = parentComment.Depth + 1 + } + if depth > utils.Env.GetAsInt(constant.EnvKeyMaxReplyDepth, constant.MaxReplyDepthDefault) { + return errs.New(http.StatusBadRequest, "exceeded maximum reply depth", nil) + } + comment.Depth = depth + if err := tx.Create(comment).Error; err != nil { + return err + } + commentID = comment.ID // 记录主键 + switch comment.TargetType { + case constant.TargetTypePost: + var count int64 + if err := tx.Model(&model.Comment{}). + Where("target_id = ? AND target_type = ?", comment.TargetID, constant.TargetTypePost). + Count(&count).Error; err != nil { + return err + } + if err := tx.Model(&model.Post{}).Where("id = ?", comment.TargetID). + UpdateColumn("comment_count", count).Error; err != nil { + return err + } + default: + return errs.New(http.StatusBadRequest, "unsupported target type: "+comment.TargetType, nil) + } + return nil + }) + return commentID, err } - func (cr *CommentRepo) UpdateComment(comment *model.Comment) error { - if comment.ID == 0 { - return errs.New(http.StatusBadRequest, "invalid comment ID", nil) - } + if comment.ID == 0 { + return errs.New(http.StatusBadRequest, "invalid comment ID", nil) + } - if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil { - return err - } + if err := GetDB().Select("IsPrivate", "Content").Updates(comment).Error; err != nil { + return err + } - return nil + return nil } func (cr *CommentRepo) DeleteComment(commentID string) error { - if commentID == "" { - return errs.New(http.StatusBadRequest, "invalid comment ID", nil) - } + if commentID == "" { + return errs.New(http.StatusBadRequest, "invalid comment ID", nil) + } - err := GetDB().Transaction(func(tx *gorm.DB) error { - var comment model.Comment + err := GetDB().Transaction(func(tx *gorm.DB) error { + var comment model.Comment - // 1. 查找主评论 - if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil { - return err - } + // 1. 查找主评论 + if err := tx.Where("id = ?", commentID).First(&comment).Error; err != nil { + return err + } - // 2. 删除子评论 - if err := cr.deleteChildren(tx, comment.ID); err != nil { - return err - } + // 2. 删除子评论 + if err := cr.deleteChildren(tx, comment.ID); err != nil { + return err + } - // 3. 删除主评论 - if err := tx.Delete(&comment).Error; err != nil { - return err - } + // 3. 删除主评论 + if err := tx.Delete(&comment).Error; err != nil { + return err + } - // 4. 更新父评论的回复计数 - if comment.ReplyID != 0 { - var parent model.Comment - if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil { - return err - } + // 4. 更新父评论的回复计数 + if comment.ReplyID != 0 { + var parent model.Comment + if err := tx.Where("id = ?", comment.ReplyID).First(&parent).Error; err != nil { + return err + } - parent.CommentCount -= 1 + parent.CommentCount -= 1 - if err := tx.Save(&parent).Error; err != nil { - return err - } - } + if err := tx.Save(&parent).Error; err != nil { + return err + } + } - return nil - }) + // 5. 更新目标的评论数量 + switch comment.TargetType { + case constant.TargetTypePost: + var count int64 + if err := tx.Model(&model.Comment{}). + Where("target_id = ? AND target_type = ?", comment.TargetID, constant.TargetTypePost). + Count(&count).Error; err != nil { + return err + } + if err := tx.Model(&model.Post{}).Where("id = ?", comment.TargetID). + UpdateColumn("comment_count", count).Error; err != nil { + return err + } + default: + return errs.New(http.StatusBadRequest, "unsupported target type: "+comment.TargetType, nil) + } - if err != nil { - return err - } + return nil + }) - return nil + if err != nil { + return err + } + + return nil } func (cr *CommentRepo) GetComment(commentID string) (*model.Comment, error) { - var comment model.Comment - if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil { - return nil, err - } - return &comment, nil + var comment model.Comment + if err := GetDB().Where("id = ?", commentID).Preload("User").First(&comment).Error; err != nil { + return nil, err + } + return &comment, nil } func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, targetType string, page, size uint64, orderBy string, desc bool, depth int) ([]model.Comment, error) { - if !slices.Contains(constant.OrderByEnumComment, orderBy) { - return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil) - } + if !slices.Contains(constant.OrderByEnumComment, orderBy) { + return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil) + } - var masterID uint + var masterID uint - if targetType == constant.TargetTypePost { - post, err := Post.GetPostByID(strconv.Itoa(int(targetID))) - if err != nil { - return nil, err - } - masterID = post.UserID - } + if targetType == constant.TargetTypePost { + post, err := Post.GetPostByID(strconv.Itoa(int(targetID))) + if err != nil { + return nil, err + } + masterID = post.UserID + } - query := GetDB().Model(&model.Comment{}).Preload("User") + query := GetDB().Model(&model.Comment{}).Preload("User") - if commentID > 0 { - query = query.Where("reply_id = ?", commentID) - } + if commentID > 0 { + query = query.Where("reply_id = ?", commentID) + } - if currentUserID > 0 { - query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID) - } else { - query = query.Where("is_private = ?", false) - } + if currentUserID > 0 { + query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID) + } else { + query = query.Where("is_private = ?", false) + } - if depth >= 0 { - query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth) - } else { - query = query.Where("target_id = ? AND target_type = ?", targetID, targetType) - } + if depth >= 0 { + query = query.Where("target_id = ? AND target_type = ? AND depth = ?", targetID, targetType, depth) + } else { + query = query.Where("target_id = ? AND target_type = ?", targetID, targetType) + } - items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc) + items, _, err := PaginateQuery[model.Comment](query, page, size, orderBy, desc) - if err != nil { - return nil, err - } + if err != nil { + return nil, err + } - return items, nil + return items, nil +} + +func (cr *CommentRepo) CountComments(targetType string, targetID uint) (int64, error) { + var count int64 + err := GetDB().Model(&model.Comment{}).Where("target_id = ? AND target_type = ?", targetID, targetType).Count(&count).Error + if err != nil { + return 0, err + } + return count, nil } diff --git a/internal/repo/like.go b/internal/repo/like.go index 9caffec..f9c53ae 100644 --- a/internal/repo/like.go +++ b/internal/repo/like.go @@ -1,12 +1,12 @@ package repo import ( - "errors" + "errors" - "github.com/sirupsen/logrus" - "github.com/snowykami/neo-blog/internal/model" - "github.com/snowykami/neo-blog/pkg/constant" - "gorm.io/gorm" + "github.com/sirupsen/logrus" + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/pkg/constant" + "gorm.io/gorm" ) type likeRepo struct{} @@ -14,96 +14,96 @@ type likeRepo struct{} var Like = &likeRepo{} func (l *likeRepo) ToggleLike(userID, targetID uint, targetType string) (bool, error) { - err := l.checkTargetType(targetType) - if err != nil { - return false, err - } - var finalStatus bool - err = GetDB().Transaction(func(tx *gorm.DB) error { - isLiked, err := l.IsLiked(userID, targetID, targetType) - if err != nil { - logrus.Error(err) - return err - } - if isLiked { - if err := - tx.Where("target_type = ? AND target_id = ? AND user_id = ?", targetType, targetID, userID). - Delete(&model.Like{TargetType: targetType, TargetID: targetID, UserID: userID}). - Error; err != nil { - logrus.Error(err) - return err - } - finalStatus = false - } else { - like := &model.Like{ - TargetType: targetType, - TargetID: targetID, - UserID: userID, - } - if err := tx.Create(like).Error; err != nil { - return err - } - finalStatus = true - } - // 重新计算点赞数量 - var count int64 - if err := tx.Model(&model.Like{}).Where("target_type = ? AND target_id = ?", targetType, targetID).Count(&count).Error; err != nil { - return err - } - // 更新目标的点赞数量 - //switch targetType { - //case constant.TargetTypePost: - // if err := tx.Model(&model.Post{}).Where("id = ?", targetID).UpdateColumn("like_count", count).Error; err != nil { - // return err - // } - //case constant.TargetTypeComment: - // if err := tx.Model(&model.Comment{}).Where("id = ?", targetID).UpdateColumn("like_count", count).Error; err != nil { - // return err - // } - //default: - // return errors.New("invalid target type") - //} - return nil - }) - return finalStatus, err + err := l.checkTargetType(targetType) + if err != nil { + return false, err + } + var finalStatus bool + err = GetDB().Transaction(func(tx *gorm.DB) error { + isLiked, err := l.IsLiked(userID, targetID, targetType) + if err != nil { + logrus.Error(err) + return err + } + if isLiked { + if err := + tx.Where("target_type = ? AND target_id = ? AND user_id = ?", targetType, targetID, userID). + Delete(&model.Like{TargetType: targetType, TargetID: targetID, UserID: userID}). + Error; err != nil { + logrus.Error(err) + return err + } + finalStatus = false + } else { + like := &model.Like{ + TargetType: targetType, + TargetID: targetID, + UserID: userID, + } + if err := tx.Create(like).Error; err != nil { + return err + } + finalStatus = true + } + // 重新计算点赞数量 + var count int64 + if err := tx.Model(&model.Like{}).Where("target_type = ? AND target_id = ?", targetType, targetID).Count(&count).Error; err != nil { + return err + } + // 更新目标的点赞数量 + //switch targetType { + //case constant.TargetTypePost: + // if err := tx.Model(&model.Post{}).Where("id = ?", targetID).UpdateColumn("like_count", count).Error; err != nil { + // return err + // } + //case constant.TargetTypeComment: + // if err := tx.Model(&model.Comment{}).Where("id = ?", targetID).UpdateColumn("like_count", count).Error; err != nil { + // return err + // } + //default: + // return errors.New("invalid target type") + //} + return nil + }) + return finalStatus, err } // IsLiked 检查是否点赞 func (l *likeRepo) IsLiked(userID, targetID uint, targetType string) (bool, error) { - err := l.checkTargetType(targetType) - if err != nil { - return false, err - } - var like model.Like - err = GetDB().Where("target_type = ? AND target_id = ? AND user_id = ?", targetType, targetID, userID).First(&like).Error - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - return false, nil - } - return false, err - } - return true, nil + err := l.checkTargetType(targetType) + if err != nil { + return false, err + } + var like model.Like + err = GetDB().Where("target_type = ? AND target_id = ? AND user_id = ?", targetType, targetID, userID).First(&like).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + return false, err + } + return true, nil } // Count 点赞计数 func (l *likeRepo) Count(targetID uint, targetType string) (int64, error) { - err := l.checkTargetType(targetType) - if err != nil { - return 0, err - } - var count int64 - err = GetDB().Model(&model.Like{}).Where("target_type = ? AND target_id = ?", targetType, targetID).Count(&count).Error - if err != nil { - return 0, err - } - return count, nil + err := l.checkTargetType(targetType) + if err != nil { + return 0, err + } + var count int64 + err = GetDB().Model(&model.Like{}).Where("target_type = ? AND target_id = ?", targetType, targetID).Count(&count).Error + if err != nil { + return 0, err + } + return count, nil } func (l *likeRepo) checkTargetType(targetType string) error { - switch targetType { - case constant.TargetTypePost, constant.TargetTypeComment: - return nil - default: - return errors.New("invalid target type") - } + switch targetType { + case constant.TargetTypePost, constant.TargetTypeComment: + return nil + default: + return errors.New("invalid target type") + } } diff --git a/internal/router/apiv1/user.go b/internal/router/apiv1/user.go index 602620a..25bba99 100644 --- a/internal/router/apiv1/user.go +++ b/internal/router/apiv1/user.go @@ -1,25 +1,27 @@ package apiv1 import ( - "github.com/cloudwego/hertz/pkg/route" - "github.com/snowykami/neo-blog/internal/controller/v1" - "github.com/snowykami/neo-blog/internal/middleware" + "github.com/cloudwego/hertz/pkg/route" + "github.com/snowykami/neo-blog/internal/controller/v1" + "github.com/snowykami/neo-blog/internal/middleware" ) func registerUserRoutes(group *route.RouterGroup) { - userController := v1.NewUserController() - userGroup := group.Group("/user").Use(middleware.UseAuth(true)) - userGroupWithoutAuth := group.Group("/user").Use(middleware.UseAuth(false)) - userGroupWithoutAuthNeedsCaptcha := userGroupWithoutAuth.Use(middleware.UseCaptcha()) - { - userGroupWithoutAuthNeedsCaptcha.POST("/login", userController.Login) - userGroupWithoutAuthNeedsCaptcha.POST("/register", userController.Register) - userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", userController.VerifyEmail) // Send email verification code - userGroupWithoutAuth.GET("/oidc/list", userController.OidcList) - userGroupWithoutAuth.GET("/oidc/login/:name", userController.OidcLogin) - userGroupWithoutAuth.GET("/u/:id", userController.GetUser) - userGroup.GET("/me", userController.GetUser) - userGroupWithoutAuth.POST("/logout", userController.Logout) - userGroup.PUT("/u/:id", userController.UpdateUser) - } + userController := v1.NewUserController() + userGroup := group.Group("/user").Use(middleware.UseAuth(true)) + userGroupWithoutAuth := group.Group("/user").Use(middleware.UseAuth(false)) + userGroupWithoutAuthNeedsCaptcha := userGroupWithoutAuth.Use(middleware.UseCaptcha()) + { + userGroupWithoutAuthNeedsCaptcha.POST("/login", userController.Login) + userGroupWithoutAuthNeedsCaptcha.POST("/register", userController.Register) + userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", userController.VerifyEmail) // Send email verification code + userGroupWithoutAuth.GET("/oidc/list", userController.OidcList) + userGroupWithoutAuth.GET("/oidc/login/:name", userController.OidcLogin) + userGroupWithoutAuth.GET("/u/:id", userController.GetUser) + userGroup.GET("/me", userController.GetUser) + userGroupWithoutAuth.POST("/logout", userController.Logout) + userGroup.PUT("/u/:id", userController.UpdateUser) + userGroup.PUT("/password/edit", userController.ChangePassword) + userGroup.PUT("/email/edit", userController.ChangeEmail) + } } diff --git a/internal/service/comment.go b/internal/service/comment.go index 416f6d2..b4ebf77 100644 --- a/internal/service/comment.go +++ b/internal/service/comment.go @@ -19,17 +19,17 @@ func NewCommentService() *CommentService { return &CommentService{} } -func (cs *CommentService) CreateComment(ctx context.Context, req *dto.CreateCommentReq) error { +func (cs *CommentService) CreateComment(ctx context.Context, req *dto.CreateCommentReq) (uint, error) { currentUser, ok := ctxutils.GetCurrentUser(ctx) if !ok { - return errs.ErrUnauthorized + return 0, errs.ErrUnauthorized } if ok, err := cs.checkTargetExists(req.TargetID, req.TargetType); !ok { if err != nil { - return errs.New(errs.ErrBadRequest.Code, "target not found", err) + return 0, errs.New(errs.ErrBadRequest.Code, "target not found", err) } - return errs.ErrBadRequest + return 0, errs.ErrBadRequest } comment := &model.Comment{ @@ -41,13 +41,13 @@ func (cs *CommentService) CreateComment(ctx context.Context, req *dto.CreateComm IsPrivate: req.IsPrivate, } - err := repo.Comment.CreateComment(comment) + commentID, err := repo.Comment.CreateComment(comment) if err != nil { - return err + return 0, err } - return nil + return commentID, nil } func (cs *CommentService) UpdateComment(ctx context.Context, req *dto.UpdateCommentReq) error { diff --git a/web/src/api/comment.ts b/web/src/api/comment.ts index a3d1610..75d8719 100644 --- a/web/src/api/comment.ts +++ b/web/src/api/comment.ts @@ -20,8 +20,8 @@ export async function createComment( replyId: number | null isPrivate: boolean } -): Promise> { - const res = await axiosClient.post>('/comment/c', { +): Promise> { + const res = await axiosClient.post>('/comment/c', { targetType, targetId, content, @@ -52,7 +52,6 @@ export async function deleteComment({ id }: { id: number }): Promise { await axiosClient.delete(`/comment/c/${id}`) } - export async function listComments({ targetType, targetId, @@ -82,4 +81,9 @@ export async function listComments({ } }) return res.data +} + +export async function getComment({ id }: { id: number }): Promise> { + const res = await axiosClient.get>(`/comment/c/${id}`) + return res.data } \ No newline at end of file diff --git a/web/src/components/blog-post/blog-post.tsx b/web/src/components/blog-post/blog-post.tsx index d5e8b46..3d220d6 100644 --- a/web/src/components/blog-post/blog-post.tsx +++ b/web/src/components/blog-post/blog-post.tsx @@ -4,8 +4,8 @@ import { Calendar, Clock, FileText, Flame, Heart, MessageCircle, PenLine, Square import { RenderMarkdown } from "@/components/common/markdown"; import { isMobileByUA } from "@/utils/server/device"; import { calculateReadingTime } from "@/utils/common/post"; -import {CommentSection} from "@/components/comment"; -import { TargetType } from '../../models/types'; +import { CommentSection } from "@/components/comment"; +import { TargetType } from '@/models/types'; function PostMeta({ post }: { post: Post }) { return ( @@ -139,7 +139,7 @@ async function BlogPost({ post }: { post: Post }) { {/* */} - + ); } diff --git a/web/src/components/comment/comment-item.tsx b/web/src/components/comment/comment-item.tsx index 9daf81d..ad58d95 100644 --- a/web/src/components/comment/comment-item.tsx +++ b/web/src/components/comment/comment-item.tsx @@ -12,7 +12,6 @@ import { useDoubleConfirm } from "@/hooks/use-double-confirm"; import { CommentInput } from "./comment-input"; import { createComment, deleteComment, listComments, updateComment } from "@/api/comment"; import { OrderBy } from "@/models/common"; -import config from "@/config"; import { formatDateTime } from "@/utils/common/datetime"; @@ -23,7 +22,8 @@ export function CommentItem( parentComment, onCommentDelete, activeInput, - setActiveInputId + setActiveInputId, + onReplySubmitted // 评论区计数更新用 }: { user: User | null, comment: Comment, @@ -31,6 +31,7 @@ export function CommentItem( onCommentDelete: ({ commentId }: { commentId: number }) => void, activeInput: { id: number; type: 'reply' | 'edit' } | null, setActiveInputId: (input: { id: number; type: 'reply' | 'edit' } | null) => void, + onReplySubmitted: ({ commentContent, isPrivate }: { commentContent: string, isPrivate: boolean }) => void, } ) { const locale = useLocale(); @@ -96,7 +97,7 @@ export function CommentItem( orderBy: OrderBy.CreatedAt, desc: false, page: 1, - size: config.commentsPerPage, + size: 999999, commentId: comment.id } ).then(response => { @@ -136,6 +137,7 @@ export function CommentItem( setShowReplies(true); setActiveInputId(null); setReplyCount(replyCount + 1); + onReplySubmitted({ commentContent, isPrivate }); }).catch(error => { toast.error(t("comment_failed") + ": " + error?.response?.data?.message || error?.message @@ -289,6 +291,7 @@ export function CommentItem( onCommentDelete={onReplyDelete} activeInput={activeInput} setActiveInputId={setActiveInputId} + onReplySubmitted={onReplySubmitted} /> ))} diff --git a/web/src/components/comment/index.tsx b/web/src/components/comment/index.tsx index a6005f0..d9ef0f6 100644 --- a/web/src/components/comment/index.tsx +++ b/web/src/components/comment/index.tsx @@ -5,7 +5,7 @@ import { useTranslations } from "next-intl"; import { Suspense, useEffect, useState } from "react"; import { toast } from "sonner"; import { Comment } from "@/models/comment"; -import { createComment, deleteComment, listComments } from "@/api/comment"; +import { createComment, deleteComment, getComment, listComments } from "@/api/comment"; import { TargetType } from "@/models/types"; import { OrderBy } from "@/models/common"; import { Separator } from "@/components/ui/separator"; @@ -22,18 +22,22 @@ import "./style.css"; export function CommentSection( { targetType, - targetId + targetId, + totalCount = 0 }: { targetType: TargetType, - targetId: number + targetId: number, + totalCount?: number } ) { const t = useTranslations('Comment') const [currentUser, setCurrentUser] = useState(null); const [comments, setComments] = useState([]); - const [refreshCommentsKey, setRefreshCommentsKey] = useState(0); const [activeInput, setActiveInput] = useState<{ id: number; type: 'reply' | 'edit' } | null>(null); + const [page, setPage] = useState(1); // 当前页码 + const [totalCommentCount, setTotalCommentCount] = useState(totalCount); // 评论总数 + const [needLoadMore, setNeedLoadMore] = useState(true); // 是否需要加载更多,当最后一次获取的评论数小于分页大小时设为false // 获取当前登录用户 useEffect(() => { @@ -51,13 +55,13 @@ export function CommentSection( depth: 0, orderBy: OrderBy.CreatedAt, desc: true, - page: 1, + page: page, size: config.commentsPerPage, commentId: 0 }).then(response => { setComments(response.data); }); - }, [refreshCommentsKey]) + }, []) const onCommentSubmitted = ({ commentContent, isPrivate }: { commentContent: string, isPrivate: boolean }) => { createComment({ @@ -66,25 +70,56 @@ export function CommentSection( content: commentContent, replyId: null, isPrivate, - }).then(() => { + }).then(res => { toast.success(t("comment_success")); - setRefreshCommentsKey(k => k + 1); + setTotalCommentCount(c => c + 1); + setComments(prevComments => prevComments.slice(0, -1)); + getComment({ id: res.data.id }).then(response => { + console.log("New comment fetched:", response.data); + setComments(prevComments => [response.data, ...prevComments]); + }); + setActiveInput(null); }) } + const onReplySubmitted = ({ }: { commentContent: string, isPrivate: boolean }) => { + setTotalCommentCount(c => c + 1); + } + const onCommentDelete = ({ commentId }: { commentId: number }) => { deleteComment({ id: commentId }).then(() => { toast.success(t("delete_success")); - setRefreshCommentsKey(k => k + 1); + setComments(prevComments => prevComments.filter(comment => comment.id !== commentId)); + setTotalCommentCount(c => c - 1); }).catch(error => { toast.error(t("delete_failed") + ": " + error.message); }); } + const handleLoadMore = () => { + const nextPage = page + 1; + listComments({ + targetType, + targetId, + depth: 0, + orderBy: OrderBy.CreatedAt, + desc: true, + page: nextPage, + size: config.commentsPerPage, + commentId: 0 + }).then(response => { + if (response.data.length < config.commentsPerPage) { + setNeedLoadMore(false); + } + setComments(prevComments => [...prevComments, ...response.data]); + setPage(nextPage); + }); + } + return (
-
{t("comment")}
+
{t("comment")} ({totalCommentCount})
}> {comments.map((comment, idx) => ( -
+
))} + {needLoadMore ? +

+ {t("load_more")} +

+ : +

+ {t("no_more")} +

+ }
) diff --git a/web/src/config.ts b/web/src/config.ts index 926f740..57846dd 100644 --- a/web/src/config.ts +++ b/web/src/config.ts @@ -15,7 +15,7 @@ const config = { bodyWidth: "80vw", bodyWidthMobile: "100vw", postsPerPage: 12, - commentsPerPage: 20, + commentsPerPage: 8, footer: { text: "Liteyuki ICP备 1145141919810", links: [] diff --git a/web/src/locales/zh-CN.json b/web/src/locales/zh-CN.json index e7f2685..074506f 100644 --- a/web/src/locales/zh-CN.json +++ b/web/src/locales/zh-CN.json @@ -21,7 +21,9 @@ "like": "点赞", "like_failed": "点赞失败", "like_success": "点赞成功", + "load_more": "加载更多", "login_required": "请先登录后再操作", + "no_more": "没有更多了!", "placeholder": "写下你的评论...", "private": "私密评论", "private_placeholder": "悄悄地说一句...",