mirror of
https://github.com/snowykami/neo-blog.git
synced 2025-09-26 11:06:23 +00:00
fix: Closes #6 优化登录时的状态反馈,且登录失败后会刷新验证码
This commit is contained in:
@ -1,400 +1,400 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/snowykami/neo-blog/internal/dto"
|
"github.com/snowykami/neo-blog/internal/dto"
|
||||||
"github.com/snowykami/neo-blog/internal/model"
|
"github.com/snowykami/neo-blog/internal/model"
|
||||||
"github.com/snowykami/neo-blog/internal/repo"
|
"github.com/snowykami/neo-blog/internal/repo"
|
||||||
"github.com/snowykami/neo-blog/internal/static"
|
"github.com/snowykami/neo-blog/internal/static"
|
||||||
"github.com/snowykami/neo-blog/pkg/constant"
|
"github.com/snowykami/neo-blog/pkg/constant"
|
||||||
"github.com/snowykami/neo-blog/pkg/errs"
|
"github.com/snowykami/neo-blog/pkg/errs"
|
||||||
"github.com/snowykami/neo-blog/pkg/utils"
|
"github.com/snowykami/neo-blog/pkg/utils"
|
||||||
"gorm.io/gorm"
|
"gorm.io/gorm"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserService struct{}
|
type UserService struct{}
|
||||||
|
|
||||||
func NewUserService() *UserService {
|
func NewUserService() *UserService {
|
||||||
return &UserService{}
|
return &UserService{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
func (s *UserService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||||
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
|
user, err := repo.User.GetUserByUsernameOrEmail(req.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
logrus.Warnf("User not found: %s", req.Username)
|
logrus.Warnf("User not found: %s", req.Username)
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
||||||
token, refreshToken, err := s.generate2Token(user.ID)
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to generate tokens:", err)
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.UserLoginResp{
|
resp := &dto.UserLoginResp{
|
||||||
Token: token,
|
Token: token,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
} else {
|
} else {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.New(http.StatusUnauthorized, "Invalid username or password", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
func (s *UserService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
||||||
// 验证邮箱验证码
|
// 验证邮箱验证码
|
||||||
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
|
if !utils.Env.GetAsBool("ENABLE_REGISTER", true) {
|
||||||
return nil, errs.ErrForbidden
|
return nil, errs.ErrForbidden
|
||||||
}
|
}
|
||||||
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
|
if utils.Env.GetAsBool("ENABLE_EMAIL_VERIFICATION", true) {
|
||||||
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
|
ok, err := s.verifyEmail(req.Email, req.VerificationCode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to verify email:", err)
|
logrus.Errorln("Failed to verify email:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
return nil, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 检查用户名或邮箱是否已存在
|
// 检查用户名或邮箱是否已存在
|
||||||
usernameExist, err := repo.User.CheckUsernameExists(req.Username)
|
usernameExist, err := repo.User.CheckUsernameExists(req.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
emailExist, err := repo.User.CheckEmailExists(req.Email)
|
emailExist, err := repo.User.CheckEmailExists(req.Email)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if usernameExist || emailExist {
|
if usernameExist || emailExist {
|
||||||
return nil, errs.New(http.StatusConflict, "Username or email already exists", nil)
|
return nil, errs.New(http.StatusConflict, "Username or email already exists", nil)
|
||||||
}
|
}
|
||||||
// 创建新用户
|
// 创建新用户
|
||||||
hashedPassword, err := utils.Password.HashPassword(req.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt"))
|
hashedPassword, err := utils.Password.HashPassword(req.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to hash password:", err)
|
logrus.Errorln("Failed to hash password:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
newUser := &model.User{
|
newUser := &model.User{
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Nickname: req.Nickname,
|
Nickname: req.Nickname,
|
||||||
Email: req.Email,
|
Email: req.Email,
|
||||||
Gender: "",
|
Gender: "",
|
||||||
Role: "user",
|
Role: "user",
|
||||||
Password: hashedPassword,
|
Password: hashedPassword,
|
||||||
}
|
}
|
||||||
err = repo.User.CreateUser(newUser)
|
err = repo.User.CreateUser(newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
// 创建默认管理员账户
|
// 创建默认管理员账户
|
||||||
if newUser.ID == 1 {
|
if newUser.ID == 1 {
|
||||||
newUser.Role = constant.RoleAdmin
|
newUser.Role = constant.RoleAdmin
|
||||||
err = repo.User.UpdateUser(newUser)
|
err = repo.User.UpdateUser(newUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to update user role to admin:", err)
|
logrus.Errorln("Failed to update user role to admin:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// 生成访问令牌和刷新令牌
|
// 生成访问令牌和刷新令牌
|
||||||
token, refreshToken, err := s.generate2Token(newUser.ID)
|
token, refreshToken, err := s.generate2Token(newUser.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to generate tokens:", err)
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.UserRegisterResp{
|
resp := &dto.UserRegisterResp{
|
||||||
Token: token,
|
Token: token,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
User: newUser.ToDto(),
|
User: newUser.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
func (s *UserService) RequestVerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
||||||
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
||||||
kv := utils.KV.GetInstance()
|
kv := utils.KV.GetInstance()
|
||||||
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
|
kv.Set(constant.KVKeyEmailVerificationCode+req.Email, generatedVerificationCode, time.Minute*10)
|
||||||
|
|
||||||
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if utils.IsDevMode {
|
if utils.IsDevMode {
|
||||||
logrus.Infof("%s's verification code is %s", req.Email, generatedVerificationCode)
|
logrus.Infof("%s's verification code is %s", req.Email, generatedVerificationCode)
|
||||||
}
|
}
|
||||||
err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true)
|
err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
return &dto.VerifyEmailResp{Success: true}, nil
|
return &dto.VerifyEmailResp{Success: true}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) ListOidcConfigs() ([]dto.UserOidcConfigDto, error) {
|
func (s *UserService) ListOidcConfigs() ([]dto.UserOidcConfigDto, error) {
|
||||||
enabledOidcConfigs, err := repo.Oidc.ListOidcConfigs(true)
|
enabledOidcConfigs, err := repo.Oidc.ListOidcConfigs(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
var oidcConfigsDtos []dto.UserOidcConfigDto
|
var oidcConfigsDtos []dto.UserOidcConfigDto
|
||||||
|
|
||||||
for _, oidcConfig := range enabledOidcConfigs {
|
for _, oidcConfig := range enabledOidcConfigs {
|
||||||
state := utils.Strings.GenerateRandomString(32)
|
state := utils.Strings.GenerateRandomString(32)
|
||||||
kvStore := utils.KV.GetInstance()
|
kvStore := utils.KV.GetInstance()
|
||||||
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
|
kvStore.Set(constant.KVKeyOidcState+state, oidcConfig.Name, 5*time.Minute)
|
||||||
loginUrl := utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
loginUrl := utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
||||||
"client_id": oidcConfig.ClientID,
|
"client_id": oidcConfig.ClientID,
|
||||||
"redirect_uri": fmt.Sprintf("%s%s%s/%sREDIRECT_BACK", // 这个大占位符给前端替换用的,替换时也要uri编码因为是层层包的
|
"redirect_uri": fmt.Sprintf("%s%s%s/%sREDIRECT_BACK", // 这个大占位符给前端替换用的,替换时也要uri编码因为是层层包的
|
||||||
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
|
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
|
||||||
constant.ApiSuffix,
|
constant.ApiSuffix,
|
||||||
constant.OidcUri,
|
constant.OidcUri,
|
||||||
oidcConfig.Name,
|
oidcConfig.Name,
|
||||||
),
|
),
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"scope": "openid email profile",
|
"scope": "openid email profile",
|
||||||
"state": state,
|
"state": state,
|
||||||
})
|
})
|
||||||
|
|
||||||
if oidcConfig.Type == constant.OidcProviderTypeMisskey {
|
if oidcConfig.Type == constant.OidcProviderTypeMisskey {
|
||||||
// Misskey OIDC 特殊处理
|
// Misskey OIDC 特殊处理
|
||||||
loginUrl = utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
loginUrl = utils.Url.BuildUrl(oidcConfig.AuthorizationEndpoint, map[string]string{
|
||||||
"client_id": oidcConfig.ClientID,
|
"client_id": oidcConfig.ClientID,
|
||||||
"redirect_uri": fmt.Sprintf("%s%s%s/%s", // 这个大占位符给前端替换用的,替换时也要uri编码因为是层层包的
|
"redirect_uri": fmt.Sprintf("%s%s%s/%s", // 这个大占位符给前端替换用的,替换时也要uri编码因为是层层包的
|
||||||
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
|
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/"),
|
||||||
constant.ApiSuffix,
|
constant.ApiSuffix,
|
||||||
constant.OidcUri,
|
constant.OidcUri,
|
||||||
oidcConfig.Name,
|
oidcConfig.Name,
|
||||||
),
|
),
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
"scope": "read:account",
|
"scope": "read:account",
|
||||||
"state": state,
|
"state": state,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcConfigsDtos = append(oidcConfigsDtos, dto.UserOidcConfigDto{
|
oidcConfigsDtos = append(oidcConfigsDtos, dto.UserOidcConfigDto{
|
||||||
Name: oidcConfig.Name,
|
Name: oidcConfig.Name,
|
||||||
DisplayName: oidcConfig.DisplayName,
|
DisplayName: oidcConfig.DisplayName,
|
||||||
Icon: oidcConfig.Icon,
|
Icon: oidcConfig.Icon,
|
||||||
LoginUrl: loginUrl,
|
LoginUrl: loginUrl,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return oidcConfigsDtos, nil
|
return oidcConfigsDtos, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
|
func (s *UserService) OidcLogin(req *dto.OidcLoginReq) (*dto.OidcLoginResp, error) {
|
||||||
// 验证state
|
// 验证state
|
||||||
kvStore := utils.KV.GetInstance()
|
kvStore := utils.KV.GetInstance()
|
||||||
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
|
storedName, ok := kvStore.Get(constant.KVKeyOidcState + req.State)
|
||||||
if !ok || storedName != req.Name {
|
if !ok || storedName != req.Name {
|
||||||
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
|
return nil, errs.New(http.StatusForbidden, "invalid oidc state", nil)
|
||||||
}
|
}
|
||||||
// 获取OIDC配置
|
// 获取OIDC配置
|
||||||
oidcConfig, err := repo.Oidc.GetOidcConfigByName(req.Name)
|
oidcConfig, err := repo.Oidc.GetOidcConfigByName(req.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if oidcConfig == nil {
|
if oidcConfig == nil {
|
||||||
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
|
return nil, errs.New(http.StatusNotFound, "OIDC configuration not found", nil)
|
||||||
}
|
}
|
||||||
// 请求访问令牌
|
// 请求访问令牌
|
||||||
tokenResp, err := utils.Oidc.RequestToken(
|
tokenResp, err := utils.Oidc.RequestToken(
|
||||||
oidcConfig.TokenEndpoint,
|
oidcConfig.TokenEndpoint,
|
||||||
oidcConfig.ClientID,
|
oidcConfig.ClientID,
|
||||||
oidcConfig.ClientSecret,
|
oidcConfig.ClientSecret,
|
||||||
req.Code,
|
req.Code,
|
||||||
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
|
strings.TrimSuffix(utils.Env.Get(constant.EnvKeyBaseUrl, constant.DefaultBaseUrl), "/")+constant.OidcUri+oidcConfig.Name,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to request OIDC token:", err)
|
logrus.Errorln("Failed to request OIDC token:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
|
userInfo, err := utils.Oidc.RequestUserInfo(oidcConfig.UserInfoEndpoint, tokenResp.AccessToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to request OIDC user info:", err)
|
logrus.Errorln("Failed to request OIDC user info:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
|
|
||||||
// 绑定过登录
|
// 绑定过登录
|
||||||
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
|
userOpenID, err := repo.User.GetUserOpenIDByIssuerAndSub(oidcConfig.Issuer, userInfo.Sub)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if userOpenID != nil {
|
if userOpenID != nil {
|
||||||
user, err := repo.User.GetUserByID(userOpenID.UserID)
|
user, err := repo.User.GetUserByID(userOpenID.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
token, refreshToken, err := s.generate2Token(user.ID)
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to generate tokens:", err)
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.OidcLoginResp{
|
resp := &dto.OidcLoginResp{
|
||||||
Token: token,
|
Token: token,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
} else {
|
} else {
|
||||||
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
|
// 若没有绑定过登录,则先通过邮箱查找用户,若没有再创建新用户
|
||||||
user, err := repo.User.GetUserByEmail(userInfo.Email)
|
user, err := repo.User.GetUserByEmail(userInfo.Email)
|
||||||
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
logrus.Errorln("Failed to get user by email:", err)
|
logrus.Errorln("Failed to get user by email:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if user != nil {
|
if user != nil {
|
||||||
userOpenID = &model.UserOpenID{
|
userOpenID = &model.UserOpenID{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Issuer: oidcConfig.Issuer,
|
Issuer: oidcConfig.Issuer,
|
||||||
Sub: userInfo.Sub,
|
Sub: userInfo.Sub,
|
||||||
}
|
}
|
||||||
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to create or update user OpenID:", err)
|
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
token, refreshToken, err := s.generate2Token(user.ID)
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to generate tokens:", err)
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.OidcLoginResp{
|
resp := &dto.OidcLoginResp{
|
||||||
Token: token,
|
Token: token,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
} else {
|
} else {
|
||||||
user = &model.User{
|
user = &model.User{
|
||||||
Username: userInfo.Name,
|
Username: userInfo.Name,
|
||||||
Nickname: userInfo.Name,
|
Nickname: userInfo.Name,
|
||||||
AvatarUrl: userInfo.Picture,
|
AvatarUrl: userInfo.Picture,
|
||||||
Email: userInfo.Email,
|
Email: userInfo.Email,
|
||||||
}
|
}
|
||||||
err = repo.User.CreateUser(user)
|
err = repo.User.CreateUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to create user:", err)
|
logrus.Errorln("Failed to create user:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
userOpenID = &model.UserOpenID{
|
userOpenID = &model.UserOpenID{
|
||||||
UserID: user.ID,
|
UserID: user.ID,
|
||||||
Issuer: oidcConfig.Issuer,
|
Issuer: oidcConfig.Issuer,
|
||||||
Sub: userInfo.Sub,
|
Sub: userInfo.Sub,
|
||||||
}
|
}
|
||||||
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
err = repo.User.CreateOrUpdateUserOpenID(userOpenID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to create or update user OpenID:", err)
|
logrus.Errorln("Failed to create or update user OpenID:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
token, refreshToken, err := s.generate2Token(user.ID)
|
token, refreshToken, err := s.generate2Token(user.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.Errorln("Failed to generate tokens:", err)
|
logrus.Errorln("Failed to generate tokens:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
resp := &dto.OidcLoginResp{
|
resp := &dto.OidcLoginResp{
|
||||||
Token: token,
|
Token: token,
|
||||||
RefreshToken: refreshToken,
|
RefreshToken: refreshToken,
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}
|
}
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
|
func (s *UserService) GetUser(req *dto.GetUserReq) (*dto.GetUserResp, error) {
|
||||||
if req.UserID == 0 {
|
if req.UserID == 0 {
|
||||||
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
|
return nil, errs.New(http.StatusBadRequest, "user_id is required", nil)
|
||||||
}
|
}
|
||||||
user, err := repo.User.GetUserByID(req.UserID)
|
user, err := repo.User.GetUserByID(req.UserID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
logrus.Errorln("Failed to get user by ID:", err)
|
logrus.Errorln("Failed to get user by ID:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
return &dto.GetUserResp{
|
return &dto.GetUserResp{
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) GetUserByUsername(req *dto.GetUserByUsernameReq) (*dto.GetUserResp, error) {
|
func (s *UserService) GetUserByUsername(req *dto.GetUserByUsernameReq) (*dto.GetUserResp, error) {
|
||||||
if req.Username == "" {
|
if req.Username == "" {
|
||||||
return nil, errs.New(http.StatusBadRequest, "username is required", nil)
|
return nil, errs.New(http.StatusBadRequest, "username is required", nil)
|
||||||
}
|
}
|
||||||
user, err := repo.User.GetUserByUsername(req.Username)
|
user, err := repo.User.GetUserByUsername(req.Username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
logrus.Errorln("Failed to get user by username:", err)
|
logrus.Errorln("Failed to get user by username:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
if user == nil {
|
if user == nil {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
return &dto.GetUserResp{
|
return &dto.GetUserResp{
|
||||||
User: user.ToDto(),
|
User: user.ToDto(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
|
func (s *UserService) UpdateUser(req *dto.UpdateUserReq) (*dto.UpdateUserResp, error) {
|
||||||
user := &model.User{
|
user := &model.User{
|
||||||
Model: gorm.Model{
|
Model: gorm.Model{
|
||||||
ID: req.ID,
|
ID: req.ID,
|
||||||
},
|
},
|
||||||
Username: req.Username,
|
Username: req.Username,
|
||||||
Nickname: req.Nickname,
|
Nickname: req.Nickname,
|
||||||
Gender: req.Gender,
|
Gender: req.Gender,
|
||||||
AvatarUrl: req.AvatarUrl,
|
AvatarUrl: req.AvatarUrl,
|
||||||
}
|
}
|
||||||
err := repo.User.UpdateUser(user)
|
err := repo.User.UpdateUser(user)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||||
return nil, errs.ErrNotFound
|
return nil, errs.ErrNotFound
|
||||||
}
|
}
|
||||||
logrus.Errorln("Failed to update user:", err)
|
logrus.Errorln("Failed to update user:", err)
|
||||||
return nil, errs.ErrInternalServer
|
return nil, errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
return &dto.UpdateUserResp{}, nil
|
return &dto.UpdateUserResp{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) generate2Token(userID uint) (string, string, error) {
|
func (s *UserService) generate2Token(userID uint) (string, string, error) {
|
||||||
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault))*time.Second)
|
token := utils.Jwt.NewClaims(userID, "", false, time.Duration(utils.Env.GetAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault))*time.Second)
|
||||||
tokenString, err := token.ToString()
|
tokenString, err := token.ToString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errs.ErrInternalServer
|
return "", "", errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, constant.EnvKeyRefreshTokenDurationDefault))*time.Second)
|
refreshToken := utils.Jwt.NewClaims(userID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetAsInt(constant.EnvKeyRefreshTokenDuration, constant.EnvKeyRefreshTokenDurationDefault))*time.Second)
|
||||||
refreshTokenString, err := refreshToken.ToString()
|
refreshTokenString, err := refreshToken.ToString()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errs.ErrInternalServer
|
return "", "", errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", "", errs.ErrInternalServer
|
return "", "", errs.ErrInternalServer
|
||||||
}
|
}
|
||||||
return tokenString, refreshTokenString, nil
|
return tokenString, refreshTokenString, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *UserService) verifyEmail(email, code string) (bool, error) {
|
func (s *UserService) verifyEmail(email, code string) (bool, error) {
|
||||||
kv := utils.KV.GetInstance()
|
kv := utils.KV.GetInstance()
|
||||||
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
|
verificationCode, ok := kv.Get(constant.KVKeyEmailVerificationCode + email)
|
||||||
if !ok || verificationCode != code {
|
if !ok || verificationCode != code {
|
||||||
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
return false, errs.New(http.StatusForbidden, "Invalid email verification code", nil)
|
||||||
}
|
}
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
@ -20,6 +20,7 @@ import { useRouter, useSearchParams } from "next/navigation"
|
|||||||
import { useTranslations } from "next-intl"
|
import { useTranslations } from "next-intl"
|
||||||
import Captcha from "../common/captcha"
|
import Captcha from "../common/captcha"
|
||||||
import { CaptchaProvider } from "@/models/captcha"
|
import { CaptchaProvider } from "@/models/captcha"
|
||||||
|
import { toast } from "sonner"
|
||||||
|
|
||||||
export function LoginForm({
|
export function LoginForm({
|
||||||
className,
|
className,
|
||||||
@ -47,7 +48,7 @@ export function LoginForm({
|
|||||||
setOidcConfigs(res.data || []) // 确保是数组
|
setOidcConfigs(res.data || []) // 确保是数组
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error("Error fetching OIDC configs:", error)
|
toast.error(t("fetch_oidc_configs_failed") + (error?.message ? `: ${error.message}` : ""))
|
||||||
setOidcConfigs([]) // 错误时设置为空数组
|
setOidcConfigs([]) // 错误时设置为空数组
|
||||||
})
|
})
|
||||||
}, [])
|
}, [])
|
||||||
@ -58,7 +59,8 @@ export function LoginForm({
|
|||||||
setCaptchaProps(res.data)
|
setCaptchaProps(res.data)
|
||||||
})
|
})
|
||||||
.catch((error) => {
|
.catch((error) => {
|
||||||
console.error("Error fetching captcha config:", error)
|
toast.error(t("fetch_captcha_config_failed") + (error?.message ? `: ${error.message}` : ""))
|
||||||
|
setCaptchaProps(null)
|
||||||
})
|
})
|
||||||
}, [refreshCaptchaKey])
|
}, [refreshCaptchaKey])
|
||||||
|
|
||||||
@ -67,11 +69,14 @@ export function LoginForm({
|
|||||||
e.preventDefault()
|
e.preventDefault()
|
||||||
userLogin({ username, password, captcha: captchaToken || "" })
|
userLogin({ username, password, captcha: captchaToken || "" })
|
||||||
.then(res => {
|
.then(res => {
|
||||||
console.log("Login successful:", res)
|
toast.success(t("login_success") + ` ${res.data.user.nickname || res.data.user.username}`);
|
||||||
router.push(redirectBack)
|
router.push(redirectBack)
|
||||||
})
|
})
|
||||||
.catch(error => {
|
.catch(error => {
|
||||||
console.error("Login failed:", error)
|
console.log(error)
|
||||||
|
toast.error(t("login_failed") + (error?.response?.data?.message ? `: ${error.response.data.message}` : ""))
|
||||||
|
setRefreshCaptchaKey(k => k + 1)
|
||||||
|
setCaptchaToken(null)
|
||||||
})
|
})
|
||||||
.finally(() => {
|
.finally(() => {
|
||||||
setIsLogging(false)
|
setIsLogging(false)
|
||||||
|
@ -56,7 +56,11 @@
|
|||||||
},
|
},
|
||||||
"Login": {
|
"Login": {
|
||||||
"captcha_error": "验证错误,请重试。",
|
"captcha_error": "验证错误,请重试。",
|
||||||
|
"fetch_captcha_config_failed": "获取验证码失败,请稍后重试。",
|
||||||
|
"fetch_oidc_configs_failed": "获取第三方身份提供者配置失败。",
|
||||||
"logging": "正在登录...",
|
"logging": "正在登录...",
|
||||||
|
"login_success": "登录成功!",
|
||||||
|
"login_failed": "登录失败",
|
||||||
"welcome": "欢迎回来",
|
"welcome": "欢迎回来",
|
||||||
"with_oidc": "使用第三方身份提供者",
|
"with_oidc": "使用第三方身份提供者",
|
||||||
"or_continue_with_local_account": "或使用用户名和密码",
|
"or_continue_with_local_account": "或使用用户名和密码",
|
||||||
@ -69,7 +73,6 @@
|
|||||||
"by_logging_in_you_agree_to_our": "登录即表示你同意我们的",
|
"by_logging_in_you_agree_to_our": "登录即表示你同意我们的",
|
||||||
"terms_of_service": "服务条款",
|
"terms_of_service": "服务条款",
|
||||||
"and": "和",
|
"and": "和",
|
||||||
"privacy_policy": "隐私政策",
|
"privacy_policy": "隐私政策"
|
||||||
"login_failed": "登录失败,请检查你的凭据。"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
Reference in New Issue
Block a user