diff --git a/internal/model/oidc_config.go b/internal/model/oidc_config.go index 7edfba7..274cf54 100644 --- a/internal/model/oidc_config.go +++ b/internal/model/oidc_config.go @@ -1,117 +1,117 @@ package model import ( - "fmt" - "time" + "fmt" + "time" - "github.com/sirupsen/logrus" - "github.com/snowykami/neo-blog/internal/dto" - "gorm.io/gorm" - "resty.dev/v3" + "github.com/sirupsen/logrus" + "github.com/snowykami/neo-blog/internal/dto" + "gorm.io/gorm" + "resty.dev/v3" ) type OidcConfig struct { - gorm.Model - Name string `gorm:"uniqueIndex"` // OIDC配置名称,唯一 - ClientID string // 客户端ID - ClientSecret string // 客户端密钥 - DisplayName string // 显示名称,例如:轻雪通行证 - Icon string // 图标url,为空则使用内置默认图标 - OidcDiscoveryUrl string // OpenID自动发现URL,例如 :https://pass.liteyuki.org/.well-known/openid-configuration - Enabled bool `gorm:"default:true"` // 是否启用 - Type string `gorm:"oauth2"` // OIDC类型,默认为oauth2,也可以为misskey - // 以下字段为自动获取字段,每次更新配置时自动填充 - Issuer string - AuthorizationEndpoint string - TokenEndpoint string - UserInfoEndpoint string - JwksUri string + gorm.Model + Name string `gorm:"uniqueIndex"` // OIDC配置名称,唯一 + ClientID string // 客户端ID + ClientSecret string // 客户端密钥 + DisplayName string // 显示名称,例如:轻雪通行证 + Icon string // 图标url,为空则使用内置默认图标 + OidcDiscoveryUrl string // OpenID自动发现URL,例如 :https://pass.liteyuki.org/.well-known/openid-configuration + Enabled bool `gorm:"default:true"` // 是否启用 + Type string `gorm:"oauth2"` // OIDC类型,默认为oauth2,也可以为misskey + // 以下字段为自动获取字段,每次更新配置时自动填充 + Issuer string + AuthorizationEndpoint string + TokenEndpoint string + UserInfoEndpoint string + JwksUri string } type oidcDiscoveryResp struct { - Issuer string `json:"issuer" validate:"required"` - AuthorizationEndpoint string `json:"authorization_endpoint" validate:"required"` - TokenEndpoint string `json:"token_endpoint" validate:"required"` - UserInfoEndpoint string `json:"userinfo_endpoint" validate:"required"` - JwksUri string `json:"jwks_uri" validate:"required"` - // 可选字段 - RegistrationEndpoint string `json:"registration_endpoint,omitempty"` - ScopesSupported []string `json:"scopes_supported,omitempty"` - ResponseTypesSupported []string `json:"response_types_supported,omitempty"` - GrantTypesSupported []string `json:"grant_types_supported,omitempty"` - SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` - IdTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` - ClaimsSupported []string `json:"claims_supported,omitempty"` - EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` + Issuer string `json:"issuer" validate:"required"` + AuthorizationEndpoint string `json:"authorization_endpoint" validate:"required"` + TokenEndpoint string `json:"token_endpoint" validate:"required"` + UserInfoEndpoint string `json:"userinfo_endpoint" validate:"required"` + JwksUri string `json:"jwks_uri" validate:"required"` + // 可选字段 + RegistrationEndpoint string `json:"registration_endpoint,omitempty"` + ScopesSupported []string `json:"scopes_supported,omitempty"` + ResponseTypesSupported []string `json:"response_types_supported,omitempty"` + GrantTypesSupported []string `json:"grant_types_supported,omitempty"` + SubjectTypesSupported []string `json:"subject_types_supported,omitempty"` + IdTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported,omitempty"` + ClaimsSupported []string `json:"claims_supported,omitempty"` + EndSessionEndpoint string `json:"end_session_endpoint,omitempty"` } func updateOidcConfigFromUrl(url string, typ string) (*oidcDiscoveryResp, error) { - client := resty.New() - client.SetTimeout(10 * time.Second) // 设置超时时间 - var discovery oidcDiscoveryResp - resp, err := client.R(). - SetHeader("Accept", "application/json"). - SetResult(&discovery). - Get(url) - if err != nil { - return nil, fmt.Errorf("请求OIDC发现端点失败: %w", err) - } - if resp.StatusCode() != 200 { - return nil, fmt.Errorf("请求OIDC发现端点失败,状态码: %d", resp.StatusCode()) - } - // 验证必要字段 - if typ == "misskey" { - discovery.UserInfoEndpoint = discovery.Issuer + "/api/users/me" // Misskey的用户信息端点 - discovery.JwksUri = discovery.Issuer + "/api/jwks" - } - fmt.Println(discovery) - if discovery.Issuer == "" || - discovery.AuthorizationEndpoint == "" || - discovery.TokenEndpoint == "" || - discovery.UserInfoEndpoint == "" || - discovery.JwksUri == "" { - return nil, fmt.Errorf("OIDC发现端点响应缺少必要字段") - } - return &discovery, nil + client := resty.New() + client.SetTimeout(10 * time.Second) // 设置超时时间 + var discovery oidcDiscoveryResp + resp, err := client.R(). + SetHeader("Accept", "application/json"). + SetResult(&discovery). + Get(url) + if err != nil { + return nil, fmt.Errorf("请求OIDC发现端点失败: %w", err) + } + if resp.StatusCode() != 200 { + return nil, fmt.Errorf("请求OIDC发现端点失败,状态码: %d", resp.StatusCode()) + } + // 验证必要字段 + if typ == "misskey" { + discovery.UserInfoEndpoint = discovery.Issuer + "/api/users/me" // Misskey的用户信息端点 + discovery.JwksUri = discovery.Issuer + "/api/jwks" + } + fmt.Println(discovery) + if discovery.Issuer == "" || + discovery.AuthorizationEndpoint == "" || + discovery.TokenEndpoint == "" || + discovery.UserInfoEndpoint == "" || + discovery.JwksUri == "" { + return nil, fmt.Errorf("OIDC发现端点响应缺少必要字段") + } + return &discovery, nil } func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) { - // 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息 - if tx.Statement.Changed("OidcDiscoveryUrl") || o.ID == 0 { - logrus.Infof("Updating OIDC config for %s, OidcDiscoveryUrl: %s", o.Name, o.OidcDiscoveryUrl) - discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl, o.Type) - if err != nil { - logrus.Error("Updating OIDC config failed: ", err) - return fmt.Errorf("updating OIDC config failed: %w", err) - } - o.Issuer = discoveryResp.Issuer - o.AuthorizationEndpoint = discoveryResp.AuthorizationEndpoint - o.TokenEndpoint = discoveryResp.TokenEndpoint - o.UserInfoEndpoint = discoveryResp.UserInfoEndpoint - o.JwksUri = discoveryResp.JwksUri - } - return nil + // 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息 + if tx.Statement.Changed("OidcDiscoveryUrl") || o.ID == 0 { + logrus.Infof("Updating OIDC config for %s, OidcDiscoveryUrl: %s", o.Name, o.OidcDiscoveryUrl) + discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl, o.Type) + if err != nil { + logrus.Error("Updating OIDC config failed: ", err) + return fmt.Errorf("updating OIDC config failed: %w", err) + } + o.Issuer = discoveryResp.Issuer + o.AuthorizationEndpoint = discoveryResp.AuthorizationEndpoint + o.TokenEndpoint = discoveryResp.TokenEndpoint + o.UserInfoEndpoint = discoveryResp.UserInfoEndpoint + o.JwksUri = discoveryResp.JwksUri + } + return nil } // ToUserDto 返回给用户侧 func (o *OidcConfig) ToUserDto() *dto.UserOidcConfigDto { - return &dto.UserOidcConfigDto{ - Name: o.Name, - DisplayName: o.DisplayName, - Icon: o.Icon, - } + return &dto.UserOidcConfigDto{ + Name: o.Name, + DisplayName: o.DisplayName, + Icon: o.Icon, + } } // ToAdminDto 返回给管理员侧 func (o *OidcConfig) ToAdminDto() *dto.AdminOidcConfigDto { - return &dto.AdminOidcConfigDto{ - ID: o.ID, - Name: o.Name, - ClientID: o.ClientID, - ClientSecret: o.ClientSecret, - DisplayName: o.DisplayName, - Icon: o.Icon, - OidcDiscoveryUrl: o.OidcDiscoveryUrl, - Enabled: o.Enabled, - } + return &dto.AdminOidcConfigDto{ + ID: o.ID, + Name: o.Name, + ClientID: o.ClientID, + ClientSecret: o.ClientSecret, + DisplayName: o.DisplayName, + Icon: o.Icon, + OidcDiscoveryUrl: o.OidcDiscoveryUrl, + Enabled: o.Enabled, + } } diff --git a/internal/repo/comment.go b/internal/repo/comment.go index 89593e8..f3ca686 100644 --- a/internal/repo/comment.go +++ b/internal/repo/comment.go @@ -214,9 +214,33 @@ func (cr *CommentRepo) ListComments(currentUserID, targetID, commentID uint, tar return items, nil } -func (cr *CommentRepo) CountReplyComments(commentID uint) (int64, error) { +func (cr *CommentRepo) CountReplyComments(currentUserID, commentID uint) (int64, error) { var count int64 - if err := GetDB().Model(&model.Comment{}).Where("reply_id = ?", commentID).Count(&count).Error; err != nil { + var masterID uint + + // 根据commentID查询所属对象的用户ID + comment, err := cr.GetComment(strconv.Itoa(int(commentID))) + if err != nil { + return 0, err + } + if comment.TargetType == constant.TargetTypePost { + post, err := Post.GetPostByID(strconv.Itoa(int(comment.TargetID))) + if err != nil { + return 0, err + } + masterID = post.UserID + } else { + // 如果不是文章类型,可以根据需要添加其他类型的处理逻辑 + return 0, errs.New(http.StatusBadRequest, "unsupported target type for counting replies", nil) + } + + query := GetDB().Model(&model.Comment{}).Where("reply_id = ?", commentID) + if currentUserID > 0 { + query = query.Where("(is_private = ? OR (is_private = ? AND (user_id = ? OR user_id = ?)))", false, true, currentUserID, masterID) + } else { + query = query.Where("is_private = ?", false) + } + if err := query.Count(&count).Error; err != nil { return 0, err } return count, nil diff --git a/internal/service/comment.go b/internal/service/comment.go index 85218d1..7161f6e 100644 --- a/internal/service/comment.go +++ b/internal/service/comment.go @@ -1,177 +1,177 @@ package service import ( - "context" - "strconv" + "context" + "strconv" - "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/constant" - "github.com/snowykami/neo-blog/internal/ctxutils" - "github.com/snowykami/neo-blog/internal/dto" - "github.com/snowykami/neo-blog/internal/model" - "github.com/snowykami/neo-blog/internal/repo" - "github.com/snowykami/neo-blog/pkg/errs" + "github.com/snowykami/neo-blog/internal/ctxutils" + "github.com/snowykami/neo-blog/internal/dto" + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/internal/repo" + "github.com/snowykami/neo-blog/pkg/errs" ) type CommentService struct{} func NewCommentService() *CommentService { - return &CommentService{} + return &CommentService{} } func (cs *CommentService) CreateComment(ctx context.Context, req *dto.CreateCommentReq) error { - currentUser, ok := ctxutils.GetCurrentUser(ctx) - if !ok { - return errs.ErrUnauthorized - } + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } - if ok, err := cs.checkTargetExists(req.TargetID, req.TargetType); !ok { - if err != nil { - return errs.New(errs.ErrBadRequest.Code, "target not found", err) - } - return errs.ErrBadRequest - } + if ok, err := cs.checkTargetExists(req.TargetID, req.TargetType); !ok { + if err != nil { + return errs.New(errs.ErrBadRequest.Code, "target not found", err) + } + return errs.ErrBadRequest + } - comment := &model.Comment{ - Content: req.Content, - ReplyID: req.ReplyID, - TargetID: req.TargetID, - TargetType: req.TargetType, - UserID: currentUser.ID, - IsPrivate: req.IsPrivate, - } + comment := &model.Comment{ + Content: req.Content, + ReplyID: req.ReplyID, + TargetID: req.TargetID, + TargetType: req.TargetType, + UserID: currentUser.ID, + IsPrivate: req.IsPrivate, + } - err := repo.Comment.CreateComment(comment) + err := repo.Comment.CreateComment(comment) - if err != nil { - return err - } + if err != nil { + return err + } - return nil + return nil } func (cs *CommentService) UpdateComment(ctx context.Context, req *dto.UpdateCommentReq) error { - currentUser, ok := ctxutils.GetCurrentUser(ctx) - if !ok { - return errs.ErrUnauthorized - } + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } - comment, err := repo.Comment.GetComment(strconv.Itoa(int(req.CommentID))) - if err != nil { - return err - } + comment, err := repo.Comment.GetComment(strconv.Itoa(int(req.CommentID))) + if err != nil { + return err + } - if currentUser.ID != comment.UserID { - return errs.ErrForbidden - } + if currentUser.ID != comment.UserID { + return errs.ErrForbidden + } - comment.Content = req.Content - comment.IsPrivate = req.IsPrivate + comment.Content = req.Content + comment.IsPrivate = req.IsPrivate - err = repo.Comment.UpdateComment(comment) + err = repo.Comment.UpdateComment(comment) - if err != nil { - return err - } + if err != nil { + return err + } - return nil + return nil } func (cs *CommentService) DeleteComment(ctx context.Context, commentID string) error { - currentUser, ok := ctxutils.GetCurrentUser(ctx) - if !ok { - return errs.ErrUnauthorized - } - if commentID == "" { - return errs.ErrBadRequest - } + currentUser, ok := ctxutils.GetCurrentUser(ctx) + if !ok { + return errs.ErrUnauthorized + } + if commentID == "" { + return errs.ErrBadRequest + } - comment, err := repo.Comment.GetComment(commentID) - if err != nil { - return errs.New(errs.ErrNotFound.Code, "comment not found", err) - } + comment, err := repo.Comment.GetComment(commentID) + if err != nil { + return errs.New(errs.ErrNotFound.Code, "comment not found", err) + } - if comment.UserID != currentUser.ID { - return errs.ErrForbidden - } + if comment.UserID != currentUser.ID { + return errs.ErrForbidden + } - if err := repo.Comment.DeleteComment(commentID); err != nil { - return err - } + if err := repo.Comment.DeleteComment(commentID); err != nil { + return err + } - return nil + return nil } func (cs *CommentService) GetComment(ctx context.Context, commentID string) (*dto.CommentDto, error) { - comment, err := repo.Comment.GetComment(commentID) + comment, err := repo.Comment.GetComment(commentID) - if err != nil { - return nil, errs.New(errs.ErrNotFound.Code, "comment not found", err) - } + if err != nil { + return nil, errs.New(errs.ErrNotFound.Code, "comment not found", err) + } - commentDto := dto.CommentDto{ - ID: comment.ID, - TargetID: comment.TargetID, - TargetType: comment.TargetType, - Content: comment.Content, - ReplyID: comment.ReplyID, - Depth: comment.Depth, - CreatedAt: comment.CreatedAt.String(), - UpdatedAt: comment.UpdatedAt.String(), - User: comment.User.ToDto(), - } + commentDto := dto.CommentDto{ + ID: comment.ID, + TargetID: comment.TargetID, + TargetType: comment.TargetType, + Content: comment.Content, + ReplyID: comment.ReplyID, + Depth: comment.Depth, + CreatedAt: comment.CreatedAt.String(), + UpdatedAt: comment.UpdatedAt.String(), + User: comment.User.ToDto(), + } - return &commentDto, err + return &commentDto, err } func (cs *CommentService) GetCommentList(ctx context.Context, req *dto.GetCommentListReq) ([]dto.CommentDto, error) { - currentUserID := uint(0) - if currentUser, ok := ctxutils.GetCurrentUser(ctx); ok { - currentUserID = currentUser.ID - } + currentUserID := uint(0) + if currentUser, ok := ctxutils.GetCurrentUser(ctx); ok { + currentUserID = currentUser.ID + } - comments, err := repo.Comment.ListComments(currentUserID, req.TargetID, req.CommentID, req.TargetType, req.Page, req.Size, req.OrderBy, req.Desc, req.Depth) - if err != nil { - return nil, errs.New(errs.ErrInternalServer.Code, "failed to list comments", err) - } + comments, err := repo.Comment.ListComments(currentUserID, req.TargetID, req.CommentID, req.TargetType, req.Page, req.Size, req.OrderBy, req.Desc, req.Depth) + if err != nil { + return nil, errs.New(errs.ErrInternalServer.Code, "failed to list comments", err) + } - commentDtos := make([]dto.CommentDto, 0) + commentDtos := make([]dto.CommentDto, 0) - for _, comment := range comments { - replyCount, _ := repo.Comment.CountReplyComments(comment.ID) - isLiked := false - if currentUserID != 0 { - isLiked, _ = repo.Like.IsLiked(currentUserID, comment.ID, constant.TargetTypeComment) - } + for _, comment := range comments { + replyCount, _ := repo.Comment.CountReplyComments(currentUserID, comment.ID) + isLiked := false + if currentUserID != 0 { + isLiked, _ = repo.Like.IsLiked(currentUserID, comment.ID, constant.TargetTypeComment) + } - commentDto := dto.CommentDto{ - ID: comment.ID, - Content: comment.Content, - TargetID: comment.TargetID, - TargetType: comment.TargetType, - ReplyID: comment.ReplyID, - CreatedAt: comment.CreatedAt.String(), - UpdatedAt: comment.UpdatedAt.String(), - Depth: comment.Depth, - User: comment.User.ToDto(), - ReplyCount: replyCount, - LikeCount: comment.LikeCount, - IsLiked: isLiked, - IsPrivate: comment.IsPrivate, - } - commentDtos = append(commentDtos, commentDto) - } - return commentDtos, nil + commentDto := dto.CommentDto{ + ID: comment.ID, + Content: comment.Content, + TargetID: comment.TargetID, + TargetType: comment.TargetType, + ReplyID: comment.ReplyID, + CreatedAt: comment.CreatedAt.String(), + UpdatedAt: comment.UpdatedAt.String(), + Depth: comment.Depth, + User: comment.User.ToDto(), + ReplyCount: replyCount, + LikeCount: comment.LikeCount, + IsLiked: isLiked, + IsPrivate: comment.IsPrivate, + } + commentDtos = append(commentDtos, commentDto) + } + return commentDtos, nil } func (cs *CommentService) checkTargetExists(targetID uint, targetType string) (bool, error) { - switch targetType { - case constant.TargetTypePost: - if _, err := repo.Post.GetPostByID(strconv.Itoa(int(targetID))); err != nil { - return false, errs.New(errs.ErrNotFound.Code, "post not found", err) - } - default: - return false, errs.New(errs.ErrBadRequest.Code, "invalid target type", nil) - } - return true, nil + switch targetType { + case constant.TargetTypePost: + if _, err := repo.Post.GetPostByID(strconv.Itoa(int(targetID))); err != nil { + return false, errs.New(errs.ErrNotFound.Code, "post not found", err) + } + default: + return false, errs.New(errs.ErrBadRequest.Code, "invalid target type", nil) + } + return true, nil } diff --git a/web/src/components/comment/comment-input.tsx b/web/src/components/comment/comment-input.tsx index 15ad715..24ba64b 100644 --- a/web/src/components/comment/comment-input.tsx +++ b/web/src/components/comment/comment-input.tsx @@ -27,15 +27,22 @@ export function CommentInput( } ) { const t = useTranslations('Comment') - const handleToLogin = useToLogin() - const toUserProfile = useToUserProfile(); + const commonT = useTranslations('Common') + const clickToLogin = useToLogin() + const clickToUserProfile = useToUserProfile(); const [isPrivate, setIsPrivate] = useState(initIsPrivate); const [commentContent, setCommentContent] = useState(initContent); const handleCommentSubmit = async () => { if (!user) { - toast.error({t("login_required")}); + // 通知 + toast.error(t("login_required"), { + action: { + label: commonT("login"), + onClick: clickToLogin, + }, + }) return; } if (!commentContent.trim()) { @@ -49,13 +56,13 @@ export function CommentInput( return (
-
toUserProfile(user.username) : handleToLogin} className="flex-shrink-0 w-10 h-10 fade-in"> +
clickToUserProfile(user.username) : clickToLogin} className="flex-shrink-0 w-10 h-10 fade-in"> {user ? getGravatarByUser(user) : null} {!user && }