mirror of
https://github.com/snowykami/neo-blog.git
synced 2025-09-03 15:56:22 +00:00
⚡ refactor user service methods, implement OIDC login and user management features, and enhance token handling
This commit is contained in:
@ -23,6 +23,7 @@ EMAIL_PORT=465
|
||||
EMAIL_SSL=true
|
||||
|
||||
# App settings
|
||||
BASE_URL=https://blog.jason.moe
|
||||
MAX_REQUEST_BODY_SIZE=1000000
|
||||
MODE=prod
|
||||
PORT=8888
|
||||
|
@ -3,17 +3,17 @@ package v1
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
"github.com/cloudwego/hertz/pkg/protocol"
|
||||
"github.com/cloudwego/hertz/pkg/common/utils"
|
||||
"github.com/snowykami/neo-blog/internal/ctxutils"
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"github.com/snowykami/neo-blog/internal/service"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/errs"
|
||||
"github.com/snowykami/neo-blog/pkg/resps"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
type userType struct {
|
||||
service service.UserService
|
||||
service *service.UserService
|
||||
}
|
||||
|
||||
var User = &userType{
|
||||
@ -33,13 +33,11 @@ func (u *userType) Login(ctx context.Context, c *app.RequestContext) {
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
||||
return
|
||||
} else {
|
||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, utils.H{
|
||||
"token": resp.Token,
|
||||
"user": resp.User,
|
||||
})
|
||||
}
|
||||
|
||||
func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
||||
@ -55,46 +53,101 @@ func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
||||
return
|
||||
}
|
||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
|
||||
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, utils.H{
|
||||
"token": resp.Token,
|
||||
"user": resp.User,
|
||||
})
|
||||
}
|
||||
|
||||
func (u *userType) Logout(ctx context.Context, c *app.RequestContext) {
|
||||
u.clearTokenCookie(c)
|
||||
ctxutils.ClearTokenAndRefreshTokenCookie(c)
|
||||
resps.Ok(c, resps.Success, nil)
|
||||
}
|
||||
|
||||
func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
resp, err := u.service.ListOidcConfigs()
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
resps.Ok(c, resps.Success, map[string]any{
|
||||
"oidc_configs": resp.OidcConfigs,
|
||||
})
|
||||
}
|
||||
|
||||
func (u *userType) OidcLogin(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
name := c.Param("name")
|
||||
code := c.Param("code")
|
||||
state := c.Param("state")
|
||||
oidcLoginReq := &dto.OidcLoginReq{
|
||||
Name: name,
|
||||
Code: code,
|
||||
State: state,
|
||||
}
|
||||
resp, err := u.service.OidcLogin(oidcLoginReq)
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, map[string]any{
|
||||
"token": resp.Token,
|
||||
"user": resp.User,
|
||||
})
|
||||
}
|
||||
|
||||
func (u *userType) Get(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
}
|
||||
|
||||
func (u *userType) Update(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
}
|
||||
|
||||
func (u *userType) Delete(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
}
|
||||
|
||||
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
||||
var verifyEmailReq dto.VerifyEmailReq
|
||||
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
|
||||
func (u *userType) GetUser(ctx context.Context, c *app.RequestContext) {
|
||||
userID := c.Param("id")
|
||||
if userID == "" {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
resp, err := u.service.VerifyEmail(&verifyEmailReq)
|
||||
userIDInt, err := strconv.Atoi(userID)
|
||||
if err != nil || userIDInt <= 0 {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := u.service.GetUser(&dto.GetUserReq{UserID: uint(userIDInt)})
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
resps.Ok(c, resps.Success, resp.User)
|
||||
}
|
||||
|
||||
func (u *userType) UpdateUser(ctx context.Context, c *app.RequestContext) {
|
||||
userID := c.Param("id")
|
||||
if userID == "" {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
userIDInt, err := strconv.Atoi(userID)
|
||||
if err != nil || userIDInt <= 0 {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
var updateUserReq dto.UpdateUserReq
|
||||
if err := c.BindAndValidate(&updateUserReq); err != nil {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
updateUserReq.ID = uint(userIDInt)
|
||||
currentUser := ctxutils.GetCurrentUser(ctx)
|
||||
if currentUser == nil {
|
||||
resps.UnAuthorized(c, resps.ErrUnauthorized)
|
||||
return
|
||||
}
|
||||
if currentUser.ID != updateUserReq.ID {
|
||||
resps.Forbidden(c, resps.ErrForbidden)
|
||||
return
|
||||
}
|
||||
resp, err := u.service.UpdateUser(&updateUserReq)
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
@ -103,12 +156,17 @@ func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
|
||||
func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) {
|
||||
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
||||
|
||||
func (u *userType) clearTokenCookie(c *app.RequestContext) {
|
||||
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
||||
var verifyEmailReq dto.VerifyEmailReq
|
||||
if err := c.BindAndValidate(&verifyEmailReq); err != nil {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
resp, err := u.service.RequestVerifyEmail(&verifyEmailReq)
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
|
22
internal/ctxutils/token.go
Normal file
22
internal/ctxutils/token.go
Normal file
@ -0,0 +1,22 @@
|
||||
package ctxutils
|
||||
|
||||
import (
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
"github.com/cloudwego/hertz/pkg/protocol"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
)
|
||||
|
||||
func SetTokenCookie(c *app.RequestContext, token string) {
|
||||
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
||||
|
||||
func SetTokenAndRefreshTokenCookie(c *app.RequestContext, token, refreshToken string) {
|
||||
c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
||||
|
||||
func ClearTokenAndRefreshTokenCookie(c *app.RequestContext) {
|
||||
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
19
internal/ctxutils/user.go
Normal file
19
internal/ctxutils/user.go
Normal file
@ -0,0 +1,19 @@
|
||||
package ctxutils
|
||||
|
||||
import (
|
||||
"context"
|
||||
"github.com/snowykami/neo-blog/internal/model"
|
||||
"github.com/snowykami/neo-blog/internal/repo"
|
||||
)
|
||||
|
||||
func GetCurrentUser(ctx context.Context) *model.User {
|
||||
userIDValue := ctx.Value("user_id").(uint)
|
||||
if userIDValue <= 0 {
|
||||
return nil
|
||||
}
|
||||
user, err := repo.User.GetUserByID(userIDValue)
|
||||
if err != nil || user == nil || user.ID == 0 {
|
||||
return nil
|
||||
}
|
||||
return user
|
||||
}
|
@ -1,7 +1 @@
|
||||
package dto
|
||||
|
||||
type BaseResp struct {
|
||||
Code int `json:"code"`
|
||||
Message string `json:"message"`
|
||||
Data any `json:"data"`
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package dto
|
||||
|
||||
type UserDto struct {
|
||||
ID uint `json:"id"` // 用户ID
|
||||
Username string `json:"username"` // 用户名
|
||||
Nickname string `json:"nickname"`
|
||||
AvatarUrl string `json:"avatar_url"` // 头像URL
|
||||
@ -8,6 +9,13 @@ type UserDto struct {
|
||||
Gender string `json:"gender"`
|
||||
Role string `json:"role"`
|
||||
}
|
||||
|
||||
type OidcConfigDto struct {
|
||||
Name string `json:"name"` // OIDC配置名称
|
||||
DisplayName string `json:"display_name"` // OIDC配置显示名称
|
||||
Icon string `json:"icon"` // OIDC配置图标URL
|
||||
LoginUrl string `json:"login_url"` // OIDC登录URL
|
||||
}
|
||||
type UserLoginReq struct {
|
||||
Username string `json:"username"` // username or email
|
||||
Password string `json:"password"`
|
||||
@ -40,3 +48,39 @@ type VerifyEmailReq struct {
|
||||
type VerifyEmailResp struct {
|
||||
Success bool `json:"success"` // 验证码发送成功与否
|
||||
}
|
||||
|
||||
type OidcLoginReq struct {
|
||||
Name string `json:"name"` // OIDC配置名称
|
||||
Code string `json:"code"` // OIDC授权码
|
||||
State string `json:"state"`
|
||||
}
|
||||
|
||||
type OidcLoginResp struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserDto `json:"user"`
|
||||
}
|
||||
|
||||
type ListOidcConfigResp struct {
|
||||
OidcConfigs []OidcConfigDto `json:"oidc_configs"` // OIDC配置列表
|
||||
}
|
||||
|
||||
type GetUserReq struct {
|
||||
UserID uint `json:"user_id"`
|
||||
}
|
||||
|
||||
type GetUserResp struct {
|
||||
User *UserDto `json:"user"` // 用户信息
|
||||
}
|
||||
|
||||
type UpdateUserReq struct {
|
||||
ID uint `json:"id"`
|
||||
Username string `json:"username"`
|
||||
Nickname string `json:"nickname"`
|
||||
AvatarUrl string `json:"avatar_url"`
|
||||
Gender string `json:"gender"`
|
||||
}
|
||||
|
||||
type UpdateUserResp struct {
|
||||
User *UserDto `json:"user"` // 更新后的用户信息
|
||||
}
|
||||
|
@ -3,10 +3,53 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
"github.com/snowykami/neo-blog/internal/ctxutils"
|
||||
"github.com/snowykami/neo-blog/internal/repo"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/resps"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UseAuth() app.HandlerFunc {
|
||||
return func(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Implement authentication logic here
|
||||
// For cookie
|
||||
token := string(c.Cookie("token"))
|
||||
refreshToken := string(c.Cookie("refresh_token"))
|
||||
tokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(token)
|
||||
if err == nil && tokenClaims != nil {
|
||||
ctx = context.WithValue(ctx, "user_id", tokenClaims.UserID)
|
||||
c.Next(ctx)
|
||||
return
|
||||
}
|
||||
|
||||
// token 失效 使用refresh token重新签发和鉴权
|
||||
refreshTokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(refreshToken)
|
||||
if err == nil && refreshTokenClaims != nil {
|
||||
ok, err := isStatefulJwtValid(refreshTokenClaims)
|
||||
if err == nil && ok {
|
||||
ctx = context.WithValue(ctx, "user_id", refreshTokenClaims.UserID) // 修改这里,使用refreshTokenClaims
|
||||
c.Next(ctx)
|
||||
newTokenClaims := utils.Jwt.NewClaims(refreshTokenClaims.UserID, refreshTokenClaims.SessionKey, refreshTokenClaims.Stateful, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||
newToken, err := newTokenClaims.ToString()
|
||||
if err == nil {
|
||||
ctxutils.SetTokenCookie(c, newToken)
|
||||
} else {
|
||||
resps.InternalServerError(c, resps.ErrInternalServerError)
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 所有认证方式都失败,返回未授权错误
|
||||
resps.UnAuthorized(c, resps.ErrUnauthorized)
|
||||
c.Abort()
|
||||
}
|
||||
}
|
||||
|
||||
func isStatefulJwtValid(claims *utils.Claims) (bool, error) {
|
||||
if !claims.Stateful {
|
||||
return true, nil
|
||||
}
|
||||
return repo.Session.IsSessionValid(claims.SessionKey)
|
||||
}
|
||||
|
@ -2,6 +2,7 @@ package model
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"gorm.io/gorm"
|
||||
"resty.dev/v3"
|
||||
"time"
|
||||
@ -9,14 +10,13 @@ import (
|
||||
|
||||
type OidcConfig struct {
|
||||
gorm.Model
|
||||
Name string `gorm:"uniqueIndex"`
|
||||
ClientID string `gorm:"column:client_id"` // 客户端ID
|
||||
ClientSecret string `gorm:"column:client_secret"` // 客户端密钥
|
||||
DisplayName string `gorm:"column:display_name"` // 显示名称,例如:轻雪通行证
|
||||
GroupsClaim *string `gorm:"default:groups"` // 组声明,默认为:"groups"
|
||||
Icon *string `gorm:"column:icon"` // 图标url,为空则使用内置默认图标
|
||||
OidcDiscoveryUrl string `gorm:"column:oidc_discovery_url"` // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration
|
||||
Enabled bool `gorm:"column:enabled;default:true"` // 是否启用
|
||||
Name string `gorm:"uniqueIndex"` // OIDC配置名称,唯一
|
||||
ClientID string // 客户端ID
|
||||
ClientSecret string // 客户端密钥
|
||||
DisplayName string // 显示名称,例如:轻雪通行证
|
||||
Icon string // 图标url,为空则使用内置默认图标
|
||||
OidcDiscoveryUrl string // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration
|
||||
Enabled bool `gorm:"default:true"` // 是否启用
|
||||
// 以下字段为自动获取字段,每次更新配置时自动填充
|
||||
Issuer string
|
||||
AuthorizationEndpoint string
|
||||
@ -68,11 +68,6 @@ func updateOidcConfigFromUrl(url string) (*oidcDiscoveryResp, error) {
|
||||
}
|
||||
|
||||
func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
|
||||
// 设置默认值
|
||||
if o.GroupsClaim == nil {
|
||||
defaultGroupsClaim := "groups"
|
||||
o.GroupsClaim = &defaultGroupsClaim
|
||||
}
|
||||
// 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息
|
||||
if tx.Statement.Changed("OidcDiscoveryUrl") {
|
||||
discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl)
|
||||
@ -87,3 +82,12 @@ func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ToDto 不包含LoginUrl,在service层自行实现
|
||||
func (o *OidcConfig) ToDto() *dto.OidcConfigDto {
|
||||
return &dto.OidcConfigDto{
|
||||
Name: o.Name,
|
||||
DisplayName: o.DisplayName,
|
||||
Icon: o.Icon,
|
||||
}
|
||||
}
|
||||
|
@ -16,8 +16,17 @@ type User struct {
|
||||
Password string // 密码,存储加密后的值
|
||||
}
|
||||
|
||||
type UserOpenID struct {
|
||||
gorm.Model
|
||||
UserID uint `gorm:"uniqueIndex"`
|
||||
User User `gorm:"foreignKey:UserID;references:ID"`
|
||||
Issuer string `gorm:"index"` // OIDC Issuer
|
||||
Sub string `gorm:"index"` // OIDC Sub openid
|
||||
}
|
||||
|
||||
func (user *User) ToDto() *dto.UserDto {
|
||||
return &dto.UserDto{
|
||||
ID: user.ID,
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
AvatarUrl: user.AvatarUrl,
|
||||
|
@ -6,7 +6,7 @@ type userRepo struct{}
|
||||
|
||||
var User = &userRepo{}
|
||||
|
||||
func (user *userRepo) GetByUsername(username string) (*model.User, error) {
|
||||
func (user *userRepo) GetUserByUsername(username string) (*model.User, error) {
|
||||
var userModel model.User
|
||||
if err := GetDB().Where("username = ?", username).First(&userModel).Error; err != nil {
|
||||
return nil, err
|
||||
@ -14,7 +14,7 @@ func (user *userRepo) GetByUsername(username string) (*model.User, error) {
|
||||
return &userModel, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) GetByEmail(email string) (*model.User, error) {
|
||||
func (user *userRepo) GetUserByEmail(email string) (*model.User, error) {
|
||||
var userModel model.User
|
||||
if err := GetDB().Where("email = ?", email).First(&userModel).Error; err != nil {
|
||||
return nil, err
|
||||
@ -22,7 +22,15 @@ func (user *userRepo) GetByEmail(email string) (*model.User, error) {
|
||||
return &userModel, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
|
||||
func (user *userRepo) GetUserByID(id uint) (*model.User, error) {
|
||||
var userModel model.User
|
||||
if err := GetDB().Where("id = ?", id).First(&userModel).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &userModel, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) GetUserByUsernameOrEmail(usernameOrEmail string) (*model.User, error) {
|
||||
var userModel model.User
|
||||
if err := GetDB().Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).First(&userModel).Error; err != nil {
|
||||
return nil, err
|
||||
@ -30,14 +38,14 @@ func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User,
|
||||
return &userModel, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) Create(userModel *model.User) error {
|
||||
func (user *userRepo) CreateUser(userModel *model.User) error {
|
||||
if err := GetDB().Create(userModel).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *userRepo) Update(userModel *model.User) error {
|
||||
func (user *userRepo) UpdateUser(userModel *model.User) error {
|
||||
if err := GetDB().Updates(userModel).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
@ -59,3 +67,40 @@ func (user *userRepo) CheckEmailExists(email string) (bool, error) {
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) ListOidcConfigs(onlyEnabled bool) ([]model.OidcConfig, error) {
|
||||
var configs []model.OidcConfig
|
||||
if onlyEnabled {
|
||||
if err := GetDB().Where("enabled = ?", true).Find(&configs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := GetDB().Find(&configs).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return configs, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) GetOidcConfigByName(name string) (*model.OidcConfig, error) {
|
||||
var config model.OidcConfig
|
||||
if err := GetDB().Where("name = ?", name).First(&config).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (user *userRepo) CreateOrUpdateUserOpenID(userOpenID *model.UserOpenID) error {
|
||||
if err := GetDB().Save(userOpenID).Error; err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *userRepo) GetUserOpenIDByIssuerAndSub(issuer, sub string) (*model.UserOpenID, error) {
|
||||
var userOpenID model.UserOpenID
|
||||
if err := GetDB().Where("issuer = ? AND sub = ?", issuer, sub).First(&userOpenID).Error; err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &userOpenID, nil
|
||||
}
|
||||
|
@ -16,9 +16,8 @@ func registerUserRoutes(group *route.RouterGroup) {
|
||||
userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code
|
||||
userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList)
|
||||
userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin)
|
||||
userGroupWithoutAuth.GET("/u/:id", v1.User.Get)
|
||||
userGroupWithoutAuth.GET("/u/:id", v1.User.GetUser)
|
||||
userGroup.POST("/logout", v1.User.Logout)
|
||||
userGroup.PUT("/u/:id", v1.User.Update)
|
||||
userGroup.DELETE("/u/:id", v1.User.Delete)
|
||||
userGroup.PUT("/u/:id", v1.User.UpdateUser)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"github.com/snowykami/neo-blog/internal/model"
|
||||
@ -9,25 +10,20 @@ import (
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/errs"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
"gorm.io/gorm"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserService interface {
|
||||
UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error)
|
||||
UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error)
|
||||
VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error)
|
||||
// TODO impl other user-related methods
|
||||
type UserService struct{}
|
||||
|
||||
func NewUserService() *UserService {
|
||||
return &UserService{}
|
||||
}
|
||||
|
||||
type userService struct{}
|
||||
|
||||
func NewUserService() UserService {
|
||||
return &userService{}
|
||||
}
|
||||
|
||||
func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||
user, err := repo.User.GetByUsernameOrEmail(req.Username)
|
||||
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
@ -35,26 +31,14 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
|
||||
return nil, errs.ErrNotFound
|
||||
}
|
||||
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
||||
|
||||
token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
||||
tokenString, err := token.ToString()
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
|
||||
refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||
refreshTokenString, err := refreshToken.ToString()
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
// 对refresh token进行持久化存储
|
||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||
token, refreshToken, err := s.generate2Token(user.ID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to generate tokens:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
resp := &dto.UserLoginResp{
|
||||
Token: tokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
Token: token,
|
||||
RefreshToken: refreshToken,
|
||||
User: user.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
@ -63,16 +47,19 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
||||
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
||||
// 验证邮箱验证码
|
||||
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
|
||||
return nil, errs.ErrForbidden
|
||||
}
|
||||
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
|
||||
kv := utils.KV.GetInstance()
|
||||
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + ":" + req.Email)
|
||||
if !ok || verificationCode != req.VerificationCode {
|
||||
return nil, errs.ErrInvalidCredentials
|
||||
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to verify email:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if !ok {
|
||||
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||
}
|
||||
}
|
||||
// 检查用户名或邮箱是否已存在
|
||||
@ -101,39 +88,28 @@ func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterR
|
||||
Role: "user",
|
||||
Password: hashedPassword,
|
||||
}
|
||||
err = repo.User.Create(newUser)
|
||||
err = repo.User.CreateUser(newUser)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
// 生成访问令牌和刷新令牌
|
||||
token := utils.Jwt.NewClaims(newUser.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
||||
tokenString, err := token.ToString()
|
||||
token, refreshToken, err := s.generate2Token(newUser.ID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to generate tokens:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
refreshToken := utils.Jwt.NewClaims(newUser.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||
refreshTokenString, err := refreshToken.ToString()
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
// 对refresh token进行持久化存储
|
||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
|
||||
resp := &dto.UserRegisterResp{
|
||||
Token: tokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
Token: token,
|
||||
RefreshToken: refreshToken,
|
||||
User: newUser.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
||||
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
||||
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
||||
kv := utils.KV.GetInstance()
|
||||
kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10)
|
||||
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
|
||||
|
||||
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
||||
if err != nil {
|
||||
@ -149,3 +125,220 @@ func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp
|
||||
}
|
||||
return &dto.VerifyEmailResp{Success: true}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) ListOidcConfigs() (*dto.ListOidcConfigResp, error) {
|
||||
enabledOidcConfigs, err := repo.User.ListOidcConfigs(true)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
var oidcConfigsDtos []dto.OidcConfigDto
|
||||
|
||||
for _, oidcConfig := range enabledOidcConfigs {
|
||||
state := utils.Strings.GenerateRandomString(32)
|
||||
kvStore := utils.KV.GetInstance()
|
||||
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
|
||||
oidcConfigsDtos = append(oidcConfigsDtos, dto.OidcConfigDto{
|
||||
Name: oidcConfig.Name,
|
||||
DisplayName: oidcConfig.DisplayName,
|
||||
Icon: oidcConfig.Icon,
|
||||
LoginUrl: utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
||||
"client_id": oidcConfig.ClientID,
|
||||
"redirect_uri": strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/") + constant.OidcUri + oidcConfig.Name,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
}),
|
||||
})
|
||||
}
|
||||
return &dto.ListOidcConfigResp{
|
||||
OidcConfigs: oidcConfigsDtos,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
|
||||
// 验证state
|
||||
kvStore := utils.KV.GetInstance()
|
||||
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
|
||||
if !ok || storedName != req.Name {
|
||||
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
|
||||
}
|
||||
// 获取OIDC配置
|
||||
oidcConfig, err := repo.User.GetOidcConfigByName(req.Name)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if oidcConfig == nil {
|
||||
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
|
||||
}
|
||||
// 请求访问令牌
|
||||
tokenResp, err := utils.Oidc.RequestToken(
|
||||
oidcConfig.TokenEndpoint,
|
||||
oidcConfig.ClientID,
|
||||
oidcConfig.ClientSecret,
|
||||
req.Code,
|
||||
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
|
||||
)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to request OIDC token:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to request OIDC user info:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
|
||||
// 绑定过登录
|
||||
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if userOpenID != nil {
|
||||
user, err := repo.User.GetUserByID(userOpenID.UserID)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
token, refreshToken, err := s.generate2Token(user.ID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to generate tokens:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
resp := &dto.OidcLoginResp{
|
||||
Token: token,
|
||||
RefreshToken: refreshToken,
|
||||
User: user.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
} else {
|
||||
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
|
||||
user, err := repo.User.GetUserByEmail(userInfo.Email)
|
||||
if !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
logrus.Errorln("Failed to get user by email:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if user != nil {
|
||||
userOpenID = &model.UserOpenID{
|
||||
UserID: user.ID,
|
||||
Issuer: oidcConfig.Issuer,
|
||||
Sub: userInfo.Sub,
|
||||
}
|
||||
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
token, refreshToken, err := s.generate2Token(user.ID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to generate tokens:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
resp := &dto.OidcLoginResp{
|
||||
Token: token,
|
||||
RefreshToken: refreshToken,
|
||||
User: user.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
} else {
|
||||
user = &model.User{
|
||||
Username: userInfo.Name,
|
||||
Nickname: userInfo.Name,
|
||||
AvatarUrl: userInfo.Picture,
|
||||
Email: userInfo.Email,
|
||||
}
|
||||
err = repo.User.CreateUser(user)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to create user:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
userOpenID = &model.UserOpenID{
|
||||
UserID: user.ID,
|
||||
Issuer: oidcConfig.Issuer,
|
||||
Sub: userInfo.Sub,
|
||||
}
|
||||
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
token, refreshToken, err := s.generate2Token(user.ID)
|
||||
if err != nil {
|
||||
logrus.Errorln("Failed to generate tokens:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
resp := &dto.OidcLoginResp{
|
||||
Token: token,
|
||||
RefreshToken: refreshToken,
|
||||
User: user.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
|
||||
if req.UserID == 0 {
|
||||
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
|
||||
}
|
||||
user, err := repo.User.GetUserByID(req.UserID)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errs.ErrNotFound
|
||||
}
|
||||
logrus.Errorln("Failed to get user by ID:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errs.ErrNotFound
|
||||
}
|
||||
return &dto.GetUserResp{
|
||||
User: user.ToDto(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
|
||||
user := &model.User{
|
||||
Model: gorm.Model{
|
||||
ID: req.ID,
|
||||
},
|
||||
Username: req.Username,
|
||||
Nickname: req.Nickname,
|
||||
Gender: req.Gender,
|
||||
AvatarUrl: req.AvatarUrl,
|
||||
}
|
||||
err := repo.User.UpdateUser(user)
|
||||
if err != nil {
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, errs.ErrNotFound
|
||||
}
|
||||
logrus.Errorln("Failed to update user:", err)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
return &dto.UpdateUserResp{}, nil
|
||||
}
|
||||
|
||||
func (s *UserService) generate2Token(userID uint) (string, string, error) {
|
||||
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
||||
tokenString, err := token.ToString()
|
||||
if err != nil {
|
||||
return "", "", errs.ErrInternalServer
|
||||
}
|
||||
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||
refreshTokenString, err := refreshToken.ToString()
|
||||
if err != nil {
|
||||
return "", "", errs.ErrInternalServer
|
||||
}
|
||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||
if err != nil {
|
||||
return "", "", errs.ErrInternalServer
|
||||
}
|
||||
return tokenString, refreshTokenString, nil
|
||||
}
|
||||
|
||||
func (s *UserService) verifyEmail(email, code string) (bool, error) {
|
||||
kv := utils.KV.GetInstance()
|
||||
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
|
||||
if !ok || verificationCode != code {
|
||||
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ const (
|
||||
RoleUser = "user"
|
||||
RoleAdmin = "admin"
|
||||
|
||||
EnvKeyBaseUrl = "BASE_URL" // 环境变量:基础URL
|
||||
EnvKeyMode = "MODE" // 环境变量:运行模式
|
||||
EnvKeyJwtSecrete = "JWT_SECRET" // 环境变量:JWT密钥
|
||||
EnvKeyPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐
|
||||
@ -18,5 +19,9 @@ const (
|
||||
EnvKeyRefreshTokenDuration = "REFRESH_TOKEN_DURATION" // 环境变量:刷新令牌有效期
|
||||
EnvKeyRefreshTokenDurationWithRemember = "REFRESH_TOKEN_DURATION_WITH_REMEMBER" // 环境变量:记住我刷新令牌有效期
|
||||
|
||||
KVKeyEmailVerificationCode = "email_verification_code" // KV存储:邮箱验证码
|
||||
KVKeyEmailVerificationCode = "email_verification_code:" // KV存储:邮箱验证码
|
||||
KVKeyOidcState = "oidc_state:" // KV存储:OIDC状态
|
||||
|
||||
OidcUri = "/user/oidc/login" // OIDC登录URI
|
||||
DefaultBaseUrl = "http://localhost:3000" // 默认BaseUrl
|
||||
)
|
||||
|
@ -11,6 +11,7 @@ func Custom(c *app.RequestContext, status int, message string, data any) {
|
||||
"message": message,
|
||||
"data": data,
|
||||
})
|
||||
c.Abort()
|
||||
}
|
||||
|
||||
func Ok(c *app.RequestContext, message string, data any) {
|
||||
|
@ -36,7 +36,7 @@ func (c *Claims) ToString() (string, error) {
|
||||
return token.SignedString([]byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret")))
|
||||
}
|
||||
|
||||
// ParseJsonWebTokenWithoutState 解析JWT令牌,不对有状态的Token进行状态检查
|
||||
// ParseJsonWebTokenWithoutState 解析JWT令牌,仅检查无状态下是否valid,不对有状态的Token进行状态检查
|
||||
func (j *jwtUtils) ParseJsonWebTokenWithoutState(tokenString string) (*Claims, error) {
|
||||
claims := &Claims{}
|
||||
token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) {
|
||||
|
72
pkg/utils/oidc.go
Normal file
72
pkg/utils/oidc.go
Normal file
@ -0,0 +1,72 @@
|
||||
package utils
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"resty.dev/v3"
|
||||
)
|
||||
|
||||
type oidcUtils struct{}
|
||||
|
||||
var Oidc = oidcUtils{}
|
||||
|
||||
// RequestToken 请求访问令牌
|
||||
func (u *oidcUtils) RequestToken(tokenEndpoint, clientID, clientSecret, code, redirectURI string) (*TokenResponse, error) {
|
||||
client := resty.New()
|
||||
tokenResp, err := client.R().
|
||||
SetFormData(map[string]string{
|
||||
"grant_type": "authorization_code",
|
||||
"client_id": clientID,
|
||||
"client_secret": clientSecret,
|
||||
"code": code,
|
||||
"redirect_uri": redirectURI,
|
||||
}).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetResult(&TokenResponse{}).
|
||||
Post(tokenEndpoint)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if tokenResp.StatusCode() != 200 {
|
||||
return nil, fmt.Errorf("状态码: %d,响应: %s", tokenResp.StatusCode(), tokenResp.String())
|
||||
}
|
||||
return tokenResp.Result().(*TokenResponse), nil
|
||||
}
|
||||
|
||||
// RequestUserInfo 请求用户信息
|
||||
func (u *oidcUtils) RequestUserInfo(userInfoEndpoint, accessToken string) (*UserInfo, error) {
|
||||
client := resty.New()
|
||||
userInfoResp, err := client.R().
|
||||
SetHeader("Authorization", "Bearer "+accessToken).
|
||||
SetHeader("Accept", "application/json").
|
||||
SetResult(&UserInfo{}).
|
||||
Get(userInfoEndpoint)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if userInfoResp.StatusCode() != 200 {
|
||||
return nil, fmt.Errorf("状态码: %d,响应: %s", userInfoResp.StatusCode(), userInfoResp.String())
|
||||
}
|
||||
|
||||
return userInfoResp.Result().(*UserInfo), nil
|
||||
}
|
||||
|
||||
type TokenResponse struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
IDToken string `json:"id_token,omitempty"`
|
||||
RefreshToken string `json:"refresh_token,omitempty"`
|
||||
}
|
||||
|
||||
// UserInfo 定义用户信息结构
|
||||
type UserInfo struct {
|
||||
Sub string `json:"sub"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
Picture string `json:"picture,omitempty"`
|
||||
Groups []string `json:"groups,omitempty"` // 可选字段,OIDC提供的用户组信息
|
||||
}
|
20
pkg/utils/url.go
Normal file
20
pkg/utils/url.go
Normal file
@ -0,0 +1,20 @@
|
||||
package utils
|
||||
|
||||
import "net/url"
|
||||
|
||||
type urlUtils struct{}
|
||||
|
||||
var Url = &urlUtils{}
|
||||
|
||||
func (u *urlUtils) BuildUrl(baseUrl string, queryParams map[string]string) string {
|
||||
newUrl, err := url.Parse(baseUrl)
|
||||
if err != nil {
|
||||
return baseUrl
|
||||
}
|
||||
q := newUrl.Query()
|
||||
for key, value := range queryParams {
|
||||
q.Set(key, value)
|
||||
}
|
||||
newUrl.RawQuery = q.Encode()
|
||||
return newUrl.String()
|
||||
}
|
Reference in New Issue
Block a user