diff --git a/.env.example b/.env.example index 844477d..dfb5fbd 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,7 @@ EMAIL_PORT=465 EMAIL_SSL=true # App settings +BASE_URL=https://blog.jason.moe MAX_REQUEST_BODY_SIZE=1000000 MODE=prod PORT=8888 diff --git a/internal/controller/v1/user.go b/internal/controller/v1/user.go index 4ebadba..c87c356 100644 --- a/internal/controller/v1/user.go +++ b/internal/controller/v1/user.go @@ -3,17 +3,17 @@ package v1 import ( "context" "github.com/cloudwego/hertz/pkg/app" - "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/snowykami/neo-blog/internal/ctxutils" "github.com/snowykami/neo-blog/internal/dto" "github.com/snowykami/neo-blog/internal/service" - "github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/resps" - "github.com/snowykami/neo-blog/pkg/utils" + "strconv" ) type userType struct { - service service.UserService + service *service.UserService } var User = &userType{ @@ -33,13 +33,11 @@ func (u *userType) Login(ctx context.Context, c *app.RequestContext) { resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) return } - if resp == nil { - resps.UnAuthorized(c, resps.ErrInvalidCredentials) - return - } else { - u.setTokenCookie(c, resp.Token, resp.RefreshToken) - resps.Ok(c, resps.Success, resp) - } + ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken) + resps.Ok(c, resps.Success, utils.H{ + "token": resp.Token, + "user": resp.User, + }) } func (u *userType) Register(ctx context.Context, c *app.RequestContext) { @@ -55,46 +53,101 @@ func (u *userType) Register(ctx context.Context, c *app.RequestContext) { resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) return } - if resp == nil { - resps.UnAuthorized(c, resps.ErrInvalidCredentials) - return - } - u.setTokenCookie(c, resp.Token, resp.RefreshToken) - resps.Ok(c, resps.Success, resp) + + ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken) + resps.Ok(c, resps.Success, utils.H{ + "token": resp.Token, + "user": resp.User, + }) } func (u *userType) Logout(ctx context.Context, c *app.RequestContext) { - u.clearTokenCookie(c) + ctxutils.ClearTokenAndRefreshTokenCookie(c) resps.Ok(c, resps.Success, nil) } func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) { - // TODO: Impl + resp, err := u.service.ListOidcConfigs() + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + resps.Ok(c, resps.Success, map[string]any{ + "oidc_configs": resp.OidcConfigs, + }) } func (u *userType) OidcLogin(ctx context.Context, c *app.RequestContext) { - // TODO: Impl + name := c.Param("name") + code := c.Param("code") + state := c.Param("state") + oidcLoginReq := &dto.OidcLoginReq{ + Name: name, + Code: code, + State: state, + } + resp, err := u.service.OidcLogin(oidcLoginReq) + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + ctxutils.SetTokenAndRefreshTokenCookie(c, resp.Token, resp.RefreshToken) + resps.Ok(c, resps.Success, map[string]any{ + "token": resp.Token, + "user": resp.User, + }) } -func (u *userType) Get(ctx context.Context, c *app.RequestContext) { - // TODO: Impl -} - -func (u *userType) Update(ctx context.Context, c *app.RequestContext) { - // TODO: Impl -} - -func (u *userType) Delete(ctx context.Context, c *app.RequestContext) { - // TODO: Impl -} - -func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) { - var verifyEmailReq dto.VerifyEmailReq - if err := c.BindAndValidate(&verifyEmailReq); err != nil { +func (u *userType) GetUser(ctx context.Context, c *app.RequestContext) { + userID := c.Param("id") + if userID == "" { resps.BadRequest(c, resps.ErrParamInvalid) return } - resp, err := u.service.VerifyEmail(&verifyEmailReq) + userIDInt, err := strconv.Atoi(userID) + if err != nil || userIDInt <= 0 { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + + resp, err := u.service.GetUser(&dto.GetUserReq{UserID: uint(userIDInt)}) + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + resps.Ok(c, resps.Success, resp.User) +} + +func (u *userType) UpdateUser(ctx context.Context, c *app.RequestContext) { + userID := c.Param("id") + if userID == "" { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + userIDInt, err := strconv.Atoi(userID) + if err != nil || userIDInt <= 0 { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + var updateUserReq dto.UpdateUserReq + if err := c.BindAndValidate(&updateUserReq); err != nil { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + updateUserReq.ID = uint(userIDInt) + currentUser := ctxutils.GetCurrentUser(ctx) + if currentUser == nil { + resps.UnAuthorized(c, resps.ErrUnauthorized) + return + } + if currentUser.ID != updateUserReq.ID { + resps.Forbidden(c, resps.ErrForbidden) + return + } + resp, err := u.service.UpdateUser(&updateUserReq) if err != nil { serviceErr := errs.AsServiceError(err) resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) @@ -103,12 +156,17 @@ func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) { resps.Ok(c, resps.Success, resp) } -func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) { - c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true) - c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) -} - -func (u *userType) clearTokenCookie(c *app.RequestContext) { - c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) - c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) +func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) { + var verifyEmailReq dto.VerifyEmailReq + if err := c.BindAndValidate(&verifyEmailReq); err != nil { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + resp, err := u.service.RequestVerifyEmail(&verifyEmailReq) + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + resps.Ok(c, resps.Success, resp) } diff --git a/internal/ctxutils/token.go b/internal/ctxutils/token.go new file mode 100644 index 0000000..195042a --- /dev/null +++ b/internal/ctxutils/token.go @@ -0,0 +1,22 @@ +package ctxutils + +import ( + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/utils" +) + +func SetTokenCookie(c *app.RequestContext, token string) { + c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true) +} + +func SetTokenAndRefreshTokenCookie(c *app.RequestContext, token, refreshToken string) { + c.SetCookie("token", token, utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true) + c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) +} + +func ClearTokenAndRefreshTokenCookie(c *app.RequestContext) { + c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) + c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) +} diff --git a/internal/ctxutils/user.go b/internal/ctxutils/user.go new file mode 100644 index 0000000..b1382bc --- /dev/null +++ b/internal/ctxutils/user.go @@ -0,0 +1,19 @@ +package ctxutils + +import ( + "context" + "github.com/snowykami/neo-blog/internal/model" + "github.com/snowykami/neo-blog/internal/repo" +) + +func GetCurrentUser(ctx context.Context) *model.User { + userIDValue := ctx.Value("user_id").(uint) + if userIDValue <= 0 { + return nil + } + user, err := repo.User.GetUserByID(userIDValue) + if err != nil || user == nil || user.ID == 0 { + return nil + } + return user +} diff --git a/internal/dto/dto.go b/internal/dto/dto.go index a74ed41..76d3a17 100644 --- a/internal/dto/dto.go +++ b/internal/dto/dto.go @@ -1,7 +1 @@ package dto - -type BaseResp struct { - Code int `json:"code"` - Message string `json:"message"` - Data any `json:"data"` -} diff --git a/internal/dto/user.go b/internal/dto/user.go index dab5ed2..0811780 100644 --- a/internal/dto/user.go +++ b/internal/dto/user.go @@ -1,6 +1,7 @@ package dto type UserDto struct { + ID uint `json:"id"` // 用户ID Username string `json:"username"` // 用户名 Nickname string `json:"nickname"` AvatarUrl string `json:"avatar_url"` // 头像URL @@ -8,6 +9,13 @@ type UserDto struct { Gender string `json:"gender"` Role string `json:"role"` } + +type OidcConfigDto struct { + Name string `json:"name"` // OIDC配置名称 + DisplayName string `json:"display_name"` // OIDC配置显示名称 + Icon string `json:"icon"` // OIDC配置图标URL + LoginUrl string `json:"login_url"` // OIDC登录URL +} type UserLoginReq struct { Username string `json:"username"` // username or email Password string `json:"password"` @@ -40,3 +48,39 @@ type VerifyEmailReq struct { type VerifyEmailResp struct { Success bool `json:"success"` // 验证码发送成功与否 } + +type OidcLoginReq struct { + Name string `json:"name"` // OIDC配置名称 + Code string `json:"code"` // OIDC授权码 + State string `json:"state"` +} + +type OidcLoginResp struct { + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` + User *UserDto `json:"user"` +} + +type ListOidcConfigResp struct { + OidcConfigs []OidcConfigDto `json:"oidc_configs"` // OIDC配置列表 +} + +type GetUserReq struct { + UserID uint `json:"user_id"` +} + +type GetUserResp struct { + User *UserDto `json:"user"` // 用户信息 +} + +type UpdateUserReq struct { + ID uint `json:"id"` + Username string `json:"username"` + Nickname string `json:"nickname"` + AvatarUrl string `json:"avatar_url"` + Gender string `json:"gender"` +} + +type UpdateUserResp struct { + User *UserDto `json:"user"` // 更新后的用户信息 +} diff --git a/internal/middleware/auth.go b/internal/middleware/auth.go index 730bc8a..14c2aef 100644 --- a/internal/middleware/auth.go +++ b/internal/middleware/auth.go @@ -3,10 +3,53 @@ package middleware import ( "context" "github.com/cloudwego/hertz/pkg/app" + "github.com/snowykami/neo-blog/internal/ctxutils" + "github.com/snowykami/neo-blog/internal/repo" + "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/resps" + "github.com/snowykami/neo-blog/pkg/utils" + "time" ) func UseAuth() app.HandlerFunc { return func(ctx context.Context, c *app.RequestContext) { - // TODO: Implement authentication logic here + // For cookie + token := string(c.Cookie("token")) + refreshToken := string(c.Cookie("refresh_token")) + tokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(token) + if err == nil && tokenClaims != nil { + ctx = context.WithValue(ctx, "user_id", tokenClaims.UserID) + c.Next(ctx) + return + } + + // token 失效 使用refresh token重新签发和鉴权 + refreshTokenClaims, err := utils.Jwt.ParseJsonWebTokenWithoutState(refreshToken) + if err == nil && refreshTokenClaims != nil { + ok, err := isStatefulJwtValid(refreshTokenClaims) + if err == nil && ok { + ctx = context.WithValue(ctx, "user_id", refreshTokenClaims.UserID) // 修改这里,使用refreshTokenClaims + c.Next(ctx) + newTokenClaims := utils.Jwt.NewClaims(refreshTokenClaims.UserID, refreshTokenClaims.SessionKey, refreshTokenClaims.Stateful, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour))) + newToken, err := newTokenClaims.ToString() + if err == nil { + ctxutils.SetTokenCookie(c, newToken) + } else { + resps.InternalServerError(c, resps.ErrInternalServerError) + } + return + } + } + + // 所有认证方式都失败,返回未授权错误 + resps.UnAuthorized(c, resps.ErrUnauthorized) + c.Abort() } } + +func isStatefulJwtValid(claims *utils.Claims) (bool, error) { + if !claims.Stateful { + return true, nil + } + return repo.Session.IsSessionValid(claims.SessionKey) +} diff --git a/internal/model/oidc_config.go b/internal/model/oidc_config.go index f372ebe..28c4082 100644 --- a/internal/model/oidc_config.go +++ b/internal/model/oidc_config.go @@ -2,6 +2,7 @@ package model import ( "fmt" + "github.com/snowykami/neo-blog/internal/dto" "gorm.io/gorm" "resty.dev/v3" "time" @@ -9,14 +10,13 @@ import ( type OidcConfig struct { gorm.Model - Name string `gorm:"uniqueIndex"` - ClientID string `gorm:"column:client_id"` // 客户端ID - ClientSecret string `gorm:"column:client_secret"` // 客户端密钥 - DisplayName string `gorm:"column:display_name"` // 显示名称,例如:轻雪通行证 - GroupsClaim *string `gorm:"default:groups"` // 组声明,默认为:"groups" - Icon *string `gorm:"column:icon"` // 图标url,为空则使用内置默认图标 - OidcDiscoveryUrl string `gorm:"column:oidc_discovery_url"` // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration - Enabled bool `gorm:"column:enabled;default:true"` // 是否启用 + Name string `gorm:"uniqueIndex"` // OIDC配置名称,唯一 + ClientID string // 客户端ID + ClientSecret string // 客户端密钥 + DisplayName string // 显示名称,例如:轻雪通行证 + Icon string // 图标url,为空则使用内置默认图标 + OidcDiscoveryUrl string // OpenID自动发现URL,例如 :https://pass.liteyuki.icu/.well-known/openid-configuration + Enabled bool `gorm:"default:true"` // 是否启用 // 以下字段为自动获取字段,每次更新配置时自动填充 Issuer string AuthorizationEndpoint string @@ -68,11 +68,6 @@ func updateOidcConfigFromUrl(url string) (*oidcDiscoveryResp, error) { } func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) { - // 设置默认值 - if o.GroupsClaim == nil { - defaultGroupsClaim := "groups" - o.GroupsClaim = &defaultGroupsClaim - } // 只有在创建新记录或更新 OidcDiscoveryUrl 字段时才更新端点信息 if tx.Statement.Changed("OidcDiscoveryUrl") { discoveryResp, err := updateOidcConfigFromUrl(o.OidcDiscoveryUrl) @@ -87,3 +82,12 @@ func (o *OidcConfig) BeforeSave(tx *gorm.DB) (err error) { } return nil } + +// ToDto 不包含LoginUrl,在service层自行实现 +func (o *OidcConfig) ToDto() *dto.OidcConfigDto { + return &dto.OidcConfigDto{ + Name: o.Name, + DisplayName: o.DisplayName, + Icon: o.Icon, + } +} diff --git a/internal/model/user.go b/internal/model/user.go index 525cfe7..652db49 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -16,8 +16,17 @@ type User struct { Password string // 密码,存储加密后的值 } +type UserOpenID struct { + gorm.Model + UserID uint `gorm:"uniqueIndex"` + User User `gorm:"foreignKey:UserID;references:ID"` + Issuer string `gorm:"index"` // OIDC Issuer + Sub string `gorm:"index"` // OIDC Sub openid +} + func (user *User) ToDto() *dto.UserDto { return &dto.UserDto{ + ID: user.ID, Username: user.Username, Nickname: user.Nickname, AvatarUrl: user.AvatarUrl, diff --git a/internal/repo/user.go b/internal/repo/user.go index dcab751..9a3b6c9 100644 --- a/internal/repo/user.go +++ b/internal/repo/user.go @@ -6,7 +6,7 @@ type userRepo struct{} var User = &userRepo{} -func (user *userRepo) GetByUsername(username string) (*model.User, error) { +func (user *userRepo) GetUserByUsername(username string) (*model.User, error) { var userModel model.User if err := GetDB().Where("username = ?", username).First(&userModel).Error; err != nil { return nil, err @@ -14,7 +14,7 @@ func (user *userRepo) GetByUsername(username string) (*model.User, error) { return &userModel, nil } -func (user *userRepo) GetByEmail(email string) (*model.User, error) { +func (user *userRepo) GetUserByEmail(email string) (*model.User, error) { var userModel model.User if err := GetDB().Where("email = ?", email).First(&userModel).Error; err != nil { return nil, err @@ -22,7 +22,15 @@ func (user *userRepo) GetByEmail(email string) (*model.User, error) { return &userModel, nil } -func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User, error) { +func (user *userRepo) GetUserByID(id uint) (*model.User, error) { + var userModel model.User + if err := GetDB().Where("id = ?", id).First(&userModel).Error; err != nil { + return nil, err + } + return &userModel, nil +} + +func (user *userRepo) GetUserByUsernameOrEmail(usernameOrEmail string) (*model.User, error) { var userModel model.User if err := GetDB().Where("username = ? OR email = ?", usernameOrEmail, usernameOrEmail).First(&userModel).Error; err != nil { return nil, err @@ -30,14 +38,14 @@ func (user *userRepo) GetByUsernameOrEmail(usernameOrEmail string) (*model.User, return &userModel, nil } -func (user *userRepo) Create(userModel *model.User) error { +func (user *userRepo) CreateUser(userModel *model.User) error { if err := GetDB().Create(userModel).Error; err != nil { return err } return nil } -func (user *userRepo) Update(userModel *model.User) error { +func (user *userRepo) UpdateUser(userModel *model.User) error { if err := GetDB().Updates(userModel).Error; err != nil { return err } @@ -59,3 +67,40 @@ func (user *userRepo) CheckEmailExists(email string) (bool, error) { } return count > 0, nil } + +func (user *userRepo) ListOidcConfigs(onlyEnabled bool) ([]model.OidcConfig, error) { + var configs []model.OidcConfig + if onlyEnabled { + if err := GetDB().Where("enabled = ?", true).Find(&configs).Error; err != nil { + return nil, err + } + } else { + if err := GetDB().Find(&configs).Error; err != nil { + return nil, err + } + } + return configs, nil +} + +func (user *userRepo) GetOidcConfigByName(name string) (*model.OidcConfig, error) { + var config model.OidcConfig + if err := GetDB().Where("name = ?", name).First(&config).Error; err != nil { + return nil, err + } + return &config, nil +} + +func (user *userRepo) CreateOrUpdateUserOpenID(userOpenID *model.UserOpenID) error { + if err := GetDB().Save(userOpenID).Error; err != nil { + return err + } + return nil +} + +func (user *userRepo) GetUserOpenIDByIssuerAndSub(issuer, sub string) (*model.UserOpenID, error) { + var userOpenID model.UserOpenID + if err := GetDB().Where("issuer = ? AND sub = ?", issuer, sub).First(&userOpenID).Error; err != nil { + return nil, err + } + return &userOpenID, nil +} diff --git a/internal/router/apiv1/user.go b/internal/router/apiv1/user.go index 032959e..96578a0 100644 --- a/internal/router/apiv1/user.go +++ b/internal/router/apiv1/user.go @@ -16,9 +16,8 @@ func registerUserRoutes(group *route.RouterGroup) { userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList) userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin) - userGroupWithoutAuth.GET("/u/:id", v1.User.Get) + userGroupWithoutAuth.GET("/u/:id", v1.User.GetUser) userGroup.POST("/logout", v1.User.Logout) - userGroup.PUT("/u/:id", v1.User.Update) - userGroup.DELETE("/u/:id", v1.User.Delete) + userGroup.PUT("/u/:id", v1.User.UpdateUser) } } diff --git a/internal/service/user.go b/internal/service/user.go index b666bf2..6b5083d 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -1,6 +1,7 @@ package service import ( + "errors" "github.com/sirupsen/logrus" "github.com/snowykami/neo-blog/internal/dto" "github.com/snowykami/neo-blog/internal/model" @@ -9,25 +10,20 @@ import ( "github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/utils" + "gorm.io/gorm" "net/http" + "strings" "time" ) -type UserService interface { - UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error) - UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error) - VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) - // TODO impl other user-related methods +type UserService struct{} + +func NewUserService() *UserService { + return &UserService{} } -type userService struct{} - -func NewUserService() UserService { - return &userService{} -} - -func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) { - user, err := repo.User.GetByUsernameOrEmail(req.Username) +func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) { + user, err := repo.User.GetUserByUsernameOrEmail(req.Username) if err != nil { return nil, errs.ErrInternalServer } @@ -35,26 +31,14 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro return nil, errs.ErrNotFound } if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) { - - token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour))) - tokenString, err := token.ToString() - if err != nil { - return nil, errs.ErrInternalServer - } - - refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour))) - refreshTokenString, err := refreshToken.ToString() - if err != nil { - return nil, errs.ErrInternalServer - } - // 对refresh token进行持久化存储 - err = repo.Session.SaveSession(refreshToken.SessionKey) + token, refreshToken, err := s.generate2Token(user.ID) if err != nil { + logrus.Errorln("Failed to generate tokens:", err) return nil, errs.ErrInternalServer } resp := &dto.UserLoginResp{ - Token: tokenString, - RefreshToken: refreshTokenString, + Token: token, + RefreshToken: refreshToken, User: user.ToDto(), } return resp, nil @@ -63,16 +47,19 @@ func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, erro } } -func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) { +func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) { // 验证邮箱验证码 if !utils.Env.GetAsBool("ENABLE_REGISTER", true) { return nil, errs.ErrForbidden } if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) { - kv := utils.KV.GetInstance() - verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + ":" + req.Email) - if !ok || verificationCode != req.VerificationCode { - return nil, errs.ErrInvalidCredentials + ok, err := s.verifyEmail(req.Email, req.VerificationCode) + if err != nil { + logrus.Errorln("Failed to verify email:", err) + return nil, errs.ErrInternalServer + } + if !ok { + return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil) } } // 检查用户名或邮箱是否已存在 @@ -101,39 +88,28 @@ func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterR Role: "user", Password: hashedPassword, } - err = repo.User.Create(newUser) + err = repo.User.CreateUser(newUser) if err != nil { return nil, errs.ErrInternalServer } // 生成访问令牌和刷新令牌 - token := utils.Jwt.NewClaims(newUser.ID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour))) - tokenString, err := token.ToString() + token, refreshToken, err := s.generate2Token(newUser.ID) if err != nil { + logrus.Errorln("Failed to generate tokens:", err) return nil, errs.ErrInternalServer } - refreshToken := utils.Jwt.NewClaims(newUser.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour))) - refreshTokenString, err := refreshToken.ToString() - if err != nil { - return nil, errs.ErrInternalServer - } - // 对refresh token进行持久化存储 - err = repo.Session.SaveSession(refreshToken.SessionKey) - if err != nil { - return nil, errs.ErrInternalServer - } - resp := &dto.UserRegisterResp{ - Token: tokenString, - RefreshToken: refreshTokenString, + Token: token, + RefreshToken: refreshToken, User: newUser.ToDto(), } return resp, nil } -func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) { +func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) { generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef") kv := utils.KV.GetInstance() - kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10) + kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10) template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{}) if err != nil { @@ -149,3 +125,220 @@ func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp } return &dto.VerifyEmailResp{Success: true}, nil } + +func (s *UserService) ListOidcConfigs() (*dto.ListOidcConfigResp, error) { + enabledOidcConfigs, err := repo.User.ListOidcConfigs(true) + if err != nil { + return nil, errs.ErrInternalServer + } + var oidcConfigsDtos []dto.OidcConfigDto + + for _, oidcConfig := range enabledOidcConfigs { + state := utils.Strings.GenerateRandomString(32) + kvStore := utils.KV.GetInstance() + kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute) + oidcConfigsDtos = append(oidcConfigsDtos, dto.OidcConfigDto{ + Name: oidcConfig.Name, + DisplayName: oidcConfig.DisplayName, + Icon: oidcConfig.Icon, + LoginUrl: utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{ + "client_id": oidcConfig.ClientID, + "redirect_uri": strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/") + constant.OidcUri + oidcConfig.Name, + "response_type": "code", + "scope": "openid email profile", + "state": state, + }), + }) + } + return &dto.ListOidcConfigResp{ + OidcConfigs: oidcConfigsDtos, + }, nil +} + +func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) { + // 验证state + kvStore := utils.KV.GetInstance() + storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State) + if !ok || storedName != req.Name { + return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil) + } + // 获取OIDC配置 + oidcConfig, err := repo.User.GetOidcConfigByName(req.Name) + if err != nil { + return nil, errs.ErrInternalServer + } + if oidcConfig == nil { + return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil) + } + // 请求访问令牌 + tokenResp, err := utils.Oidc.RequestToken( + oidcConfig.TokenEndpoint, + oidcConfig.ClientID, + oidcConfig.ClientSecret, + req.Code, + strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name, + ) + if err != nil { + logrus.Errorln("Failed to request OIDC token:", err) + return nil, errs.ErrInternalServer + } + userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken) + if err != nil { + logrus.Errorln("Failed to request OIDC user info:", err) + return nil, errs.ErrInternalServer + } + + // 绑定过登录 + userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub) + if !errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errs.ErrInternalServer + } + if userOpenID != nil { + user, err := repo.User.GetUserByID(userOpenID.UserID) + if err != nil { + return nil, errs.ErrInternalServer + } + token, refreshToken, err := s.generate2Token(user.ID) + if err != nil { + logrus.Errorln("Failed to generate tokens:", err) + return nil, errs.ErrInternalServer + } + resp := &dto.OidcLoginResp{ + Token: token, + RefreshToken: refreshToken, + User: user.ToDto(), + } + return resp, nil + } else { + // 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户 + user, err := repo.User.GetUserByEmail(userInfo.Email) + if !errors.Is(err, gorm.ErrRecordNotFound) { + logrus.Errorln("Failed to get user by email:", err) + return nil, errs.ErrInternalServer + } + if user != nil { + userOpenID = &model.UserOpenID{ + UserID: user.ID, + Issuer: oidcConfig.Issuer, + Sub: userInfo.Sub, + } + err = repo.User.CreateOrUpdateUserOpenID(userOpenID) + if err != nil { + logrus.Errorln("Failed to create or update user OpenID:", err) + return nil, errs.ErrInternalServer + } + token, refreshToken, err := s.generate2Token(user.ID) + if err != nil { + logrus.Errorln("Failed to generate tokens:", err) + return nil, errs.ErrInternalServer + } + resp := &dto.OidcLoginResp{ + Token: token, + RefreshToken: refreshToken, + User: user.ToDto(), + } + return resp, nil + } else { + user = &model.User{ + Username: userInfo.Name, + Nickname: userInfo.Name, + AvatarUrl: userInfo.Picture, + Email: userInfo.Email, + } + err = repo.User.CreateUser(user) + if err != nil { + logrus.Errorln("Failed to create user:", err) + return nil, errs.ErrInternalServer + } + userOpenID = &model.UserOpenID{ + UserID: user.ID, + Issuer: oidcConfig.Issuer, + Sub: userInfo.Sub, + } + err = repo.User.CreateOrUpdateUserOpenID(userOpenID) + if err != nil { + logrus.Errorln("Failed to create or update user OpenID:", err) + return nil, errs.ErrInternalServer + } + token, refreshToken, err := s.generate2Token(user.ID) + if err != nil { + logrus.Errorln("Failed to generate tokens:", err) + return nil, errs.ErrInternalServer + } + resp := &dto.OidcLoginResp{ + Token: token, + RefreshToken: refreshToken, + User: user.ToDto(), + } + return resp, nil + } + } +} + +func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) { + if req.UserID == 0 { + return nil, errs.New(http.StatusBadRequest, "user_id is required", nil) + } + user, err := repo.User.GetUserByID(req.UserID) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errs.ErrNotFound + } + logrus.Errorln("Failed to get user by ID:", err) + return nil, errs.ErrInternalServer + } + if user == nil { + return nil, errs.ErrNotFound + } + return &dto.GetUserResp{ + User: user.ToDto(), + }, nil +} + +func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) { + user := &model.User{ + Model: gorm.Model{ + ID: req.ID, + }, + Username: req.Username, + Nickname: req.Nickname, + Gender: req.Gender, + AvatarUrl: req.AvatarUrl, + } + err := repo.User.UpdateUser(user) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, errs.ErrNotFound + } + logrus.Errorln("Failed to update user:", err) + return nil, errs.ErrInternalServer + } + return &dto.UpdateUserResp{}, nil +} + +func (s *UserService) generate2Token(userID uint) (string, string, error) { + token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour))) + tokenString, err := token.ToString() + if err != nil { + return "", "", errs.ErrInternalServer + } + refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour))) + refreshTokenString, err := refreshToken.ToString() + if err != nil { + return "", "", errs.ErrInternalServer + } + err = repo.Session.SaveSession(refreshToken.SessionKey) + if err != nil { + return "", "", errs.ErrInternalServer + } + return tokenString, refreshTokenString, nil +} + +func (s *UserService) verifyEmail(email, code string) (bool, error) { + kv := utils.KV.GetInstance() + verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email) + if !ok || verificationCode != code { + return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil) + } + return true, nil +} diff --git a/pkg/constant/constant.go b/pkg/constant/constant.go index 67d0b2c..00f9c3a 100644 --- a/pkg/constant/constant.go +++ b/pkg/constant/constant.go @@ -10,6 +10,7 @@ const ( RoleUser = "user" RoleAdmin = "admin" + EnvKeyBaseUrl = "BASE_URL" // 环境变量:基础URL EnvKeyMode = "MODE" // 环境变量:运行模式 EnvKeyJwtSecrete = "JWT_SECRET" // 环境变量:JWT密钥 EnvKeyPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐 @@ -18,5 +19,9 @@ const ( EnvKeyRefreshTokenDuration = "REFRESH_TOKEN_DURATION" // 环境变量:刷新令牌有效期 EnvKeyRefreshTokenDurationWithRemember = "REFRESH_TOKEN_DURATION_WITH_REMEMBER" // 环境变量:记住我刷新令牌有效期 - KVKeyEmailVerificationCode = "email_verification_code" // KV存储:邮箱验证码 + KVKeyEmailVerificationCode = "email_verification_code:" // KV存储:邮箱验证码 + KVKeyOidcState = "oidc_state:" // KV存储:OIDC状态 + + OidcUri = "/user/oidc/login" // OIDC登录URI + DefaultBaseUrl = "http://localhost:3000" // 默认BaseUrl ) diff --git a/pkg/resps/resps.go b/pkg/resps/resps.go index 409c168..06467e1 100644 --- a/pkg/resps/resps.go +++ b/pkg/resps/resps.go @@ -11,6 +11,7 @@ func Custom(c *app.RequestContext, status int, message string, data any) { "message": message, "data": data, }) + c.Abort() } func Ok(c *app.RequestContext, message string, data any) { diff --git a/pkg/utils/jwt.go b/pkg/utils/jwt.go index e56e9a4..f5f0334 100644 --- a/pkg/utils/jwt.go +++ b/pkg/utils/jwt.go @@ -36,7 +36,7 @@ func (c *Claims) ToString() (string, error) { return token.SignedString([]byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret"))) } -// ParseJsonWebTokenWithoutState 解析JWT令牌,不对有状态的Token进行状态检查 +// ParseJsonWebTokenWithoutState 解析JWT令牌,仅检查无状态下是否valid,不对有状态的Token进行状态检查 func (j *jwtUtils) ParseJsonWebTokenWithoutState(tokenString string) (*Claims, error) { claims := &Claims{} token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { diff --git a/pkg/utils/oidc.go b/pkg/utils/oidc.go new file mode 100644 index 0000000..2568807 --- /dev/null +++ b/pkg/utils/oidc.go @@ -0,0 +1,72 @@ +package utils + +import ( + "fmt" + "resty.dev/v3" +) + +type oidcUtils struct{} + +var Oidc = oidcUtils{} + +// RequestToken 请求访问令牌 +func (u *oidcUtils) RequestToken(tokenEndpoint, clientID, clientSecret, code, redirectURI string) (*TokenResponse, error) { + client := resty.New() + tokenResp, err := client.R(). + SetFormData(map[string]string{ + "grant_type": "authorization_code", + "client_id": clientID, + "client_secret": clientSecret, + "code": code, + "redirect_uri": redirectURI, + }). + SetHeader("Accept", "application/json"). + SetResult(&TokenResponse{}). + Post(tokenEndpoint) + + if err != nil { + return nil, err + } + + if tokenResp.StatusCode() != 200 { + return nil, fmt.Errorf("状态码: %d,响应: %s", tokenResp.StatusCode(), tokenResp.String()) + } + return tokenResp.Result().(*TokenResponse), nil +} + +// RequestUserInfo 请求用户信息 +func (u *oidcUtils) RequestUserInfo(userInfoEndpoint, accessToken string) (*UserInfo, error) { + client := resty.New() + userInfoResp, err := client.R(). + SetHeader("Authorization", "Bearer "+accessToken). + SetHeader("Accept", "application/json"). + SetResult(&UserInfo{}). + Get(userInfoEndpoint) + + if err != nil { + return nil, err + } + + if userInfoResp.StatusCode() != 200 { + return nil, fmt.Errorf("状态码: %d,响应: %s", userInfoResp.StatusCode(), userInfoResp.String()) + } + + return userInfoResp.Result().(*UserInfo), nil +} + +type TokenResponse struct { + AccessToken string `json:"access_token"` + TokenType string `json:"token_type"` + ExpiresIn int `json:"expires_in"` + IDToken string `json:"id_token,omitempty"` + RefreshToken string `json:"refresh_token,omitempty"` +} + +// UserInfo 定义用户信息结构 +type UserInfo struct { + Sub string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` + Picture string `json:"picture,omitempty"` + Groups []string `json:"groups,omitempty"` // 可选字段,OIDC提供的用户组信息 +} diff --git a/pkg/utils/url.go b/pkg/utils/url.go new file mode 100644 index 0000000..78b50f3 --- /dev/null +++ b/pkg/utils/url.go @@ -0,0 +1,20 @@ +package utils + +import "net/url" + +type urlUtils struct{} + +var Url = &urlUtils{} + +func (u *urlUtils) BuildUrl(baseUrl string, queryParams map[string]string) string { + newUrl, err := url.Parse(baseUrl) + if err != nil { + return baseUrl + } + q := newUrl.Query() + for key, value := range queryParams { + q.Set(key, value) + } + newUrl.RawQuery = q.Encode() + return newUrl.String() +}