feat: Refactor comment section to correctly handle API response structure
All checks were successful
Push to Helm Chart Repository / build (push) Successful in 9s

fix: Update Gravatar URL size and improve avatar rendering logic

style: Adjust footer margin for better layout consistency

refactor: Remove old navbar component and integrate new layout structure

feat: Enhance user profile page with user header component

chore: Remove unused user profile component

fix: Update posts per page configuration for better pagination

feat: Extend device context to support system theme mode

refactor: Remove unused device hook

fix: Improve storage state hook for better error handling

i18n: Add new translations for blog home page

feat: Implement pagination component for better navigation

feat: Create theme toggle component for improved user experience

feat: Introduce responsive navbar or side layout with theme toggle

feat: Develop custom select component for better UI consistency

feat: Create user header component to display user information

chore: Add query key constants for better code maintainability
This commit is contained in:
2025-09-12 00:26:08 +08:00
parent b3e8a5ef77
commit d1d8aa529f
36 changed files with 1443 additions and 731 deletions

View File

@ -128,13 +128,13 @@ func (cc *CommentController) GetCommentList(ctx context.Context, c *app.RequestC
TargetType: c.Query("target_type"),
CommentID: commentID,
}
resp, err := cc.service.GetCommentList(ctx, &req)
commentDtos, err := cc.service.GetCommentList(ctx, &req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, resp)
resps.Ok(c, resps.Success, utils.H{"comments": commentDtos})
}
func (cc *CommentController) ReactComment(ctx context.Context, c *app.RequestContext) {

View File

@ -1,121 +1,150 @@
package v1
import (
"context"
"slices"
"strings"
"context"
"slices"
"strings"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/snowykami/neo-blog/internal/ctxutils"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/service"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/resps"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/snowykami/neo-blog/internal/ctxutils"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/service"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/resps"
)
type PostController struct {
service *service.PostService
service *service.PostService
}
func NewPostController() *PostController {
return &PostController{
service: service.NewPostService(),
}
return &PostController{
service: service.NewPostService(),
}
}
func (p *PostController) Create(ctx context.Context, c *app.RequestContext) {
var req dto.CreateOrUpdatePostReq
if err := c.BindAndValidate(&req); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
postID, err := p.service.CreatePost(ctx, &req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, utils.H{"id": postID})
var req dto.CreateOrUpdatePostReq
if err := c.BindAndValidate(&req); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
postID, err := p.service.CreatePost(ctx, &req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, utils.H{"id": postID})
}
func (p *PostController) Delete(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
if err := p.service.DeletePost(ctx, id); err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, nil)
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
if err := p.service.DeletePost(ctx, id); err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, nil)
}
func (p *PostController) Get(ctx context.Context, c *app.RequestContext) {
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
post, err := p.service.GetPost(ctx, id)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
if post == nil {
resps.NotFound(c, resps.ErrNotFound)
return
}
resps.Ok(c, resps.Success, post)
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
post, err := p.service.GetPost(ctx, id)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
if post == nil {
resps.NotFound(c, resps.ErrNotFound)
return
}
resps.Ok(c, resps.Success, post)
}
func (p *PostController) Update(ctx context.Context, c *app.RequestContext) {
var req dto.CreateOrUpdatePostReq
if err := c.BindAndValidate(&req); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
postID, err := p.service.UpdatePost(ctx, id, &req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, utils.H{"id": postID})
var req dto.CreateOrUpdatePostReq
if err := c.BindAndValidate(&req); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
id := c.Param("id")
if id == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
postID, err := p.service.UpdatePost(ctx, id, &req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, utils.H{"id": postID})
}
func (p *PostController) List(ctx context.Context, c *app.RequestContext) {
pagination := ctxutils.GetPaginationParams(c)
if pagination.OrderBy == "" {
pagination.OrderBy = constant.OrderByUpdatedAt
}
if pagination.OrderBy != "" && !slices.Contains(constant.OrderByEnumPost, pagination.OrderBy) {
resps.BadRequest(c, "无效的排序字段")
return
}
keywords := c.Query("keywords")
keywordsArray := strings.Split(keywords, ",")
req := &dto.ListPostReq{
Keywords: keywordsArray,
Page: pagination.Page,
Size: pagination.Size,
OrderBy: pagination.OrderBy,
Desc: pagination.Desc,
}
posts, err := p.service.ListPosts(ctx, req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, posts)
pagination := ctxutils.GetPaginationParams(c)
if pagination.OrderBy == "" {
pagination.OrderBy = constant.OrderByUpdatedAt
}
if pagination.OrderBy != "" && !slices.Contains(constant.OrderByEnumPost, pagination.OrderBy) {
resps.BadRequest(c, "无效的排序字段")
return
}
keywords := c.Query("keywords")
keywordsArray := strings.Split(keywords, ",")
labels := c.Query("labels")
labelStringArray := strings.Split(labels, ",")
labelRule := c.Query("label_rule")
if labelRule != "intersection" {
labelRule = "union"
}
labelDtos := make([]dto.LabelDto, 0, len(labelStringArray))
for _, labelString := range labelStringArray {
// :分割key和value
if labelString == "" {
continue
}
parts := strings.SplitN(labelString, ":", 2)
if len(parts) == 2 {
labelDtos = append(labelDtos, dto.LabelDto{
Key: parts[0],
Value: parts[1],
})
} else {
labelDtos = append(labelDtos, dto.LabelDto{
Key: parts[0],
Value: "",
})
}
}
req := &dto.ListPostReq{
Keywords: keywordsArray,
Labels: labelDtos,
LabelRule: labelRule,
Page: pagination.Page,
Size: pagination.Size,
OrderBy: pagination.OrderBy,
Desc: pagination.Desc,
}
posts, total, err := p.service.ListPosts(ctx, req)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, utils.H{"posts": posts, "total": total})
}

View File

@ -31,11 +31,13 @@ type CreateOrUpdatePostReq struct {
}
type ListPostReq struct {
Keywords []string `json:"keywords"` // 关键词列表
OrderBy string `json:"order_by"` // 排序方式
Page uint64 `json:"page"` // 页码
Size uint64 `json:"size"`
Desc bool `json:"desc"`
Keywords []string `json:"keywords"` // 关键词列表
OrderBy string `json:"order_by"` // 排序方式
Page uint64 `json:"page"` // 页码
Size uint64 `json:"size"`
Desc bool `json:"desc"`
Labels []LabelDto `json:"labels"`
LabelRule string `json:"label_rule"` // 标签过滤规则 union or intersection
}
type ListPostResp struct {

View File

@ -12,7 +12,7 @@ type User struct {
AvatarUrl string
Email string `gorm:"uniqueIndex"`
Gender string
Role string `gorm:"default:'user'"`
Role string `gorm:"default:'user'"` // user editor admin
Language string `gorm:"default:'en'"`
Password string // 密码,存储加密后的值
}

View File

@ -18,6 +18,26 @@ func (l *labelRepo) GetLabelByKey(key string) (*model.Label, error) {
return &label, nil
}
func (l *labelRepo) GetLabelByValue(value string) (*model.Label, error) {
var label model.Label
if err := GetDB().Where("value = ?", value).First(&label).Error; err != nil {
return nil, err
}
return &label, nil
}
func (l *labelRepo) GetLabelByKeyAndValue(key, value string) (*model.Label, error) {
var label model.Label
query := GetDB().Where("key = ?", key)
if value != "" {
query = query.Where("value = ?", value)
}
if err := GetDB().Where(query).First(&label).Error; err != nil {
return nil, err
}
return &label, nil
}
func (l *labelRepo) GetLabelByID(id string) (*model.Label, error) {
var label model.Label
if err := GetDB().Where("id = ?", id).First(&label).Error; err != nil {

View File

@ -1,12 +1,15 @@
package repo
import (
"errors"
"net/http"
"slices"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/model"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"gorm.io/gorm"
)
type postRepo struct{}
@ -48,9 +51,9 @@ func (p *postRepo) UpdatePost(post *model.Post) error {
return nil
}
func (p *postRepo) ListPosts(currentUserID uint, keywords []string, page, size uint64, orderBy string, desc bool) ([]model.Post, error) {
func (p *postRepo) ListPosts(currentUserID uint, keywords []string, labels []dto.LabelDto, labelRule string, page, size uint64, orderBy string, desc bool) ([]model.Post, int64, error) {
if !slices.Contains(constant.OrderByEnumPost, orderBy) {
return nil, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil)
return nil, 0, errs.New(http.StatusBadRequest, "invalid order_by parameter", nil)
}
query := GetDB().Model(&model.Post{}).Preload("User")
if currentUserID > 0 {
@ -58,20 +61,43 @@ func (p *postRepo) ListPosts(currentUserID uint, keywords []string, page, size u
} else {
query = query.Where("is_private = ?", false)
}
if len(labels) > 0 {
var labelIds []uint
for _, labelDto := range labels {
label, _ := Label.GetLabelByKeyAndValue(labelDto.Key, labelDto.Value)
labelIds = append(labelIds, label.ID)
}
if labelRule == "intersection" {
query = query.Joins("JOIN post_labels ON post_labels.post_id = posts.id").
Where("post_labels.label_id IN ?", labelIds).
Group("posts.id").
Having("COUNT(DISTINCT post_labels.label_id) = ?", len(labelIds))
} else {
query = query.Joins("JOIN post_labels ON post_labels.post_id = posts.id").
Where("post_labels.label_id IN ?", labelIds)
}
}
if len(keywords) > 0 {
for _, keyword := range keywords {
if keyword != "" {
// 使用LIKE进行模糊匹配搜索标题、内容和标签
query = query.Where("title LIKE ? OR content LIKE ?", // TODO: 支持标签搜索
query = query.Where("title LIKE ? OR content LIKE ?",
"%"+keyword+"%", "%"+keyword+"%")
}
}
}
var total int64
if err := query.Count(&total).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, 0, err
}
items, _, err := PaginateQuery[model.Post](query, page, size, orderBy, desc)
if err != nil {
return nil, err
return nil, 0, err
}
return items, nil
return items, total, nil
}
func (p *postRepo) ToggleLikePost(postID uint, userID uint) (bool, error) {

View File

@ -114,17 +114,17 @@ func (p *PostService) UpdatePost(ctx context.Context, id string, req *dto.Create
return post.ID, nil
}
func (p *PostService) ListPosts(ctx context.Context, req *dto.ListPostReq) ([]*dto.PostDto, error) {
func (p *PostService) ListPosts(ctx context.Context, req *dto.ListPostReq) ([]*dto.PostDto, int64, error) {
postDtos := make([]*dto.PostDto, 0)
currentUserID, _ := ctxutils.GetCurrentUserID(ctx)
posts, err := repo.Post.ListPosts(currentUserID, req.Keywords, req.Page, req.Size, req.OrderBy, req.Desc)
posts, total, err := repo.Post.ListPosts(currentUserID, req.Keywords, req.Labels, req.LabelRule, req.Page, req.Size, req.OrderBy, req.Desc)
if err != nil {
return nil, errs.New(errs.ErrInternalServer.Code, "failed to list posts", err)
return nil, total, errs.New(errs.ErrInternalServer.Code, "failed to list posts", err)
}
for _, post := range posts {
postDtos = append(postDtos, post.ToDtoWithShortContent(100))
}
return postDtos, nil
return postDtos, total, nil
}
func (p *PostService) ToggleLikePost(ctx context.Context, id string) (bool, error) {

View File

@ -1,400 +1,400 @@
package service
import (
"errors"
"fmt"
"net/http"
"strings"
"time"
"errors"
"fmt"
"net/http"
"strings"
"time"
"github.com/sirupsen/logrus"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/model"
"github.com/snowykami/neo-blog/internal/repo"
"github.com/snowykami/neo-blog/internal/static"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/utils"
"gorm.io/gorm"
"github.com/sirupsen/logrus"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/model"
"github.com/snowykami/neo-blog/internal/repo"
"github.com/snowykami/neo-blog/internal/static"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/utils"
"gorm.io/gorm"
)
type UserService struct{}
func NewUserService() *UserService {
return &UserService{}
return &UserService{}
}
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
logrus.Warnf("User not found: %s", req.Username)
return nil, errs.ErrNotFound
}
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.UserLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
return nil, errs.ErrInternalServer
}
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
logrus.Warnf("User not found: %s", req.Username)
return nil, errs.ErrNotFound
}
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.UserLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
return nil, errs.ErrInternalServer
}
}
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
// 验证邮箱验证码
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
return nil, errs.ErrForbidden
}
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
if err != nil {
logrus.Errorln("Failed to verify email:", err)
return nil, errs.ErrInternalServer
}
if !ok {
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
}
// 检查用户名或邮箱是否已存在
usernameExist, err := repo.User.CheckUsernameExists(req.Username)
if err != nil {
return nil, errs.ErrInternalServer
}
emailExist, err := repo.User.CheckEmailExists(req.Email)
if err != nil {
return nil, errs.ErrInternalServer
}
if usernameExist || emailExist {
return nil, errs.New(http.StatusConflict, "Username or email already exists", nil)
}
// 创建新用户
hashedPassword, err := utils.Password.HashPassword(req.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt"))
if err != nil {
logrus.Errorln("Failed to hash password:", err)
return nil, errs.ErrInternalServer
}
newUser := &model.User{
Username: req.Username,
Nickname: req.Nickname,
Email: req.Email,
Gender: "",
Role: "user",
Password: hashedPassword,
}
err = repo.User.CreateUser(newUser)
if err != nil {
return nil, errs.ErrInternalServer
}
// 创建默认管理员账户
if newUser.ID == 1 {
newUser.Role = constant.RoleAdmin
err = repo.User.UpdateUser(newUser)
if err != nil {
logrus.Errorln("Failed to update user role to admin:", err)
return nil, errs.ErrInternalServer
}
}
// 生成访问令牌和刷新令牌
token, refreshToken, err := s.generate2Token(newUser.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.UserRegisterResp{
Token: token,
RefreshToken: refreshToken,
User: newUser.ToDto(),
}
return resp, nil
// 验证邮箱验证码
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
return nil, errs.ErrForbidden
}
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
if err != nil {
logrus.Errorln("Failed to verify email:", err)
return nil, errs.ErrInternalServer
}
if !ok {
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
}
// 检查用户名或邮箱是否已存在
usernameExist, err := repo.User.CheckUsernameExists(req.Username)
if err != nil {
return nil, errs.ErrInternalServer
}
emailExist, err := repo.User.CheckEmailExists(req.Email)
if err != nil {
return nil, errs.ErrInternalServer
}
if usernameExist || emailExist {
return nil, errs.New(http.StatusConflict, "Username or email already exists", nil)
}
// 创建新用户
hashedPassword, err := utils.Password.HashPassword(req.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt"))
if err != nil {
logrus.Errorln("Failed to hash password:", err)
return nil, errs.ErrInternalServer
}
newUser := &model.User{
Username: req.Username,
Nickname: req.Nickname,
Email: req.Email,
Gender: "",
Role: "user",
Password: hashedPassword,
}
err = repo.User.CreateUser(newUser)
if err != nil {
return nil, errs.ErrInternalServer
}
// 创建默认管理员账户
if newUser.ID == 1 {
newUser.Role = constant.RoleAdmin
err = repo.User.UpdateUser(newUser)
if err != nil {
logrus.Errorln("Failed to update user role to admin:", err)
return nil, errs.ErrInternalServer
}
}
// 生成访问令牌和刷新令牌
token, refreshToken, err := s.generate2Token(newUser.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.UserRegisterResp{
Token: token,
RefreshToken: refreshToken,
User: newUser.ToDto(),
}
return resp, nil
}
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
kv := utils.KV.GetInstance()
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
kv := utils.KV.GetInstance()
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
if err != nil {
return nil, errs.ErrInternalServer
}
if utils.IsDevMode {
logrus.Infof("%s's verification code is %s", req.Email, generatedVerificationCode)
}
err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true)
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
if err != nil {
return nil, errs.ErrInternalServer
}
if utils.IsDevMode {
logrus.Infof("%s's verification code is %s", req.Email, generatedVerificationCode)
}
err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true)
if err != nil {
return nil, errs.ErrInternalServer
}
return &dto.VerifyEmailResp{Success: true}, nil
if err != nil {
return nil, errs.ErrInternalServer
}
return &dto.VerifyEmailResp{Success: true}, nil
}
func (s *UserService) ListOidcConfigs() ([]dto.UserOidcConfigDto, error) {
enabledOidcConfigs, err := repo.Oidc.ListOidcConfigs(true)
if err != nil {
return nil, errs.ErrInternalServer
}
var oidcConfigsDtos []dto.UserOidcConfigDto
enabledOidcConfigs, err := repo.Oidc.ListOidcConfigs(true)
if err != nil {
return nil, errs.ErrInternalServer
}
var oidcConfigsDtos []dto.UserOidcConfigDto
for _, oidcConfig := range enabledOidcConfigs {
state := utils.Strings.GenerateRandomString(32)
kvStore := utils.KV.GetInstance()
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
loginUrl := utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
"client_id": oidcConfig.ClientID,
"redirect_uri": fmt.Sprintf("%s%s%s/%sREDIRECT_BACK", // 这个大占位符给前端替换用的替换时也要uri编码因为是层层包的
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
constant.ApiSuffix,
constant.OidcUri,
oidcConfig.Name,
),
"response_type": "code",
"scope": "openid email profile",
"state": state,
})
for _, oidcConfig := range enabledOidcConfigs {
state := utils.Strings.GenerateRandomString(32)
kvStore := utils.KV.GetInstance()
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
loginUrl := utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
"client_id": oidcConfig.ClientID,
"redirect_uri": fmt.Sprintf("%s%s%s/%sREDIRECT_BACK", // 这个大占位符给前端替换用的替换时也要uri编码因为是层层包的
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
constant.ApiSuffix,
constant.OidcUri,
oidcConfig.Name,
),
"response_type": "code",
"scope": "openid email profile",
"state": state,
})
if oidcConfig.Type == constant.OidcProviderTypeMisskey {
// Misskey OIDC 特殊处理
loginUrl = utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
"client_id": oidcConfig.ClientID,
"redirect_uri": fmt.Sprintf("%s%s%s/%s", // 这个大占位符给前端替换用的替换时也要uri编码因为是层层包的
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
constant.ApiSuffix,
constant.OidcUri,
oidcConfig.Name,
),
"response_type": "code",
"scope": "read:account",
"state": state,
})
}
if oidcConfig.Type == constant.OidcProviderTypeMisskey {
// Misskey OIDC 特殊处理
loginUrl = utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
"client_id": oidcConfig.ClientID,
"redirect_uri": fmt.Sprintf("%s%s%s/%s", // 这个大占位符给前端替换用的替换时也要uri编码因为是层层包的
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
constant.ApiSuffix,
constant.OidcUri,
oidcConfig.Name,
),
"response_type": "code",
"scope": "read:account",
"state": state,
})
}
oidcConfigsDtos = append(oidcConfigsDtos, dto.UserOidcConfigDto{
Name: oidcConfig.Name,
DisplayName: oidcConfig.DisplayName,
Icon: oidcConfig.Icon,
LoginUrl: loginUrl,
})
}
return oidcConfigsDtos, nil
oidcConfigsDtos = append(oidcConfigsDtos, dto.UserOidcConfigDto{
Name: oidcConfig.Name,
DisplayName: oidcConfig.DisplayName,
Icon: oidcConfig.Icon,
LoginUrl: loginUrl,
})
}
return oidcConfigsDtos, nil
}
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
// 验证state
kvStore := utils.KV.GetInstance()
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
if !ok || storedName != req.Name {
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
}
// 获取OIDC配置
oidcConfig, err := repo.Oidc.GetOidcConfigByName(req.Name)
if err != nil {
return nil, errs.ErrInternalServer
}
if oidcConfig == nil {
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
}
// 请求访问令牌
tokenResp, err := utils.Oidc.RequestToken(
oidcConfig.TokenEndpoint,
oidcConfig.ClientID,
oidcConfig.ClientSecret,
req.Code,
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
)
if err != nil {
logrus.Errorln("Failed to request OIDC token:", err)
return nil, errs.ErrInternalServer
}
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
if err != nil {
logrus.Errorln("Failed to request OIDC user info:", err)
return nil, errs.ErrInternalServer
}
// 验证state
kvStore := utils.KV.GetInstance()
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
if !ok || storedName != req.Name {
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
}
// 获取OIDC配置
oidcConfig, err := repo.Oidc.GetOidcConfigByName(req.Name)
if err != nil {
return nil, errs.ErrInternalServer
}
if oidcConfig == nil {
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
}
// 请求访问令牌
tokenResp, err := utils.Oidc.RequestToken(
oidcConfig.TokenEndpoint,
oidcConfig.ClientID,
oidcConfig.ClientSecret,
req.Code,
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
)
if err != nil {
logrus.Errorln("Failed to request OIDC token:", err)
return nil, errs.ErrInternalServer
}
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
if err != nil {
logrus.Errorln("Failed to request OIDC user info:", err)
return nil, errs.ErrInternalServer
}
// 绑定过登录
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrInternalServer
}
if userOpenID != nil {
user, err := repo.User.GetUserByID(userOpenID.UserID)
if err != nil {
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
user, err := repo.User.GetUserByEmail(userInfo.Email)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
logrus.Errorln("Failed to get user by email:", err)
return nil, errs.ErrInternalServer
}
if user != nil {
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
user = &model.User{
Username: userInfo.Name,
Nickname: userInfo.Name,
AvatarUrl: userInfo.Picture,
Email: userInfo.Email,
}
err = repo.User.CreateUser(user)
if err != nil {
logrus.Errorln("Failed to create user:", err)
return nil, errs.ErrInternalServer
}
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
}
}
// 绑定过登录
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrInternalServer
}
if userOpenID != nil {
user, err := repo.User.GetUserByID(userOpenID.UserID)
if err != nil {
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
user, err := repo.User.GetUserByEmail(userInfo.Email)
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
logrus.Errorln("Failed to get user by email:", err)
return nil, errs.ErrInternalServer
}
if user != nil {
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
user = &model.User{
Username: userInfo.Name,
Nickname: userInfo.Name,
AvatarUrl: userInfo.Picture,
Email: userInfo.Email,
}
err = repo.User.CreateUser(user)
if err != nil {
logrus.Errorln("Failed to create user:", err)
return nil, errs.ErrInternalServer
}
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
}
}
}
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
if req.UserID == 0 {
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
}
user, err := repo.User.GetUserByID(req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to get user by ID:", err)
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
return &dto.GetUserResp{
User: user.ToDto(),
}, nil
if req.UserID == 0 {
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
}
user, err := repo.User.GetUserByID(req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to get user by ID:", err)
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
return &dto.GetUserResp{
User: user.ToDto(),
}, nil
}
func (s *UserService) GetUserByUsername(req *dto.GetUserByUsernameReq) (*dto.GetUserResp, error) {
if req.Username == "" {
return nil, errs.New(http.StatusBadRequest, "username is required", nil)
}
user, err := repo.User.GetUserByUsername(req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to get user by username:", err)
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
return &dto.GetUserResp{
User: user.ToDto(),
}, nil
if req.Username == "" {
return nil, errs.New(http.StatusBadRequest, "username is required", nil)
}
user, err := repo.User.GetUserByUsername(req.Username)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to get user by username:", err)
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
return &dto.GetUserResp{
User: user.ToDto(),
}, nil
}
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
user := &model.User{
Model: gorm.Model{
ID: req.ID,
},
Username: req.Username,
Nickname: req.Nickname,
Gender: req.Gender,
AvatarUrl: req.AvatarUrl,
}
err := repo.User.UpdateUser(user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to update user:", err)
return nil, errs.ErrInternalServer
}
return &dto.UpdateUserResp{}, nil
user := &model.User{
Model: gorm.Model{
ID: req.ID,
},
Username: req.Username,
Nickname: req.Nickname,
Gender: req.Gender,
AvatarUrl: req.AvatarUrl,
}
err := repo.User.UpdateUser(user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to update user:", err)
return nil, errs.ErrInternalServer
}
return &dto.UpdateUserResp{}, nil
}
func (s *UserService) generate2Token(userID uint) (string, string, error) {
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault))*time.Second)
tokenString, err := token.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, constant.EnvKeyRefreshTokenDurationDefault))*time.Second)
refreshTokenString, err := refreshToken.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
err = repo.Session.SaveSession(refreshToken.SessionKey)
if err != nil {
return "", "", errs.ErrInternalServer
}
return tokenString, refreshTokenString, nil
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault))*time.Second)
tokenString, err := token.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, constant.EnvKeyRefreshTokenDurationDefault))*time.Second)
refreshTokenString, err := refreshToken.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
err = repo.Session.SaveSession(refreshToken.SessionKey)
if err != nil {
return "", "", errs.ErrInternalServer
}
return tokenString, refreshTokenString, nil
}
func (s *UserService) verifyEmail(email, code string) (bool, error) {
kv := utils.KV.GetInstance()
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
if !ok || verificationCode != code {
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
return true, nil
kv := utils.KV.GetInstance()
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
if !ok || verificationCode != code {
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
return true, nil
}