mirror of
https://github.com/snowykami/neo-blog.git
synced 2025-09-04 08:16:24 +00:00
⚡ refactor user service methods, implement OIDC login and user management features, and enhance token handling
This commit is contained in:
@ -23,6 +23,7 @@ EMAIL_PORT=465
|
|||||||
EMAIL_SSL=true
|
EMAIL_SSL=true
|
||||||
|
|
||||||
# App settings
|
# App settings
|
||||||
|
BASE_URL=https://blog.jason.moe
|
||||||
MAX_REQUEST_BODY_SIZE=1000000
|
MAX_REQUEST_BODY_SIZE=1000000
|
||||||
MODE=prod
|
MODE=prod
|
||||||
PORT=8888
|
PORT=8888
|
||||||
|
@ -3,17 +3,17 @@ package v1
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
"github.com/cloudwego/hertz/pkg/protocol"
|
"github.com/cloudwego/hertz/pkg/common/utils"
|
||||||
|
"github.com/snowykami/neo-blog/internal/ctxutils"
|
||||||
"github.com/snowykami/neo-blog/internal/dto"
|
"github.com/snowykami/neo-blog/internal/dto"
|
||||||
"github.com/snowykami/neo-blog/internal/service"
|
"github.com/snowykami/neo-blog/internal/service"
|
||||||
"github.com/snowykami/neo-blog/pkg/constant"
|
|
||||||
"github.com/snowykami/neo-blog/pkg/errs"
|
"github.com/snowykami/neo-blog/pkg/errs"
|
||||||
"github.com/snowykami/neo-blog/pkg/resps"
|
"github.com/snowykami/neo-blog/pkg/resps"
|
||||||
"github.com/snowykami/neo-blog/pkg/utils"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userType struct {
|
type userType struct {
|
||||||
service service.UserService
|
service *service.UserService
|
||||||
}
|
}
|
||||||
|
|
||||||
var User = &userType{
|
var User = &userType{
|
||||||
@ -33,13 +33,11 @@ func (u *userType) Login(ctx context.Context, c *app.RequestContext) {
|
|||||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resp == nil {
|
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
resps.Ok(c, resps.Success, utils.H{
|
||||||
return
|
"token": resp.Token,
|
||||||
} else {
|
"user": resp.User,
|
||||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
})
|
||||||
resps.Ok(c, resps.Success, resp)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
||||||
@ -55,46 +53,101 @@ func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
|||||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if resp == nil {
|
|
||||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||||
return
|
resps.Ok(c, resps.Success, utils.H{
|
||||||
}
|
"token": resp.Token,
|
||||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
"user": resp.User,
|
||||||
resps.Ok(c, resps.Success, resp)
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) Logout(ctx context.Context, c *app.RequestContext) {
|
func (u *userType) Logout(ctx context.Context, c *app.RequestContext) {
|
||||||
u.clearTokenCookie(c)
|
ctxutils.ClearTokenAndRefreshTokenCookie(c)
|
||||||
resps.Ok(c, resps.Success, nil)
|
resps.Ok(c, resps.Success, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) {
|
func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) {
|
||||||
// TODO: Impl
|
resp, err := u.service.ListOidcConfigs()
|
||||||
|
if err != nil {
|
||||||
|
serviceErr := errs.AsServiceError(err)
|
||||||
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resps.Ok(c, resps.Success, map[string]any{
|
||||||
|
"oidc_configs": resp.OidcConfigs,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) OidcLogin(ctx context.Context, c *app.RequestContext) {
|
func (u *userType) OidcLogin(ctx context.Context, c *app.RequestContext) {
|
||||||
// TODO: Impl
|
name := c.Param("name")
|
||||||
|
code := c.Param("code")
|
||||||
|
state := c.Param("state")
|
||||||
|
oidcLoginReq := &dto.OidcLoginReq{
|
||||||
|
Name: name,
|
||||||
|
Code: code,
|
||||||
|
State: state,
|
||||||
|
}
|
||||||
|
resp, err := u.service.OidcLogin(oidcLoginReq)
|
||||||
|
if err != nil {
|
||||||
|
serviceErr := errs.AsServiceError(err)
|
||||||
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||||
|
resps.Ok(c, resps.Success, map[string]any{
|
||||||
|
"token": resp.Token,
|
||||||
|
"user": resp.User,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) Get(ctx context.Context, c *app.RequestContext) {
|
func (u *userType) GetUser(ctx context.Context, c *app.RequestContext) {
|
||||||
// TODO: Impl
|
userID := c.Param("id")
|
||||||
}
|
if userID == "" {
|
||||||
|
|
||||||
func (u *userType) Update(ctx context.Context, c *app.RequestContext) {
|
|
||||||
// TODO: Impl
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *userType) Delete(ctx context.Context, c *app.RequestContext) {
|
|
||||||
// TODO: Impl
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
|
||||||
var verifyEmailReq dto.VerifyEmailReq
|
|
||||||
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
|
|
||||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
resp, err := u.service.VerifyEmail(&verifyEmailReq)
|
userIDInt, err := strconv.Atoi(userID)
|
||||||
|
if err != nil || userIDInt <= 0 {
|
||||||
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := u.service.GetUser(&dto.GetUserReq{UserID: uint(userIDInt)})
|
||||||
|
if err != nil {
|
||||||
|
serviceErr := errs.AsServiceError(err)
|
||||||
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resps.Ok(c, resps.Success, resp.User)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *userType) UpdateUser(ctx context.Context, c *app.RequestContext) {
|
||||||
|
userID := c.Param("id")
|
||||||
|
if userID == "" {
|
||||||
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
userIDInt, err := strconv.Atoi(userID)
|
||||||
|
if err != nil || userIDInt <= 0 {
|
||||||
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var updateUserReq dto.UpdateUserReq
|
||||||
|
if err := c.BindAndValidate(&updateUserReq); err != nil {
|
||||||
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
updateUserReq.ID = uint(userIDInt)
|
||||||
|
currentUser := ctxutils.GetCurrentUser(ctx)
|
||||||
|
if currentUser == nil {
|
||||||
|
resps.UnAuthorized(c, resps.ErrUnauthorized)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if currentUser.ID != updateUserReq.ID {
|
||||||
|
resps.Forbidden(c, resps.ErrForbidden)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resp, err := u.service.UpdateUser(&updateUserReq)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
serviceErr := errs.AsServiceError(err)
|
serviceErr := errs.AsServiceError(err)
|
||||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
@ -103,12 +156,17 @@ func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
|||||||
resps.Ok(c, resps.Success, resp)
|
resps.Ok(c, resps.Success, resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) {
|
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
||||||
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
var verifyEmailReq dto.VerifyEmailReq
|
||||||
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
|
||||||
}
|
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||||
|
return
|
||||||
func (u *userType) clearTokenCookie(c *app.RequestContext) {
|
}
|
||||||
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
resp, err := u.service.RequestVerifyEmail(&verifyEmailReq)
|
||||||
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
if err != nil {
|
||||||
|
serviceErr := errs.AsServiceError(err)
|
||||||
|
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
resps.Ok(c, resps.Success, resp)
|
||||||
}
|
}
|
||||||
|
22
internal/ctxutils/token.go
Normal file
22
internal/ctxutils/token.go
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
package ctxutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
|
"github.com/cloudwego/hertz/pkg/protocol"
|
||||||
|
"github.com/snowykami/neo-blog/pkg/constant"
|
||||||
|
"github.com/snowykami/neo-blog/pkg/utils"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetTokenCookie(c *app.RequestContext, token string) {
|
||||||
|
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SetTokenAndRefreshTokenCookie(c *app.RequestContext, token, refreshToken string) {
|
||||||
|
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||||
|
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ClearTokenAndRefreshTokenCookie(c *app.RequestContext) {
|
||||||
|
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||||
|
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||||
|
}
|
19
internal/ctxutils/user.go
Normal file
19
internal/ctxutils/user.go
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
package ctxutils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"github.com/snowykami/neo-blog/internal/model"
|
||||||
|
"github.com/snowykami/neo-blog/internal/repo"
|
||||||
|
)
|
||||||
|
|
||||||
|
func GetCurrentUser(ctx context.Context) *model.User {
|
||||||
|
userIDValue := ctx.Value("user_id").(uint)
|
||||||
|
if userIDValue <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
user, err := repo.User.GetUserByID(userIDValue)
|
||||||
|
if err != nil || user == nil || user.ID == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return user
|
||||||
|
}
|
@ -1,7 +1 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type BaseResp struct {
|
|
||||||
Code int `json:"code"`
|
|
||||||
Message string `json:"message"`
|
|
||||||
Data any `json:"data"`
|
|
||||||
}
|
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package dto
|
package dto
|
||||||
|
|
||||||
type UserDto struct {
|
type UserDto struct {
|
||||||
|
ID uint `json:"id"` // 用户ID
|
||||||
Username string `json:"username"` // 用户名
|
Username string `json:"username"` // 用户名
|
||||||
Nickname string `json:"nickname"`
|
Nickname string `json:"nickname"`
|
||||||
AvatarUrl string `json:"avatar_url"` // 头像URL
|
AvatarUrl string `json:"avatar_url"` // 头像URL
|
||||||
@ -8,6 +9,13 @@ type UserDto struct {
|
|||||||
Gender string `json:"gender"`
|
Gender string `json:"gender"`
|
||||||
Role string `json:"role"`
|
Role string `json:"role"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OidcConfigDto struct {
|
||||||
|
Name string `json:"name"` // OIDC配置名称
|
||||||
|
DisplayName string `json:"display_name"` // OIDC配置显示名称
|
||||||
|
Icon string `json:"icon"` // OIDC配置图标URL
|
||||||
|
LoginUrl string `json:"login_url"` // OIDC登录URL
|
||||||
|
}
|
||||||
type UserLoginReq struct {
|
type UserLoginReq struct {
|
||||||
Username string `json:"username"` // username or email
|
Username string `json:"username"` // username or email
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
@ -40,3 +48,39 @@ type VerifyEmailReq struct {
|
|||||||
type VerifyEmailResp struct {
|
type VerifyEmailResp struct {
|
||||||
Success bool `json:"success"` // 验证码发送成功与否
|
Success bool `json:"success"` // 验证码发送成功与否
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type OidcLoginReq struct {
|
||||||
|
Name string `json:"name"` // OIDC配置名称
|
||||||
|
Code string `json:"code"` // OIDC授权码
|
||||||
|
State string `json:"state"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OidcLoginResp struct {
|
||||||
|
Token string `json:"token"`
|
||||||
|
RefreshToken string `json:"refresh_token"`
|
||||||
|
User *UserDto `json:"user"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ListOidcConfigResp struct {
|
||||||
|
OidcConfigs []OidcConfigDto `json:"oidc_configs"` // OIDC配置列表
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetUserReq struct {
|
||||||
|
UserID uint `json:"user_id"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type GetUserResp struct {
|
||||||
|
User *UserDto `json:"user"` // 用户信息
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateUserReq struct {
|
||||||
|
ID uint `json:"id"`
|
||||||
|
Username string `json:"username"`
|
||||||
|
Nickname string `json:"nickname"`
|
||||||
|
AvatarUrl string `json:"avatar_url"`
|
||||||
|
Gender string `json:"gender"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type UpdateUserResp struct {
|
||||||
|
User *UserDto `json:"user"` // 更新后的用户信息
|
||||||
|
}
|
||||||
|
@ -3,10 +3,53 @@ package middleware
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"github.com/cloudwego/hertz/pkg/app"
|
"github.com/cloudwego/hertz/pkg/app"
|
||||||
|
"github.com/snowykami/neo-blog/internal/ctxutils"
|
||||||
|
"github.com/snowykami/neo-blog/internal/repo"
|
||||||
|
"github.com/snowykami/neo-blog/pkg/constant"
|
||||||
|
"github.com/snowykami/neo-blog/pkg/resps"
|
||||||
|
"github.com/snowykami/neo-blog/pkg/utils"
|
||||||
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
func UseAuth() app.HandlerFunc {
|
func UseAuth() app.HandlerFunc {
|
||||||
return func(ctx context.Context, c *app.RequestContext) {
|
return func(ctx context.Context, c *app.RequestContext) {
|
||||||
// TODO: Implement authentication logic here
|
// For cookie
|
||||||
|
token := string(c.Cookie("token"))
|
||||||
|
refreshToken := string(c.Cookie("refresh_token"))
|
||||||
|
tokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(token)
|
||||||
|
if err == nil && tokenClaims != nil {
|
||||||
|
ctx = context.WithValue(ctx, "user_id", tokenClaims.UserID)
|
||||||
|
c.Next(ctx)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// token 失效 使用refresh token重新签发和鉴权
|
||||||
|
refreshTokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(refreshToken)
|
||||||
|
if err == nil && refreshTokenClaims != nil {
|
||||||
|
ok, err := isStatefulJwtValid(refreshTokenClaims)
|
||||||
|
if err == nil && ok {
|
||||||
|
ctx = context.WithValue(ctx, "user_id", refreshTokenClaims.UserID) // 修改这里,使用refreshTokenClaims
|
||||||
|
c.Next(ctx)
|
||||||
|
newTokenClaims := utils.Jwt.NewClaims(refreshTokenClaims.UserID, refreshTokenClaims.SessionKey, refreshTokenClaims.Stateful, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||||
|
newToken, err := newTokenClaims.ToString()
|
||||||
|
if err == nil {
|
||||||
|
ctxutils.SetTokenCookie(c, newToken)
|
||||||
|
} else {
|
||||||
|
resps.InternalServerError(c, resps.ErrInternalServerError)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 所有认证方式都失败,返回未授权错误
|
||||||
|
resps.UnAuthorized(c, resps.ErrUnauthorized)
|
||||||
|
c.Abort()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isStatefulJwtValid(claims *utils.Claims) (bool, error) {
|
||||||
|
if !claims.Stateful {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
return repo.Session.IsSessionValid(claims.SessionKey)
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@ package model
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"github.com/snowykami/neo-blog/internal/dto"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
"resty.dev/v3"
|
"resty.dev/v3"
|
||||||
"time"
|
"time"
|
||||||
@ -9,14 +10,13 @@ import (
|
|||||||
|
|
||||||
type OidcConfig struct {
|
type OidcConfig struct {
|
||||||
gorm.Model
|
gorm.Model
|
||||||
Name string `gorm:"uniqueIndex"`
|
Name string `gorm:"uniqueIndex"` // OIDC配置名称,唯一
|
||||||
ClientID string `gorm:"column:client_id"` // 客户端ID
|
ClientID string // 客户端ID
|
||||||
ClientSecret string `gorm:"column:client_secret"` // 客户端密钥
|
ClientSecret string // 客户端密钥
|
||||||
DisplayName string `gorm:"column:display_name"` // 显示名称,例如:轻雪通行证
|
DisplayName string // 显示名称,例如:轻雪通行证
|
||||||
GroupsClaim *string `gorm:"default:groups"` // 组声明,默认为:"groups"
|
Icon string // 图标url,为空则使用内置默认图标
|
||||||
Icon *string `gorm:"column:icon"` // 图标url,为空则使用内置默认图标
|
OidcDiscoveryUrl string // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration
|
||||||
OidcDiscoveryUrl string `gorm:"column:oidc_discovery_url"` // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration
|
Enabled bool `gorm:"default:true"` // 是否启用
|
||||||
Enabled bool `gorm:"column:enabled;default:true"` // 是否启用
|
|
||||||
// 以下字段为自动获取字段,每次更新配置时自动填充
|
// 以下字段为自动获取字段,每次更新配置时自动填充
|
||||||
Issuer string
|
Issuer string
|
||||||
AuthorizationEndpoint string
|
AuthorizationEndpoint string
|
||||||
@ -68,11 +68,6 @@ func updateOidcConfigFromUrl(url string) (*oidcDiscoveryResp, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
|
func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
|
||||||
// 设置默认值
|
|
||||||
if o.GroupsClaim == nil {
|
|
||||||
defaultGroupsClaim := "groups"
|
|
||||||
o.GroupsClaim = &defaultGroupsClaim
|
|
||||||
}
|
|
||||||
// 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息
|
// 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息
|
||||||
if tx.Statement.Changed("OidcDiscoveryUrl") {
|
if tx.Statement.Changed("OidcDiscoveryUrl") {
|
||||||
discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl)
|
discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl)
|
||||||
@ -87,3 +82,12 @@ func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
|
|||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ToDto 不包含LoginUrl,在service层自行实现
|
||||||
|
func (o *OidcConfig) ToDto() *dto.OidcConfigDto {
|
||||||
|
return &dto.OidcConfigDto{
|
||||||
|
Name: o.Name,
|
||||||
|
DisplayName: o.DisplayName,
|
||||||
|
Icon: o.Icon,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -16,8 +16,17 @@ type User struct {
|
|||||||
Password string // 密码,存储加密后的值
|
Password string // 密码,存储加密后的值
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UserOpenID struct {
|
||||||
|
gorm.Model
|
||||||
|
UserID uint `gorm:"uniqueIndex"`
|
||||||
|
User User `gorm:"foreignKey:UserID;references:ID"`
|
||||||
|
Issuer string `gorm:"index"` // OIDC Issuer
|
||||||
|
Sub string `gorm:"index"` // OIDC Sub openid
|
||||||
|
}
|
||||||
|
|
||||||
func (user *User) ToDto() *dto.UserDto {
|
func (user *User) ToDto() *dto.UserDto {
|
||||||
return &dto.UserDto{
|
return &dto.UserDto{
|
||||||
|
ID: user.ID,
|
||||||
Username: user.Username,
|
Username: user.Username,
|
||||||
Nickname: user.Nickname,
|
Nickname: user.Nickname,
|
||||||
AvatarUrl: user.AvatarUrl,
|
AvatarUrl: user.AvatarUrl,
|
||||||
|
@ -6,7 +6,7 @@ type userRepo struct{}
|
|||||||
|
|
||||||
var User = &userRepo{}
|
var User = &userRepo{}
|
||||||
|
|
||||||
func (user *userRepo) GetByUsername(username string) (*model.User, error) {
|
func (user *userRepo) GetUserByUsername(username string) (*model.User, error) {
|
||||||
var userModel model.User
|
var userModel model.User
|
||||||
if err := GetDB().Where("username = ?", username).First(&userModel).Error; err != nil {
|
if err := GetDB().Where("username = ?", username).First(&userModel).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -14,7 +14,7 @@ func (user *userRepo) GetByUsername(username string) (*model.User, error) {
|
|||||||
return &userModel, nil
|
return &userModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *userRepo) GetByEmail(email string) (*model.User, error) {
|
func (user *userRepo) GetUserByEmail(email string) (*model.User, error) {
|
||||||
var userModel model.User
|
var userModel model.User
|
||||||
if err := GetDB().Where("email = ?", email).First(&userModel).Error; err != nil {
|
if err := GetDB().Where("email = ?", email).First(&userModel).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -22,7 +22,15 @@ func (user *userRepo) GetByEmail(email string) (*model.User, error) {
|
|||||||
return &userModel, nil
|
return &userModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
|
func (user *userRepo) GetUserByID(id uint) (*model.User, error) {
|
||||||
|
var userModel model.User
|
||||||
|
if err := GetDB().Where("id = ?", id).First(&userModel).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &userModel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *userRepo) GetUserByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
|
||||||
var userModel model.User
|
var userModel model.User
|
||||||
if err := GetDB().Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).First(&userModel).Error; err != nil {
|
if err := GetDB().Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).First(&userModel).Error; err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -30,14 +38,14 @@ func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User,
|
|||||||
return &userModel, nil
|
return &userModel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *userRepo) Create(userModel *model.User) error {
|
func (user *userRepo) CreateUser(userModel *model.User) error {
|
||||||
if err := GetDB().Create(userModel).Error; err != nil {
|
if err := GetDB().Create(userModel).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (user *userRepo) Update(userModel *model.User) error {
|
func (user *userRepo) UpdateUser(userModel *model.User) error {
|
||||||
if err := GetDB().Updates(userModel).Error; err != nil {
|
if err := GetDB().Updates(userModel).Error; err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@ -59,3 +67,40 @@ func (user *userRepo) CheckEmailExists(email string) (bool, error) {
|
|||||||
}
|
}
|
||||||
return count > 0, nil
|
return count > 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (user *userRepo) ListOidcConfigs(onlyEnabled bool) ([]model.OidcConfig, error) {
|
||||||
|
var configs []model.OidcConfig
|
||||||
|
if onlyEnabled {
|
||||||
|
if err := GetDB().Where("enabled = ?", true).Find(&configs).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err := GetDB().Find(&configs).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return configs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *userRepo) GetOidcConfigByName(name string) (*model.OidcConfig, error) {
|
||||||
|
var config model.OidcConfig
|
||||||
|
if err := GetDB().Where("name = ?", name).First(&config).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *userRepo) CreateOrUpdateUserOpenID(userOpenID *model.UserOpenID) error {
|
||||||
|
if err := GetDB().Save(userOpenID).Error; err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (user *userRepo) GetUserOpenIDByIssuerAndSub(issuer, sub string) (*model.UserOpenID, error) {
|
||||||
|
var userOpenID model.UserOpenID
|
||||||
|
if err := GetDB().Where("issuer = ? AND sub = ?", issuer, sub).First(&userOpenID).Error; err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &userOpenID, nil
|
||||||
|
}
|
||||||
|
@ -16,9 +16,8 @@ func registerUserRoutes(group *route.RouterGroup) {
|
|||||||
userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code
|
userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code
|
||||||
userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList)
|
userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList)
|
||||||
userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin)
|
userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin)
|
||||||
userGroupWithoutAuth.GET("/u/:id", v1.User.Get)
|
userGroupWithoutAuth.GET("/u/:id", v1.User.GetUser)
|
||||||
userGroup.POST("/logout", v1.User.Logout)
|
userGroup.POST("/logout", v1.User.Logout)
|
||||||
userGroup.PUT("/u/:id", v1.User.Update)
|
userGroup.PUT("/u/:id", v1.User.UpdateUser)
|
||||||
userGroup.DELETE("/u/:id", v1.User.Delete)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/snowykami/neo-blog/internal/dto"
|
"github.com/snowykami/neo-blog/internal/dto"
|
||||||
"github.com/snowykami/neo-blog/internal/model"
|
"github.com/snowykami/neo-blog/internal/model"
|
||||||
@ -9,25 +10,20 @@ import (
|
|||||||
"github.com/snowykami/neo-blog/pkg/constant"
|
"github.com/snowykami/neo-blog/pkg/constant"
|
||||||
"github.com/snowykami/neo-blog/pkg/errs"
|
"github.com/snowykami/neo-blog/pkg/errs"
|
||||||
"github.com/snowykami/neo-blog/pkg/utils"
|
"github.com/snowykami/neo-blog/pkg/utils"
|
||||||
|
"gorm.io/gorm"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserService interface {
|
type UserService struct{}
|
||||||
UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error)
|
|
||||||
UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error)
|
func NewUserService() *UserService {
|
||||||
VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error)
|
return &UserService{}
|
||||||
// TODO impl other user-related methods
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type userService struct{}
|
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||||
|
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
|
||||||
func NewUserService() UserService {
|
|
||||||
return &userService{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
|
||||||
user, err := repo.User.GetByUsernameOrEmail(req.Username)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
@ -35,26 +31,14 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
|
|||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
||||||
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
|
||||||
tokenString, err := token.ToString()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.ErrInternalServer
|
|
||||||
}
|
|
||||||
|
|
||||||
refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
|
||||||
refreshTokenString, err := refreshToken.ToString()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.ErrInternalServer
|
|
||||||
}
|
|
||||||
// 对refresh token进行持久化存储
|
|
||||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.UserLoginResp{
|
resp := &dto.UserLoginResp{
|
||||||
Token: tokenString,
|
Token: token,
|
||||||
RefreshToken: refreshTokenString,
|
RefreshToken: refreshToken,
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
@ -63,16 +47,19 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
||||||
// 验证邮箱验证码
|
// 验证邮箱验证码
|
||||||
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
|
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
|
||||||
return nil, errs.ErrForbidden
|
return nil, errs.ErrForbidden
|
||||||
}
|
}
|
||||||
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
|
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
|
||||||
kv := utils.KV.GetInstance()
|
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
|
||||||
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + ":" + req.Email)
|
if err != nil {
|
||||||
if !ok || verificationCode != req.VerificationCode {
|
logrus.Errorln("Failed to verify email:", err)
|
||||||
return nil, errs.ErrInvalidCredentials
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
if !ok {
|
||||||
|
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 检查用户名或邮箱是否已存在
|
// 检查用户名或邮箱是否已存在
|
||||||
@ -101,39 +88,28 @@ func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterR
|
|||||||
Role: "user",
|
Role: "user",
|
||||||
Password: hashedPassword,
|
Password: hashedPassword,
|
||||||
}
|
}
|
||||||
err = repo.User.Create(newUser)
|
err = repo.User.CreateUser(newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
// 生成访问令牌和刷新令牌
|
// 生成访问令牌和刷新令牌
|
||||||
token := utils.Jwt.NewClaims(newUser.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
token, refreshToken, err := s.generate2Token(newUser.ID)
|
||||||
tokenString, err := token.ToString()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
refreshToken := utils.Jwt.NewClaims(newUser.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
|
||||||
refreshTokenString, err := refreshToken.ToString()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.ErrInternalServer
|
|
||||||
}
|
|
||||||
// 对refresh token进行持久化存储
|
|
||||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errs.ErrInternalServer
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &dto.UserRegisterResp{
|
resp := &dto.UserRegisterResp{
|
||||||
Token: tokenString,
|
Token: token,
|
||||||
RefreshToken: refreshTokenString,
|
RefreshToken: refreshToken,
|
||||||
User: newUser.ToDto(),
|
User: newUser.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
||||||
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
||||||
kv := utils.KV.GetInstance()
|
kv := utils.KV.GetInstance()
|
||||||
kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10)
|
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
|
||||||
|
|
||||||
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@ -149,3 +125,220 @@ func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp
|
|||||||
}
|
}
|
||||||
return &dto.VerifyEmailResp{Success: true}, nil
|
return &dto.VerifyEmailResp{Success: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *UserService) ListOidcConfigs() (*dto.ListOidcConfigResp, error) {
|
||||||
|
enabledOidcConfigs, err := repo.User.ListOidcConfigs(true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
var oidcConfigsDtos []dto.OidcConfigDto
|
||||||
|
|
||||||
|
for _, oidcConfig := range enabledOidcConfigs {
|
||||||
|
state := utils.Strings.GenerateRandomString(32)
|
||||||
|
kvStore := utils.KV.GetInstance()
|
||||||
|
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
|
||||||
|
oidcConfigsDtos = append(oidcConfigsDtos, dto.OidcConfigDto{
|
||||||
|
Name: oidcConfig.Name,
|
||||||
|
DisplayName: oidcConfig.DisplayName,
|
||||||
|
Icon: oidcConfig.Icon,
|
||||||
|
LoginUrl: utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
||||||
|
"client_id": oidcConfig.ClientID,
|
||||||
|
"redirect_uri": strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/") + constant.OidcUri + oidcConfig.Name,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
"state": state,
|
||||||
|
}),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return &dto.ListOidcConfigResp{
|
||||||
|
OidcConfigs: oidcConfigsDtos,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
|
||||||
|
// 验证state
|
||||||
|
kvStore := utils.KV.GetInstance()
|
||||||
|
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
|
||||||
|
if !ok || storedName != req.Name {
|
||||||
|
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
|
||||||
|
}
|
||||||
|
// 获取OIDC配置
|
||||||
|
oidcConfig, err := repo.User.GetOidcConfigByName(req.Name)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
if oidcConfig == nil {
|
||||||
|
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
|
||||||
|
}
|
||||||
|
// 请求访问令牌
|
||||||
|
tokenResp, err := utils.Oidc.RequestToken(
|
||||||
|
oidcConfig.TokenEndpoint,
|
||||||
|
oidcConfig.ClientID,
|
||||||
|
oidcConfig.ClientSecret,
|
||||||
|
req.Code,
|
||||||
|
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to request OIDC token:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to request OIDC user info:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
|
||||||
|
// 绑定过登录
|
||||||
|
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
if userOpenID != nil {
|
||||||
|
user, err := repo.User.GetUserByID(userOpenID.UserID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
resp := &dto.OidcLoginResp{
|
||||||
|
Token: token,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
User: user.ToDto(),
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
} else {
|
||||||
|
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
|
||||||
|
user, err := repo.User.GetUserByEmail(userInfo.Email)
|
||||||
|
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
logrus.Errorln("Failed to get user by email:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
if user != nil {
|
||||||
|
userOpenID = &model.UserOpenID{
|
||||||
|
UserID: user.ID,
|
||||||
|
Issuer: oidcConfig.Issuer,
|
||||||
|
Sub: userInfo.Sub,
|
||||||
|
}
|
||||||
|
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
resp := &dto.OidcLoginResp{
|
||||||
|
Token: token,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
User: user.ToDto(),
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
} else {
|
||||||
|
user = &model.User{
|
||||||
|
Username: userInfo.Name,
|
||||||
|
Nickname: userInfo.Name,
|
||||||
|
AvatarUrl: userInfo.Picture,
|
||||||
|
Email: userInfo.Email,
|
||||||
|
}
|
||||||
|
err = repo.User.CreateUser(user)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to create user:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
userOpenID = &model.UserOpenID{
|
||||||
|
UserID: user.ID,
|
||||||
|
Issuer: oidcConfig.Issuer,
|
||||||
|
Sub: userInfo.Sub,
|
||||||
|
}
|
||||||
|
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
|
if err != nil {
|
||||||
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
resp := &dto.OidcLoginResp{
|
||||||
|
Token: token,
|
||||||
|
RefreshToken: refreshToken,
|
||||||
|
User: user.ToDto(),
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
|
||||||
|
if req.UserID == 0 {
|
||||||
|
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
|
||||||
|
}
|
||||||
|
user, err := repo.User.GetUserByID(req.UserID)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errs.ErrNotFound
|
||||||
|
}
|
||||||
|
logrus.Errorln("Failed to get user by ID:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
if user == nil {
|
||||||
|
return nil, errs.ErrNotFound
|
||||||
|
}
|
||||||
|
return &dto.GetUserResp{
|
||||||
|
User: user.ToDto(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
|
||||||
|
user := &model.User{
|
||||||
|
Model: gorm.Model{
|
||||||
|
ID: req.ID,
|
||||||
|
},
|
||||||
|
Username: req.Username,
|
||||||
|
Nickname: req.Nickname,
|
||||||
|
Gender: req.Gender,
|
||||||
|
AvatarUrl: req.AvatarUrl,
|
||||||
|
}
|
||||||
|
err := repo.User.UpdateUser(user)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
|
return nil, errs.ErrNotFound
|
||||||
|
}
|
||||||
|
logrus.Errorln("Failed to update user:", err)
|
||||||
|
return nil, errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
return &dto.UpdateUserResp{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) generate2Token(userID uint) (string, string, error) {
|
||||||
|
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
||||||
|
tokenString, err := token.ToString()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||||
|
refreshTokenString, err := refreshToken.ToString()
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||||
|
if err != nil {
|
||||||
|
return "", "", errs.ErrInternalServer
|
||||||
|
}
|
||||||
|
return tokenString, refreshTokenString, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *UserService) verifyEmail(email, code string) (bool, error) {
|
||||||
|
kv := utils.KV.GetInstance()
|
||||||
|
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
|
||||||
|
if !ok || verificationCode != code {
|
||||||
|
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||||
|
}
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
@ -10,6 +10,7 @@ const (
|
|||||||
RoleUser = "user"
|
RoleUser = "user"
|
||||||
RoleAdmin = "admin"
|
RoleAdmin = "admin"
|
||||||
|
|
||||||
|
EnvKeyBaseUrl = "BASE_URL" // 环境变量:基础URL
|
||||||
EnvKeyMode = "MODE" // 环境变量:运行模式
|
EnvKeyMode = "MODE" // 环境变量:运行模式
|
||||||
EnvKeyJwtSecrete = "JWT_SECRET" // 环境变量:JWT密钥
|
EnvKeyJwtSecrete = "JWT_SECRET" // 环境变量:JWT密钥
|
||||||
EnvKeyPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐
|
EnvKeyPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐
|
||||||
@ -18,5 +19,9 @@ const (
|
|||||||
EnvKeyRefreshTokenDuration = "REFRESH_TOKEN_DURATION" // 环境变量:刷新令牌有效期
|
EnvKeyRefreshTokenDuration = "REFRESH_TOKEN_DURATION" // 环境变量:刷新令牌有效期
|
||||||
EnvKeyRefreshTokenDurationWithRemember = "REFRESH_TOKEN_DURATION_WITH_REMEMBER" // 环境变量:记住我刷新令牌有效期
|
EnvKeyRefreshTokenDurationWithRemember = "REFRESH_TOKEN_DURATION_WITH_REMEMBER" // 环境变量:记住我刷新令牌有效期
|
||||||
|
|
||||||
KVKeyEmailVerificationCode = "email_verification_code" // KV存储:邮箱验证码
|
KVKeyEmailVerificationCode = "email_verification_code:" // KV存储:邮箱验证码
|
||||||
|
KVKeyOidcState = "oidc_state:" // KV存储:OIDC状态
|
||||||
|
|
||||||
|
OidcUri = "/user/oidc/login" // OIDC登录URI
|
||||||
|
DefaultBaseUrl = "http://localhost:3000" // 默认BaseUrl
|
||||||
)
|
)
|
||||||
|
@ -11,6 +11,7 @@ func Custom(c *app.RequestContext, status int, message string, data any) {
|
|||||||
"message": message,
|
"message": message,
|
||||||
"data": data,
|
"data": data,
|
||||||
})
|
})
|
||||||
|
c.Abort()
|
||||||
}
|
}
|
||||||
|
|
||||||
func Ok(c *app.RequestContext, message string, data any) {
|
func Ok(c *app.RequestContext, message string, data any) {
|
||||||
|
@ -36,7 +36,7 @@ func (c *Claims) ToString() (string, error) {
|
|||||||
return token.SignedString([]byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret")))
|
return token.SignedString([]byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret")))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseJsonWebTokenWithoutState 解析JWT令牌,不对有状态的Token进行状态检查
|
// ParseJsonWebTokenWithoutState 解析JWT令牌,仅检查无状态下是否valid,不对有状态的Token进行状态检查
|
||||||
func (j *jwtUtils) ParseJsonWebTokenWithoutState(tokenString string) (*Claims, error) {
|
func (j *jwtUtils) ParseJsonWebTokenWithoutState(tokenString string) (*Claims, error) {
|
||||||
claims := &Claims{}
|
claims := &Claims{}
|
||||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
|
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
|
||||||
|
72
pkg/utils/oidc.go
Normal file
72
pkg/utils/oidc.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"resty.dev/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
type oidcUtils struct{}
|
||||||
|
|
||||||
|
var Oidc = oidcUtils{}
|
||||||
|
|
||||||
|
// RequestToken 请求访问令牌
|
||||||
|
func (u *oidcUtils) RequestToken(tokenEndpoint, clientID, clientSecret, code, redirectURI string) (*TokenResponse, error) {
|
||||||
|
client := resty.New()
|
||||||
|
tokenResp, err := client.R().
|
||||||
|
SetFormData(map[string]string{
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"client_id": clientID,
|
||||||
|
"client_secret": clientSecret,
|
||||||
|
"code": code,
|
||||||
|
"redirect_uri": redirectURI,
|
||||||
|
}).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetResult(&TokenResponse{}).
|
||||||
|
Post(tokenEndpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if tokenResp.StatusCode() != 200 {
|
||||||
|
return nil, fmt.Errorf("状态码: %d,响应: %s", tokenResp.StatusCode(), tokenResp.String())
|
||||||
|
}
|
||||||
|
return tokenResp.Result().(*TokenResponse), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestUserInfo 请求用户信息
|
||||||
|
func (u *oidcUtils) RequestUserInfo(userInfoEndpoint, accessToken string) (*UserInfo, error) {
|
||||||
|
client := resty.New()
|
||||||
|
userInfoResp, err := client.R().
|
||||||
|
SetHeader("Authorization", "Bearer "+accessToken).
|
||||||
|
SetHeader("Accept", "application/json").
|
||||||
|
SetResult(&UserInfo{}).
|
||||||
|
Get(userInfoEndpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if userInfoResp.StatusCode() != 200 {
|
||||||
|
return nil, fmt.Errorf("状态码: %d,响应: %s", userInfoResp.StatusCode(), userInfoResp.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfoResp.Result().(*UserInfo), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token,omitempty"`
|
||||||
|
RefreshToken string `json:"refresh_token,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// UserInfo 定义用户信息结构
|
||||||
|
type UserInfo struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
Picture string `json:"picture,omitempty"`
|
||||||
|
Groups []string `json:"groups,omitempty"` // 可选字段,OIDC提供的用户组信息
|
||||||
|
}
|
20
pkg/utils/url.go
Normal file
20
pkg/utils/url.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import "net/url"
|
||||||
|
|
||||||
|
type urlUtils struct{}
|
||||||
|
|
||||||
|
var Url = &urlUtils{}
|
||||||
|
|
||||||
|
func (u *urlUtils) BuildUrl(baseUrl string, queryParams map[string]string) string {
|
||||||
|
newUrl, err := url.Parse(baseUrl)
|
||||||
|
if err != nil {
|
||||||
|
return baseUrl
|
||||||
|
}
|
||||||
|
q := newUrl.Query()
|
||||||
|
for key, value := range queryParams {
|
||||||
|
q.Set(key, value)
|
||||||
|
}
|
||||||
|
newUrl.RawQuery = q.Encode()
|
||||||
|
return newUrl.String()
|
||||||
|
}
|
Reference in New Issue
Block a user