refactor user service methods, implement OIDC login and user management features, and enhance token handling

This commit is contained in:
2025-07-22 20:45:05 +08:00
parent f07200b0b9
commit cbe73121f2
17 changed files with 655 additions and 126 deletions

View File

@ -3,17 +3,17 @@ package v1
import (
"context"
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/snowykami/neo-blog/internal/ctxutils"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/service"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/resps"
"github.com/snowykami/neo-blog/pkg/utils"
"strconv"
)
type userType struct {
service service.UserService
service *service.UserService
}
var User = &userType{
@ -33,13 +33,11 @@ func (u *userType) Login(ctx context.Context, c *app.RequestContext) {
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
if resp == nil {
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
return
} else {
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
resps.Ok(c, resps.Success, resp)
}
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
resps.Ok(c, resps.Success, utils.H{
"token": resp.Token,
"user": resp.User,
})
}
func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
@ -55,46 +53,101 @@ func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
if resp == nil {
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
return
}
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
resps.Ok(c, resps.Success, resp)
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
resps.Ok(c, resps.Success, utils.H{
"token": resp.Token,
"user": resp.User,
})
}
func (u *userType) Logout(ctx context.Context, c *app.RequestContext) {
u.clearTokenCookie(c)
ctxutils.ClearTokenAndRefreshTokenCookie(c)
resps.Ok(c, resps.Success, nil)
}
func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) {
// TODO: Impl
resp, err := u.service.ListOidcConfigs()
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, map[string]any{
"oidc_configs": resp.OidcConfigs,
})
}
func (u *userType) OidcLogin(ctx context.Context, c *app.RequestContext) {
// TODO: Impl
name := c.Param("name")
code := c.Param("code")
state := c.Param("state")
oidcLoginReq := &dto.OidcLoginReq{
Name: name,
Code: code,
State: state,
}
resp, err := u.service.OidcLogin(oidcLoginReq)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
resps.Ok(c, resps.Success, map[string]any{
"token": resp.Token,
"user": resp.User,
})
}
func (u *userType) Get(ctx context.Context, c *app.RequestContext) {
// TODO: Impl
}
func (u *userType) Update(ctx context.Context, c *app.RequestContext) {
// TODO: Impl
}
func (u *userType) Delete(ctx context.Context, c *app.RequestContext) {
// TODO: Impl
}
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
var verifyEmailReq dto.VerifyEmailReq
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
func (u *userType) GetUser(ctx context.Context, c *app.RequestContext) {
userID := c.Param("id")
if userID == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
resp, err := u.service.VerifyEmail(&verifyEmailReq)
userIDInt, err := strconv.Atoi(userID)
if err != nil || userIDInt <= 0 {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
resp, err := u.service.GetUser(&dto.GetUserReq{UserID: uint(userIDInt)})
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, resp.User)
}
func (u *userType) UpdateUser(ctx context.Context, c *app.RequestContext) {
userID := c.Param("id")
if userID == "" {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
userIDInt, err := strconv.Atoi(userID)
if err != nil || userIDInt <= 0 {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
var updateUserReq dto.UpdateUserReq
if err := c.BindAndValidate(&updateUserReq); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
updateUserReq.ID = uint(userIDInt)
currentUser := ctxutils.GetCurrentUser(ctx)
if currentUser == nil {
resps.UnAuthorized(c, resps.ErrUnauthorized)
return
}
if currentUser.ID != updateUserReq.ID {
resps.Forbidden(c, resps.ErrForbidden)
return
}
resp, err := u.service.UpdateUser(&updateUserReq)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
@ -103,12 +156,17 @@ func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
resps.Ok(c, resps.Success, resp)
}
func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) {
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
}
func (u *userType) clearTokenCookie(c *app.RequestContext) {
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
var verifyEmailReq dto.VerifyEmailReq
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
resps.BadRequest(c, resps.ErrParamInvalid)
return
}
resp, err := u.service.RequestVerifyEmail(&verifyEmailReq)
if err != nil {
serviceErr := errs.AsServiceError(err)
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
return
}
resps.Ok(c, resps.Success, resp)
}

View File

@ -0,0 +1,22 @@
package ctxutils
import (
"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/protocol"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/utils"
)
func SetTokenCookie(c *app.RequestContext, token string) {
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
}
func SetTokenAndRefreshTokenCookie(c *app.RequestContext, token, refreshToken string) {
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
}
func ClearTokenAndRefreshTokenCookie(c *app.RequestContext) {
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
}

19
internal/ctxutils/user.go Normal file
View File

@ -0,0 +1,19 @@
package ctxutils
import (
"context"
"github.com/snowykami/neo-blog/internal/model"
"github.com/snowykami/neo-blog/internal/repo"
)
func GetCurrentUser(ctx context.Context) *model.User {
userIDValue := ctx.Value("user_id").(uint)
if userIDValue <= 0 {
return nil
}
user, err := repo.User.GetUserByID(userIDValue)
if err != nil || user == nil || user.ID == 0 {
return nil
}
return user
}

View File

@ -1,7 +1 @@
package dto
type BaseResp struct {
Code int `json:"code"`
Message string `json:"message"`
Data any `json:"data"`
}

View File

@ -1,6 +1,7 @@
package dto
type UserDto struct {
ID uint `json:"id"` // 用户ID
Username string `json:"username"` // 用户名
Nickname string `json:"nickname"`
AvatarUrl string `json:"avatar_url"` // 头像URL
@ -8,6 +9,13 @@ type UserDto struct {
Gender string `json:"gender"`
Role string `json:"role"`
}
type OidcConfigDto struct {
Name string `json:"name"` // OIDC配置名称
DisplayName string `json:"display_name"` // OIDC配置显示名称
Icon string `json:"icon"` // OIDC配置图标URL
LoginUrl string `json:"login_url"` // OIDC登录URL
}
type UserLoginReq struct {
Username string `json:"username"` // username or email
Password string `json:"password"`
@ -40,3 +48,39 @@ type VerifyEmailReq struct {
type VerifyEmailResp struct {
Success bool `json:"success"` // 验证码发送成功与否
}
type OidcLoginReq struct {
Name string `json:"name"` // OIDC配置名称
Code string `json:"code"` // OIDC授权码
State string `json:"state"`
}
type OidcLoginResp struct {
Token string `json:"token"`
RefreshToken string `json:"refresh_token"`
User *UserDto `json:"user"`
}
type ListOidcConfigResp struct {
OidcConfigs []OidcConfigDto `json:"oidc_configs"` // OIDC配置列表
}
type GetUserReq struct {
UserID uint `json:"user_id"`
}
type GetUserResp struct {
User *UserDto `json:"user"` // 用户信息
}
type UpdateUserReq struct {
ID uint `json:"id"`
Username string `json:"username"`
Nickname string `json:"nickname"`
AvatarUrl string `json:"avatar_url"`
Gender string `json:"gender"`
}
type UpdateUserResp struct {
User *UserDto `json:"user"` // 更新后的用户信息
}

View File

@ -3,10 +3,53 @@ package middleware
import (
"context"
"github.com/cloudwego/hertz/pkg/app"
"github.com/snowykami/neo-blog/internal/ctxutils"
"github.com/snowykami/neo-blog/internal/repo"
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/resps"
"github.com/snowykami/neo-blog/pkg/utils"
"time"
)
func UseAuth() app.HandlerFunc {
return func(ctx context.Context, c *app.RequestContext) {
// TODO: Implement authentication logic here
// For cookie
token := string(c.Cookie("token"))
refreshToken := string(c.Cookie("refresh_token"))
tokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(token)
if err == nil && tokenClaims != nil {
ctx = context.WithValue(ctx, "user_id", tokenClaims.UserID)
c.Next(ctx)
return
}
// token 失效 使用refresh token重新签发和鉴权
refreshTokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(refreshToken)
if err == nil && refreshTokenClaims != nil {
ok, err := isStatefulJwtValid(refreshTokenClaims)
if err == nil && ok {
ctx = context.WithValue(ctx, "user_id", refreshTokenClaims.UserID) // 修改这里使用refreshTokenClaims
c.Next(ctx)
newTokenClaims := utils.Jwt.NewClaims(refreshTokenClaims.UserID, refreshTokenClaims.SessionKey, refreshTokenClaims.Stateful, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
newToken, err := newTokenClaims.ToString()
if err == nil {
ctxutils.SetTokenCookie(c, newToken)
} else {
resps.InternalServerError(c, resps.ErrInternalServerError)
}
return
}
}
// 所有认证方式都失败,返回未授权错误
resps.UnAuthorized(c, resps.ErrUnauthorized)
c.Abort()
}
}
func isStatefulJwtValid(claims *utils.Claims) (bool, error) {
if !claims.Stateful {
return true, nil
}
return repo.Session.IsSessionValid(claims.SessionKey)
}

View File

@ -2,6 +2,7 @@ package model
import (
"fmt"
"github.com/snowykami/neo-blog/internal/dto"
"gorm.io/gorm"
"resty.dev/v3"
"time"
@ -9,14 +10,13 @@ import (
type OidcConfig struct {
gorm.Model
Name string `gorm:"uniqueIndex"`
ClientID string `gorm:"column:client_id"` // 客户端ID
ClientSecret string `gorm:"column:client_secret"` // 客户端密钥
DisplayName string `gorm:"column:display_name"` // 显示名称,例如:轻雪通行证
GroupsClaim *string `gorm:"default:groups"` // 组声明,默认为:"groups"
Icon *string `gorm:"column:icon"` // 图标url为空则使用内置默认图标
OidcDiscoveryUrl string `gorm:"column:oidc_discovery_url"` // OpenID自动发现URL例如 https://pass.liteyuki.icu/.well-known/openid-configuration
Enabled bool `gorm:"column:enabled;default:true"` // 是否启用
Name string `gorm:"uniqueIndex"` // OIDC配置名称唯一
ClientID string // 客户端ID
ClientSecret string // 客户端密钥
DisplayName string // 显示名称,例如:轻雪通行证
Icon string // 图标url为空则使用内置默认图标
OidcDiscoveryUrl string // OpenID自动发现URL例如 https://pass.liteyuki.icu/.well-known/openid-configuration
Enabled bool `gorm:"default:true"` // 是否启用
// 以下字段为自动获取字段,每次更新配置时自动填充
Issuer string
AuthorizationEndpoint string
@ -68,11 +68,6 @@ func updateOidcConfigFromUrl(url string) (*oidcDiscoveryResp, error) {
}
func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
// 设置默认值
if o.GroupsClaim == nil {
defaultGroupsClaim := "groups"
o.GroupsClaim = &defaultGroupsClaim
}
// 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息
if tx.Statement.Changed("OidcDiscoveryUrl") {
discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl)
@ -87,3 +82,12 @@ func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
}
return nil
}
// ToDto 不包含LoginUrl在service层自行实现
func (o *OidcConfig) ToDto() *dto.OidcConfigDto {
return &dto.OidcConfigDto{
Name: o.Name,
DisplayName: o.DisplayName,
Icon: o.Icon,
}
}

View File

@ -16,8 +16,17 @@ type User struct {
Password string // 密码,存储加密后的值
}
type UserOpenID struct {
gorm.Model
UserID uint `gorm:"uniqueIndex"`
User User `gorm:"foreignKey:UserID;references:ID"`
Issuer string `gorm:"index"` // OIDC Issuer
Sub string `gorm:"index"` // OIDC Sub openid
}
func (user *User) ToDto() *dto.UserDto {
return &dto.UserDto{
ID: user.ID,
Username: user.Username,
Nickname: user.Nickname,
AvatarUrl: user.AvatarUrl,

View File

@ -6,7 +6,7 @@ type userRepo struct{}
var User = &userRepo{}
func (user *userRepo) GetByUsername(username string) (*model.User, error) {
func (user *userRepo) GetUserByUsername(username string) (*model.User, error) {
var userModel model.User
if err := GetDB().Where("username = ?", username).First(&userModel).Error; err != nil {
return nil, err
@ -14,7 +14,7 @@ func (user *userRepo) GetByUsername(username string) (*model.User, error) {
return &userModel, nil
}
func (user *userRepo) GetByEmail(email string) (*model.User, error) {
func (user *userRepo) GetUserByEmail(email string) (*model.User, error) {
var userModel model.User
if err := GetDB().Where("email = ?", email).First(&userModel).Error; err != nil {
return nil, err
@ -22,7 +22,15 @@ func (user *userRepo) GetByEmail(email string) (*model.User, error) {
return &userModel, nil
}
func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
func (user *userRepo) GetUserByID(id uint) (*model.User, error) {
var userModel model.User
if err := GetDB().Where("id = ?", id).First(&userModel).Error; err != nil {
return nil, err
}
return &userModel, nil
}
func (user *userRepo) GetUserByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
var userModel model.User
if err := GetDB().Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).First(&userModel).Error; err != nil {
return nil, err
@ -30,14 +38,14 @@ func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User,
return &userModel, nil
}
func (user *userRepo) Create(userModel *model.User) error {
func (user *userRepo) CreateUser(userModel *model.User) error {
if err := GetDB().Create(userModel).Error; err != nil {
return err
}
return nil
}
func (user *userRepo) Update(userModel *model.User) error {
func (user *userRepo) UpdateUser(userModel *model.User) error {
if err := GetDB().Updates(userModel).Error; err != nil {
return err
}
@ -59,3 +67,40 @@ func (user *userRepo) CheckEmailExists(email string) (bool, error) {
}
return count > 0, nil
}
func (user *userRepo) ListOidcConfigs(onlyEnabled bool) ([]model.OidcConfig, error) {
var configs []model.OidcConfig
if onlyEnabled {
if err := GetDB().Where("enabled = ?", true).Find(&configs).Error; err != nil {
return nil, err
}
} else {
if err := GetDB().Find(&configs).Error; err != nil {
return nil, err
}
}
return configs, nil
}
func (user *userRepo) GetOidcConfigByName(name string) (*model.OidcConfig, error) {
var config model.OidcConfig
if err := GetDB().Where("name = ?", name).First(&config).Error; err != nil {
return nil, err
}
return &config, nil
}
func (user *userRepo) CreateOrUpdateUserOpenID(userOpenID *model.UserOpenID) error {
if err := GetDB().Save(userOpenID).Error; err != nil {
return err
}
return nil
}
func (user *userRepo) GetUserOpenIDByIssuerAndSub(issuer, sub string) (*model.UserOpenID, error) {
var userOpenID model.UserOpenID
if err := GetDB().Where("issuer = ? AND sub = ?", issuer, sub).First(&userOpenID).Error; err != nil {
return nil, err
}
return &userOpenID, nil
}

View File

@ -16,9 +16,8 @@ func registerUserRoutes(group *route.RouterGroup) {
userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code
userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList)
userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin)
userGroupWithoutAuth.GET("/u/:id", v1.User.Get)
userGroupWithoutAuth.GET("/u/:id", v1.User.GetUser)
userGroup.POST("/logout", v1.User.Logout)
userGroup.PUT("/u/:id", v1.User.Update)
userGroup.DELETE("/u/:id", v1.User.Delete)
userGroup.PUT("/u/:id", v1.User.UpdateUser)
}
}

View File

@ -1,6 +1,7 @@
package service
import (
"errors"
"github.com/sirupsen/logrus"
"github.com/snowykami/neo-blog/internal/dto"
"github.com/snowykami/neo-blog/internal/model"
@ -9,25 +10,20 @@ import (
"github.com/snowykami/neo-blog/pkg/constant"
"github.com/snowykami/neo-blog/pkg/errs"
"github.com/snowykami/neo-blog/pkg/utils"
"gorm.io/gorm"
"net/http"
"strings"
"time"
)
type UserService interface {
UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error)
UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error)
VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error)
// TODO impl other user-related methods
type UserService struct{}
func NewUserService() *UserService {
return &UserService{}
}
type userService struct{}
func NewUserService() UserService {
return &userService{}
}
func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
user, err := repo.User.GetByUsernameOrEmail(req.Username)
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
if err != nil {
return nil, errs.ErrInternalServer
}
@ -35,26 +31,14 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
return nil, errs.ErrNotFound
}
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
tokenString, err := token.ToString()
if err != nil {
return nil, errs.ErrInternalServer
}
refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
refreshTokenString, err := refreshToken.ToString()
if err != nil {
return nil, errs.ErrInternalServer
}
// 对refresh token进行持久化存储
err = repo.Session.SaveSession(refreshToken.SessionKey)
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.UserLoginResp{
Token: tokenString,
RefreshToken: refreshTokenString,
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
@ -63,16 +47,19 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
}
}
func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
// 验证邮箱验证码
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
return nil, errs.ErrForbidden
}
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
kv := utils.KV.GetInstance()
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + ":" + req.Email)
if !ok || verificationCode != req.VerificationCode {
return nil, errs.ErrInvalidCredentials
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
if err != nil {
logrus.Errorln("Failed to verify email:", err)
return nil, errs.ErrInternalServer
}
if !ok {
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
}
// 检查用户名或邮箱是否已存在
@ -101,39 +88,28 @@ func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterR
Role: "user",
Password: hashedPassword,
}
err = repo.User.Create(newUser)
err = repo.User.CreateUser(newUser)
if err != nil {
return nil, errs.ErrInternalServer
}
// 生成访问令牌和刷新令牌
token := utils.Jwt.NewClaims(newUser.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
tokenString, err := token.ToString()
token, refreshToken, err := s.generate2Token(newUser.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
refreshToken := utils.Jwt.NewClaims(newUser.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
refreshTokenString, err := refreshToken.ToString()
if err != nil {
return nil, errs.ErrInternalServer
}
// 对refresh token进行持久化存储
err = repo.Session.SaveSession(refreshToken.SessionKey)
if err != nil {
return nil, errs.ErrInternalServer
}
resp := &dto.UserRegisterResp{
Token: tokenString,
RefreshToken: refreshTokenString,
Token: token,
RefreshToken: refreshToken,
User: newUser.ToDto(),
}
return resp, nil
}
func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
kv := utils.KV.GetInstance()
kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10)
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
if err != nil {
@ -149,3 +125,220 @@ func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp
}
return &dto.VerifyEmailResp{Success: true}, nil
}
func (s *UserService) ListOidcConfigs() (*dto.ListOidcConfigResp, error) {
enabledOidcConfigs, err := repo.User.ListOidcConfigs(true)
if err != nil {
return nil, errs.ErrInternalServer
}
var oidcConfigsDtos []dto.OidcConfigDto
for _, oidcConfig := range enabledOidcConfigs {
state := utils.Strings.GenerateRandomString(32)
kvStore := utils.KV.GetInstance()
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
oidcConfigsDtos = append(oidcConfigsDtos, dto.OidcConfigDto{
Name: oidcConfig.Name,
DisplayName: oidcConfig.DisplayName,
Icon: oidcConfig.Icon,
LoginUrl: utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
"client_id": oidcConfig.ClientID,
"redirect_uri": strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/") + constant.OidcUri + oidcConfig.Name,
"response_type": "code",
"scope": "openid email profile",
"state": state,
}),
})
}
return &dto.ListOidcConfigResp{
OidcConfigs: oidcConfigsDtos,
}, nil
}
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
// 验证state
kvStore := utils.KV.GetInstance()
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
if !ok || storedName != req.Name {
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
}
// 获取OIDC配置
oidcConfig, err := repo.User.GetOidcConfigByName(req.Name)
if err != nil {
return nil, errs.ErrInternalServer
}
if oidcConfig == nil {
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
}
// 请求访问令牌
tokenResp, err := utils.Oidc.RequestToken(
oidcConfig.TokenEndpoint,
oidcConfig.ClientID,
oidcConfig.ClientSecret,
req.Code,
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
)
if err != nil {
logrus.Errorln("Failed to request OIDC token:", err)
return nil, errs.ErrInternalServer
}
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
if err != nil {
logrus.Errorln("Failed to request OIDC user info:", err)
return nil, errs.ErrInternalServer
}
// 绑定过登录
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
if !errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrInternalServer
}
if userOpenID != nil {
user, err := repo.User.GetUserByID(userOpenID.UserID)
if err != nil {
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
user, err := repo.User.GetUserByEmail(userInfo.Email)
if !errors.Is(err, gorm.ErrRecordNotFound) {
logrus.Errorln("Failed to get user by email:", err)
return nil, errs.ErrInternalServer
}
if user != nil {
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
} else {
user = &model.User{
Username: userInfo.Name,
Nickname: userInfo.Name,
AvatarUrl: userInfo.Picture,
Email: userInfo.Email,
}
err = repo.User.CreateUser(user)
if err != nil {
logrus.Errorln("Failed to create user:", err)
return nil, errs.ErrInternalServer
}
userOpenID = &model.UserOpenID{
UserID: user.ID,
Issuer: oidcConfig.Issuer,
Sub: userInfo.Sub,
}
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
if err != nil {
logrus.Errorln("Failed to create or update user OpenID:", err)
return nil, errs.ErrInternalServer
}
token, refreshToken, err := s.generate2Token(user.ID)
if err != nil {
logrus.Errorln("Failed to generate tokens:", err)
return nil, errs.ErrInternalServer
}
resp := &dto.OidcLoginResp{
Token: token,
RefreshToken: refreshToken,
User: user.ToDto(),
}
return resp, nil
}
}
}
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
if req.UserID == 0 {
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
}
user, err := repo.User.GetUserByID(req.UserID)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to get user by ID:", err)
return nil, errs.ErrInternalServer
}
if user == nil {
return nil, errs.ErrNotFound
}
return &dto.GetUserResp{
User: user.ToDto(),
}, nil
}
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
user := &model.User{
Model: gorm.Model{
ID: req.ID,
},
Username: req.Username,
Nickname: req.Nickname,
Gender: req.Gender,
AvatarUrl: req.AvatarUrl,
}
err := repo.User.UpdateUser(user)
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, errs.ErrNotFound
}
logrus.Errorln("Failed to update user:", err)
return nil, errs.ErrInternalServer
}
return &dto.UpdateUserResp{}, nil
}
func (s *UserService) generate2Token(userID uint) (string, string, error) {
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
tokenString, err := token.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
refreshTokenString, err := refreshToken.ToString()
if err != nil {
return "", "", errs.ErrInternalServer
}
err = repo.Session.SaveSession(refreshToken.SessionKey)
if err != nil {
return "", "", errs.ErrInternalServer
}
return tokenString, refreshTokenString, nil
}
func (s *UserService) verifyEmail(email, code string) (bool, error) {
kv := utils.KV.GetInstance()
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
if !ok || verificationCode != code {
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
}
return true, nil
}