From a0d215fa2e158586b93a6f1cc615bbc37bc8c58e Mon Sep 17 00:00:00 2001 From: Snowykami Date: Tue, 22 Jul 2025 08:50:16 +0800 Subject: [PATCH] :zap: implement email verification feature, add captcha validation middleware, and enhance user authentication flow --- go.mod | 7 +- go.sum | 6 + internal/controller/v1/user.go | 75 ++++++++++- internal/dto/user.go | 14 ++- internal/middleware/captcha.go | 27 +++- internal/model/session.go | 8 ++ internal/model/user.go | 16 ++- internal/repo/session.go | 23 ++++ internal/router/apiv1/user.go | 1 + internal/router/router.go | 15 +-- internal/service/user.go | 70 +++++++++-- .../assets/email/verification-code.tmpl | 70 +++++++++++ internal/static/embed.go | 31 +++++ internal/static/embed_test.go | 18 +++ pkg/constant/constant.go | 22 +++- pkg/errs/errors.go | 58 +++++++++ pkg/errs/errors_test.go | 19 +++ pkg/resps/texts.go | 10 +- pkg/utils/captcha.go | 93 ++++++++++++++ pkg/utils/email.go | 84 +++++++++++++ pkg/utils/env.go | 18 ++- pkg/utils/json_web_token.go | 52 ++++++++ pkg/utils/kvstore.go | 119 ++++++++++++++++++ pkg/utils/password.go | 10 +- pkg/utils/request_context.go | 1 + pkg/utils/strings.go | 27 ++++ 26 files changed, 844 insertions(+), 50 deletions(-) create mode 100644 internal/model/session.go create mode 100644 internal/repo/session.go create mode 100644 internal/static/assets/email/verification-code.tmpl create mode 100644 internal/static/embed.go create mode 100644 internal/static/embed_test.go create mode 100644 pkg/errs/errors.go create mode 100644 pkg/errs/errors_test.go create mode 100644 pkg/utils/captcha.go create mode 100644 pkg/utils/email.go create mode 100644 pkg/utils/json_web_token.go create mode 100644 pkg/utils/kvstore.go create mode 100644 pkg/utils/request_context.go create mode 100644 pkg/utils/strings.go diff --git a/go.mod b/go.mod index 04d38eb..72f4f4e 100644 --- a/go.mod +++ b/go.mod @@ -5,10 +5,14 @@ go 1.23.3 require ( github.com/cloudwego/hertz v0.10.1 github.com/glebarez/sqlite v1.11.0 + github.com/golang-jwt/jwt/v5 v5.2.3 github.com/joho/godotenv v1.5.1 github.com/sirupsen/logrus v1.9.3 + golang.org/x/crypto v0.31.0 + gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df gorm.io/driver/postgres v1.6.0 gorm.io/gorm v1.30.0 + resty.dev/v3 v3.0.0-beta.3 ) require ( @@ -38,15 +42,14 @@ require ( github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect - golang.org/x/crypto v0.31.0 // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/text v0.21.0 // indirect google.golang.org/protobuf v1.34.1 // indirect + gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc // indirect modernc.org/libc v1.22.5 // indirect modernc.org/mathutil v1.5.0 // indirect modernc.org/memory v1.5.0 // indirect modernc.org/sqlite v1.23.1 // indirect - resty.dev/v3 v3.0.0-beta.3 // indirect ) diff --git a/go.sum b/go.sum index 951b4e8..b902d98 100644 --- a/go.sum +++ b/go.sum @@ -26,6 +26,8 @@ github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9g github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/golang-jwt/jwt/v5 v5.2.3 h1:kkGXqQOBSDDWRhWNXTFpqGSCMyh/PLnqUvMGJPDJDs0= +github.com/golang-jwt/jwt/v5 v5.2.3/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -154,8 +156,12 @@ golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg= google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc h1:2gGKlE2+asNV9m7xrywl36YYNnBG5ZQ0r/BOOxqPpmk= +gopkg.in/alexcesaro/quotedprintable.v3 v3.0.0-20150716171945-2caba252f4dc/go.mod h1:m7x9LTH6d71AHyAX77c9yqWCCa3UKHcVEj9y7hAtKDk= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df h1:n7WqCuqOuCbNr617RXOY0AWRXxgwEyPp2z+p0+hgMuE= +gopkg.in/gomail.v2 v2.0.0-20160411212932-81ebce5c23df/go.mod h1:LRQQ+SO6ZHR7tOkpBDuZnXENFzX8qRjMDMyPD6BRkCw= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/controller/v1/user.go b/internal/controller/v1/user.go index 42bb509..3cea327 100644 --- a/internal/controller/v1/user.go +++ b/internal/controller/v1/user.go @@ -3,26 +3,66 @@ package v1 import ( "context" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/snowykami/neo-blog/internal/dto" + "github.com/snowykami/neo-blog/internal/service" + "github.com/snowykami/neo-blog/pkg/constant" + "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/resps" + "github.com/snowykami/neo-blog/pkg/utils" ) -type userType struct{} +type userType struct { + service service.UserService +} -var User = new(userType) +var User = &userType{ + service: service.NewUserService(), +} func (u *userType) Login(ctx context.Context, c *app.RequestContext) { - var userLoginReq dto.UserLoginReq - if err := c.BindAndValidate(&userLoginReq); err != nil { + var userLoginReq *dto.UserLoginReq + if err := c.BindAndValidate(userLoginReq); err != nil { resps.BadRequest(c, resps.ErrParamInvalid) } + resp, err := u.service.UserLogin(userLoginReq) + + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + } + if resp == nil { + resps.UnAuthorized(c, resps.ErrInvalidCredentials) + } else { + u.setTokenCookie(c, resp.Token, resp.RefreshToken) + resps.Ok(c, resps.Success, resp) + } } func (u *userType) Register(ctx context.Context, c *app.RequestContext) { + var userRegisterReq *dto.UserRegisterReq + if err := c.BindAndValidate(userRegisterReq); err != nil { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + resp, err := u.service.UserRegister(userRegisterReq) + + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + if resp == nil { + resps.UnAuthorized(c, resps.ErrInvalidCredentials) + return + } + u.setTokenCookie(c, resp.Token, resp.RefreshToken) + resps.Ok(c, resps.Success, resp) } func (u *userType) Logout(ctx context.Context, c *app.RequestContext) { - // TODO: Impl + u.clearTokenCookie(c) + resps.Ok(c, resps.Success, nil) } func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) { @@ -44,3 +84,28 @@ func (u *userType) Update(ctx context.Context, c *app.RequestContext) { func (u *userType) Delete(ctx context.Context, c *app.RequestContext) { // TODO: Impl } + +func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) { + var verifyEmailReq *dto.VerifyEmailReq + if err := c.BindAndValidate(verifyEmailReq); err != nil { + resps.BadRequest(c, resps.ErrParamInvalid) + return + } + resp, err := u.service.VerifyEmail(verifyEmailReq) + if err != nil { + serviceErr := errs.AsServiceError(err) + resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return + } + resps.Ok(c, resps.Success, resp) +} + +func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) { + c.SetCookie("token", token, utils.Env.GetenvAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true) + c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) +} + +func (u *userType) clearTokenCookie(c *app.RequestContext) { + c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) + c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true) +} diff --git a/internal/dto/user.go b/internal/dto/user.go index 4671a21..688a60c 100644 --- a/internal/dto/user.go +++ b/internal/dto/user.go @@ -14,9 +14,9 @@ type UserLoginReq struct { } type UserLoginResp struct { - Token string `json:"token"` - RefreshToken string `json:"refresh_token"` - User UserDto `json:"user"` + Token string `json:"token"` + RefreshToken string `json:"refresh_token"` + User *UserDto `json:"user"` } type UserRegisterReq struct { @@ -31,3 +31,11 @@ type UserRegisterResp struct { RefreshToken string `json:"refresh_token"` // 刷新令牌 User UserDto `json:"user"` // 用户信息 } + +type VerifyEmailReq struct { + Email string `json:"email"` // 邮箱地址 +} + +type VerifyEmailResp struct { + Success bool `json:"success"` // 验证码发送成功与否 +} diff --git a/internal/middleware/captcha.go b/internal/middleware/captcha.go index 989f096..46002c8 100644 --- a/internal/middleware/captcha.go +++ b/internal/middleware/captcha.go @@ -3,10 +3,35 @@ package middleware import ( "context" "github.com/cloudwego/hertz/pkg/app" + "github.com/sirupsen/logrus" + "github.com/snowykami/neo-blog/pkg/resps" + "github.com/snowykami/neo-blog/pkg/utils" ) +// UseCaptcha 中间件函数,用于X-Captcha-Token验证码 func UseCaptcha() app.HandlerFunc { + captchaConfig := utils.Captcha.GetCaptchaConfigFromEnv() return func(ctx context.Context, c *app.RequestContext) { - // TODO: Implement captcha validation logic here + CaptchaToken := string(c.GetHeader("X-Captcha-Token")) + if utils.IsDevMode && CaptchaToken == utils.Env.Get("CAPTCHA_DEV_PASSCODE", "dev_passcode") { + // 开发模式直接通过密钥 + c.Next(ctx) + return + } + ok, err := utils.Captcha.VerifyCaptcha(captchaConfig, CaptchaToken) + if err != nil { + logrus.Error("Captcha verification error:", err) + resps.InternalServerError(c, "Captcha verification failed") + c.Abort() + return + } + if !ok { + logrus.Warn("Captcha verification failed for token:", CaptchaToken) + resps.Forbidden(c, "Captcha verification failed") + c.Abort() + return + } + c.Next(ctx) // 如果验证码验证成功,则继续下一个处理程序 + return } } diff --git a/internal/model/session.go b/internal/model/session.go new file mode 100644 index 0000000..1d3cd9b --- /dev/null +++ b/internal/model/session.go @@ -0,0 +1,8 @@ +package model + +import "gorm.io/gorm" + +type Session struct { + gorm.Model + SessionKey string `gorm:"uniqueIndex"` // 会话密钥,唯一索引 +} diff --git a/internal/model/user.go b/internal/model/user.go index ff531bb..78f377e 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -1,6 +1,9 @@ package model -import "gorm.io/gorm" +import ( + "github.com/snowykami/neo-blog/internal/dto" + "gorm.io/gorm" +) type User struct { gorm.Model @@ -13,3 +16,14 @@ type User struct { Password string // 密码,存储加密后的值 } + +func (user *User) ToDto() *dto.UserDto { + return &dto.UserDto{ + Username: user.Username, + Nickname: user.Nickname, + AvatarUrl: user.AvatarUrl, + Email: user.Email, + Gender: user.Gender, + Role: user.Role, + } +} diff --git a/internal/repo/session.go b/internal/repo/session.go new file mode 100644 index 0000000..961b113 --- /dev/null +++ b/internal/repo/session.go @@ -0,0 +1,23 @@ +package repo + +import "github.com/snowykami/neo-blog/internal/model" + +type sessionRepo struct{} + +var Session = sessionRepo{} + +func (s *sessionRepo) SaveSession(sessionKey string) error { + session := &model.Session{ + SessionKey: sessionKey, + } + return db.Create(session).Error +} + +func (s *sessionRepo) IsSessionValid(sessionKey string) (bool, error) { + var count int64 + err := db.Model(&model.Session{}).Where("session_key = ?", sessionKey).Count(&count).Error + if err != nil { + return false, err + } + return count > 0, nil +} diff --git a/internal/router/apiv1/user.go b/internal/router/apiv1/user.go index 700f82a..032959e 100644 --- a/internal/router/apiv1/user.go +++ b/internal/router/apiv1/user.go @@ -13,6 +13,7 @@ func registerUserRoutes(group *route.RouterGroup) { { userGroupWithoutAuthNeedsCaptcha.POST("/login", v1.User.Login) userGroupWithoutAuthNeedsCaptcha.POST("/register", v1.User.Register) + userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList) userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin) userGroupWithoutAuth.GET("/u/:id", v1.User.Get) diff --git a/internal/router/router.go b/internal/router/router.go index 12d71df..a981426 100644 --- a/internal/router/router.go +++ b/internal/router/router.go @@ -1,25 +1,22 @@ package router import ( - "errors" "github.com/cloudwego/hertz/pkg/app/server" + "github.com/sirupsen/logrus" "github.com/snowykami/neo-blog/internal/router/apiv1" - "github.com/snowykami/neo-blog/pkg/constant" "github.com/snowykami/neo-blog/pkg/utils" ) var h *server.Hertz func Run() error { - mode := utils.Env.Get("MODE", constant.ModeProd) // dev | prod - switch mode { - case constant.ModeProd: + if utils.IsDevMode { + logrus.Infoln("Running in development mode") + return h.Run() + } else { + logrus.Infoln("Running in production mode") h.Spin() return nil - case constant.ModeDev: - return h.Run() - default: - return errors.New("unknown mode: " + mode) } } diff --git a/internal/service/user.go b/internal/service/user.go index 727a27d..a811ae4 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -1,16 +1,20 @@ package service import ( - "errors" + "github.com/sirupsen/logrus" "github.com/snowykami/neo-blog/internal/dto" "github.com/snowykami/neo-blog/internal/repo" + "github.com/snowykami/neo-blog/internal/static" "github.com/snowykami/neo-blog/pkg/constant" - "github.com/snowykami/neo-blog/pkg/resps" + "github.com/snowykami/neo-blog/pkg/errs" "github.com/snowykami/neo-blog/pkg/utils" + "time" ) type UserService interface { - UserLogin(dto *dto.UserLoginReq) (*dto.UserLoginResp, error) + UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error) + UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error) + VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) // TODO impl other user-related methods } @@ -20,17 +24,63 @@ func NewUserService() UserService { return &userService{} } -func (s *userService) UserLogin(dto *dto.UserLoginReq) (*dto.UserLoginResp, error) { - user, err := repo.User.GetByUsernameOrEmail(dto.Username) +func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) { + user, err := repo.User.GetByUsernameOrEmail(req.Username) if err != nil { - return nil, err + return nil, errs.ErrInternalServer } if user == nil { - return nil, errors.New(resps.ErrNotFound) + return nil, errs.ErrNotFound } - if utils.Password.VerifyPassword(dto.Password, user.Password, utils.Env.Get(constant.EnvVarPasswordSalt, "default_salt")) { - return nil, nil // TODO: Generate JWT token and return it in the response + if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) { + + token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetenvAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour))) + tokenString, err := token.ToString() + if err != nil { + return nil, errs.ErrInternalServer + } + + refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetenvAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour))) + refreshTokenString, err := refreshToken.ToString() + if err != nil { + return nil, errs.ErrInternalServer + } + // 对refresh token进行持久化存储 + err = repo.Session.SaveSession(refreshToken.SessionKey) + if err != nil { + return nil, errs.ErrInternalServer + } + resp := &dto.UserLoginResp{ + Token: tokenString, + RefreshToken: refreshTokenString, + User: user.ToDto(), + } + return resp, nil } else { - return nil, errors.New(resps.ErrInvalidCredentials) + return nil, errs.ErrInternalServer } } + +func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) { + return nil, nil +} + +func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) { + generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef") + kv := utils.KV.GetInstance() + kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10) + + template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{}) + if err != nil { + return nil, errs.ErrInternalServer + } + if utils.IsDevMode { + logrus.Infoln("%s's verification code is %s", req.Email, generatedVerificationCode) + } + err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true) + + if err != nil { + return nil, errs.ErrInternalServer + } + return &dto.VerifyEmailResp{Success: true}, nil +} diff --git a/internal/static/assets/email/verification-code.tmpl b/internal/static/assets/email/verification-code.tmpl new file mode 100644 index 0000000..9d42a3a --- /dev/null +++ b/internal/static/assets/email/verification-code.tmpl @@ -0,0 +1,70 @@ + + + + + {{.Title}} + + + +
+
+

欢迎使用 {{.Title}}

+
+
+

尊敬的用户 {{.Email}},您好!

+

{{.Details}} 以下是您的验证码:

+
{{.VerifyCode}}
+

请在 {{.Expire}} 分钟内使用此验证码完成验证。

+
+ +
+ + \ No newline at end of file diff --git a/internal/static/embed.go b/internal/static/embed.go new file mode 100644 index 0000000..c0cd989 --- /dev/null +++ b/internal/static/embed.go @@ -0,0 +1,31 @@ +package static + +import ( + "bytes" + "embed" + "fmt" + "html/template" +) + +//go:embed assets/* +var AssetsFS embed.FS + +// RenderTemplate 从嵌入的文件系统中读取模板并渲染 +func RenderTemplate(name string, data interface{}) (string, error) { + templatePath := "assets/" + name + templateContent, err := AssetsFS.ReadFile(templatePath) + if err != nil { + return "", fmt.Errorf("读取模板文件失败: %w", err) + } + // 解析模板 + tmpl, err := template.New(name).Parse(string(templateContent)) + if err != nil { + return "", fmt.Errorf("解析模板失败: %w", err) + } + // 渲染模板 + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return "", fmt.Errorf("渲染模板失败: %w", err) + } + return buf.String(), nil +} diff --git a/internal/static/embed_test.go b/internal/static/embed_test.go new file mode 100644 index 0000000..b9d5174 --- /dev/null +++ b/internal/static/embed_test.go @@ -0,0 +1,18 @@ +package static + +import ( + "testing" +) + +func TestRenderTemplate(t *testing.T) { + template, err := RenderTemplate("email/verification-code.tmpl", map[string]interface{}{ + "Title": "Test Page", + "Email": "xxx@.comcom", + "Details": "nihao", + }) + t.Logf(template) + if err != nil { + t.Errorf("渲染模板失败: %v", err) + return + } +} diff --git a/pkg/constant/constant.go b/pkg/constant/constant.go index e163972..67d0b2c 100644 --- a/pkg/constant/constant.go +++ b/pkg/constant/constant.go @@ -1,10 +1,22 @@ package constant const ( - ModeDev = "dev" - ModeProd = "prod" - RoleUser = "user" - RoleAdmin = "admin" + CaptchaTypeDisable = "disable" // 禁用验证码 + CaptchaTypeHCaptcha = "hcaptcha" // HCaptcha验证码 + CaptchaTypeTurnstile = "turnstile" // Turnstile验证码 + CaptchaTypeReCaptcha = "recaptcha" // ReCaptcha验证码 + ModeDev = "dev" + ModeProd = "prod" + RoleUser = "user" + RoleAdmin = "admin" - EnvVarPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐 + EnvKeyMode = "MODE" // 环境变量:运行模式 + EnvKeyJwtSecrete = "JWT_SECRET" // 环境变量:JWT密钥 + EnvKeyPasswordSalt = "PASSWORD_SALT" // 环境变量:密码盐 + EnvKeyTokenDuration = "TOKEN_DURATION" // 环境变量:令牌有效期 + EnvKeyTokenDurationDefault = 300 + EnvKeyRefreshTokenDuration = "REFRESH_TOKEN_DURATION" // 环境变量:刷新令牌有效期 + EnvKeyRefreshTokenDurationWithRemember = "REFRESH_TOKEN_DURATION_WITH_REMEMBER" // 环境变量:记住我刷新令牌有效期 + + KVKeyEmailVerificationCode = "email_verification_code" // KV存储:邮箱验证码 ) diff --git a/pkg/errs/errors.go b/pkg/errs/errors.go new file mode 100644 index 0000000..e147c4d --- /dev/null +++ b/pkg/errs/errors.go @@ -0,0 +1,58 @@ +package errs + +import ( + "errors" + "github.com/cloudwego/hertz/pkg/app" + "net/http" +) + +// ServiceError 业务错误结构 +type ServiceError struct { + Code int // 错误代码 + Message string // 错误消息 + Err error // 原始错误 +} + +func (e *ServiceError) Error() string { + if e.Err != nil { + return e.Message + ": " + e.Err.Error() + } + return e.Message +} + +// 常见业务错误 +var ( + ErrNotFound = &ServiceError{Code: http.StatusNotFound, Message: "not found"} + ErrInvalidCredentials = &ServiceError{Code: http.StatusUnauthorized, Message: "invalid credentials"} + ErrInternalServer = &ServiceError{Code: http.StatusInternalServerError, Message: "internal server error"} + ErrBadRequest = &ServiceError{Code: http.StatusBadRequest, Message: "invalid request parameters"} + ErrForbidden = &ServiceError{Code: http.StatusForbidden, Message: "access forbidden"} +) + +// New 创建自定义错误 +func New(code int, message string, err error) *ServiceError { + return &ServiceError{ + Code: code, + Message: message, + Err: err, + } +} + +// Is 判断错误类型 +func Is(err, target error) bool { + return errors.Is(err, target) +} + +// AsServiceError 将错误转换为ServiceError +func AsServiceError(err error) *ServiceError { + var serviceErr *ServiceError + if errors.As(err, &serviceErr) { + return serviceErr + } + return ErrInternalServer +} + +// HandleError 处理错误并返回HTTP状态码和消息 +func HandleError(c *app.RequestContext, err *ServiceError) { + +} diff --git a/pkg/errs/errors_test.go b/pkg/errs/errors_test.go new file mode 100644 index 0000000..c39ac45 --- /dev/null +++ b/pkg/errs/errors_test.go @@ -0,0 +1,19 @@ +package errs + +import ( + "testing" +) + +func TestAsServiceError(c) { + serviceError := ErrNotFound + err := AsServiceError(serviceError) + if err.Code != serviceError.Code || err.Message != serviceError.Message { + t.Errorf("Expected %v, got %v", serviceError, err) + } + + serviceError = New(520, "Custom error", nil) + err = AsServiceError(serviceError) + if err.Code != serviceError.Code || err.Message != serviceError.Message { + t.Errorf("Expected %v, got %v", serviceError, err) + } +} diff --git a/pkg/resps/texts.go b/pkg/resps/texts.go index 4afde10..d805aa2 100644 --- a/pkg/resps/texts.go +++ b/pkg/resps/texts.go @@ -1,10 +1,12 @@ package resps const ( - ErrParamInvalid = "invalid request parameters" - ErrUnauthorized = "unauthorized access" - ErrForbidden = "access forbidden" - ErrNotFound = "resource not found" + Success = "success" + ErrParamInvalid = "invalid request parameters" + ErrUnauthorized = "unauthorized access" + ErrForbidden = "access forbidden" + ErrNotFound = "resource not found" + ErrInternalServerError = "internal server error" ErrInvalidCredentials = "invalid credentials" ) diff --git a/pkg/utils/captcha.go b/pkg/utils/captcha.go new file mode 100644 index 0000000..ce1f541 --- /dev/null +++ b/pkg/utils/captcha.go @@ -0,0 +1,93 @@ +package utils + +import ( + "fmt" + "github.com/snowykami/neo-blog/pkg/constant" + "resty.dev/v3" +) + +type captchaUtils struct{} + +var Captcha = captchaUtils{} + +type CaptchaConfig struct { + Type string + SiteSecret string // Site secret key for the captcha service + SecretKey string // Secret key for the captcha service +} + +func (c *captchaUtils) GetCaptchaConfigFromEnv() *CaptchaConfig { + return &CaptchaConfig{ + Type: Env.Get("CAPTCHA_TYPE", "disable"), + SiteSecret: Env.Get("CAPTCHA_SITE_SECRET", ""), + SecretKey: Env.Get("CAPTCHA_SECRET_KEY", ""), + } +} + +// VerifyCaptcha 根据提供的配置和令牌验证验证码 +func (c *captchaUtils) VerifyCaptcha(captchaConfig *CaptchaConfig, captchaToken string) (bool, error) { + restyClient := resty.New() + switch captchaConfig.Type { + case constant.CaptchaTypeDisable: + return true, nil + case constant.CaptchaTypeHCaptcha: + result := make(map[string]any) + resp, err := restyClient.R(). + SetFormData(map[string]string{ + "secret": captchaConfig.SecretKey, + "response": captchaToken, + }).SetResult(&result).Post("https://hcaptcha.com/siteverify") + if err != nil { + return false, err + } + if resp.IsError() { + return false, nil + } + fmt.Printf("%#v\n", result) + if success, ok := result["success"].(bool); ok && success { + return true, nil + } else { + return false, nil + } + case constant.CaptchaTypeTurnstile: + result := make(map[string]any) + resp, err := restyClient.R(). + SetFormData(map[string]string{ + "secret": captchaConfig.SecretKey, + "response": captchaToken, + }).SetResult(&result).Post("https://challenges.cloudflare.com/turnstile/v0/siteverify") + if err != nil { + return false, err + } + if resp.IsError() { + return false, nil + } + fmt.Printf("%#v\n", result) + if success, ok := result["success"].(bool); ok && success { + return true, nil + } else { + return false, nil + } + case constant.CaptchaTypeReCaptcha: + result := make(map[string]any) + resp, err := restyClient.R(). + SetFormData(map[string]string{ + "secret": captchaConfig.SecretKey, + "response": captchaToken, + }).SetResult(&result).Post("https://www.google.com/recaptcha/api/siteverify") + if err != nil { + return false, err + } + if resp.IsError() { + return false, nil + } + fmt.Printf("%#v\n", result) + if success, ok := result["success"].(bool); ok && success { + return true, nil + } else { + return false, nil + } + default: + return false, fmt.Errorf("invalid captcha type: %s", captchaConfig.Type) + } +} diff --git a/pkg/utils/email.go b/pkg/utils/email.go new file mode 100644 index 0000000..1b1632c --- /dev/null +++ b/pkg/utils/email.go @@ -0,0 +1,84 @@ +package utils + +import ( + "bytes" + "crypto/tls" + "fmt" + "gopkg.in/gomail.v2" + "html/template" +) + +type emailUtils struct{} + +var Email = emailUtils{} + +type EmailConfig struct { + Enable bool // 邮箱启用状态 + Username string // 邮箱用户名 + Address string // 邮箱地址 + Host string // 邮箱服务器地址 + Port int // 邮箱服务器端口 + Password string // 邮箱密码 + SSL bool // 是否使用SSL +} + +// SendTemplate 发送HTML模板,从配置文件中读取邮箱配置,支持上下文控制 +func (e *emailUtils) SendTemplate(emailConfig *EmailConfig, target, subject, htmlTemplate string, data map[string]interface{}) error { + // 使用Go的模板系统处理HTML模板 + tmpl, err := template.New("email").Parse(htmlTemplate) + if err != nil { + return fmt.Errorf("解析模板失败: %w", err) + } + + var buf bytes.Buffer + if err := tmpl.Execute(&buf, data); err != nil { + return fmt.Errorf("执行模板失败: %w", err) + } + + // 发送处理后的HTML内容 + return e.SendEmail(emailConfig, target, subject, buf.String(), true) +} + +// SendEmail 使用gomail库发送邮件 +func (e *emailUtils) SendEmail(emailConfig *EmailConfig, target, subject, content string, isHTML bool) error { + if !emailConfig.Enable { + return nil + } + // 创建新邮件 + m := gomail.NewMessage() + m.SetHeader("From", emailConfig.Address) + m.SetHeader("To", target) + m.SetHeader("Subject", subject) + // 设置内容类型 + if isHTML { + m.SetBody("text/html", content) + } else { + m.SetBody("text/plain", content) + } + // 创建发送器 + d := gomail.NewDialer(emailConfig.Host, emailConfig.Port, emailConfig.Username, emailConfig.Password) + // 配置SSL/TLS + if emailConfig.SSL { + d.SSL = true + } else { + // 对于非SSL但需要STARTTLS的情况 + d.TLSConfig = &tls.Config{InsecureSkipVerify: true} + } + // 发送邮件 + if err := d.DialAndSend(m); err != nil { + return fmt.Errorf("发送邮件失败: %w", err) + } + return nil +} + +func (e *emailUtils) GetEmailConfigFromEnv() *EmailConfig { + return &EmailConfig{ + Enable: Env.GetenvAsBool("EMAIL_ENABLE", false), + Username: Env.Get("EMAIL_USERNAME", ""), + Address: Env.Get("EMAIL_ADDRESS", ""), + Host: Env.Get("EMAIL_HOST", "smtp.example.com"), + Port: Env.GetenvAsInt("EMAIL_PORT", 587), + Password: Env.Get("EMAIL_PASSWORD", ""), + SSL: Env.GetenvAsBool("EMAIL_SSL", true), + } +} diff --git a/pkg/utils/env.go b/pkg/utils/env.go index a081e27..3fbba78 100644 --- a/pkg/utils/env.go +++ b/pkg/utils/env.go @@ -2,19 +2,27 @@ package utils import ( "github.com/joho/godotenv" + "github.com/snowykami/neo-blog/pkg/constant" "os" "strconv" ) +var ( + IsDevMode = false +) + func init() { _ = godotenv.Load() + + // Init env + IsDevMode = Env.Get(constant.EnvKeyMode, constant.ModeDev) == constant.ModeDev } -type envType struct{} +type envUtils struct{} -var Env envType +var Env envUtils -func (e *envType) Get(key string, defaultValue ...string) string { +func (e *envUtils) Get(key string, defaultValue ...string) string { value := os.Getenv(key) if value == "" && len(defaultValue) > 0 { return defaultValue[0] @@ -22,7 +30,7 @@ func (e *envType) Get(key string, defaultValue ...string) string { return value } -func (e *envType) GetenvAsInt(key string, defaultValue ...int) int { +func (e *envUtils) GetenvAsInt(key string, defaultValue ...int) int { value := os.Getenv(key) if value == "" && len(defaultValue) > 0 { return defaultValue[0] @@ -34,7 +42,7 @@ func (e *envType) GetenvAsInt(key string, defaultValue ...int) int { return intValue } -func (e *envType) GetenvAsBool(key string, defaultValue ...bool) bool { +func (e *envUtils) GetenvAsBool(key string, defaultValue ...bool) bool { value := os.Getenv(key) if value == "" && len(defaultValue) > 0 { return defaultValue[0] diff --git a/pkg/utils/json_web_token.go b/pkg/utils/json_web_token.go new file mode 100644 index 0000000..e56e9a4 --- /dev/null +++ b/pkg/utils/json_web_token.go @@ -0,0 +1,52 @@ +package utils + +import ( + "github.com/golang-jwt/jwt/v5" + "github.com/snowykami/neo-blog/pkg/constant" + "time" +) + +type jwtUtils struct{} + +var Jwt = jwtUtils{} + +type Claims struct { + jwt.RegisteredClaims + UserID uint `json:"user_id"` + SessionKey string `json:"session_key"` // 会话ID,仅在有状态Token中使用 + Stateful bool `json:"stateful"` // 是否为有状态Token +} + +// NewClaims 创建一个新的Claims实例,对于无状态 +func (j *jwtUtils) NewClaims(userID uint, sessionKey string, stateful bool, duration time.Duration) *Claims { + return &Claims{ + UserID: userID, + SessionKey: sessionKey, + Stateful: stateful, + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + ExpiresAt: jwt.NewNumericDate(time.Now().Add(duration)), + }, + } +} + +// ToString 将Claims转换为JWT字符串 +func (c *Claims) ToString() (string, error) { + token := jwt.NewWithClaims(jwt.SigningMethodHS256, c) + return token.SignedString([]byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret"))) +} + +// ParseJsonWebTokenWithoutState 解析JWT令牌,不对有状态的Token进行状态检查 +func (j *jwtUtils) ParseJsonWebTokenWithoutState(tokenString string) (*Claims, error) { + claims := &Claims{} + token, err := jwt.ParseWithClaims(tokenString, claims, func(token *jwt.Token) (any, error) { + return []byte(Env.Get(constant.EnvKeyJwtSecrete, "default_jwt_secret")), nil + }) + if err != nil { + return nil, err + } + if !token.Valid { + return nil, jwt.ErrSignatureInvalid + } + return claims, nil +} diff --git a/pkg/utils/kvstore.go b/pkg/utils/kvstore.go new file mode 100644 index 0000000..5225c64 --- /dev/null +++ b/pkg/utils/kvstore.go @@ -0,0 +1,119 @@ +package utils + +import ( + "sync" + "time" +) + +type kvStoreUtils struct{} + +var KV kvStoreUtils + +// KVStore 是一个简单的内存键值存储系统 +type KVStore struct { + data map[string]storeItem + mutex sync.RWMutex +} + +// storeItem 代表存储的单个数据项 +type storeItem struct { + value interface{} + expiration int64 // Unix时间戳,0表示永不过期 +} + +// 全局单例 +var ( + kvStore *KVStore + kvStoreOnce sync.Once +) + +// GetInstance 获取KVStore单例实例 +func (kv *kvStoreUtils) GetInstance() *KVStore { + kvStoreOnce.Do(func() { + kvStore = &KVStore{ + data: make(map[string]storeItem), + } + // 启动清理过期项的协程 + go kvStore.startCleanupTimer() + }) + return kvStore +} + +// Set 设置键值对,可选指定过期时间 +func (s *KVStore) Set(key string, value interface{}, ttl time.Duration) { + s.mutex.Lock() + defer s.mutex.Unlock() + + var expiration int64 + if ttl > 0 { + expiration = time.Now().Add(ttl).Unix() + } + + s.data[key] = storeItem{ + value: value, + expiration: expiration, + } +} + +// Get 获取键对应的值,如果键不存在或已过期则返回(nil, false) +func (s *KVStore) Get(key string) (interface{}, bool) { + s.mutex.RLock() + defer s.mutex.RUnlock() + + item, exists := s.data[key] + if !exists { + return nil, false + } + + // 检查是否过期 + if item.expiration > 0 && time.Now().Unix() > item.expiration { + return nil, false + } + + return item.value, true +} + +// Delete 删除键值对 +func (s *KVStore) Delete(key string) { + s.mutex.Lock() + defer s.mutex.Unlock() + + delete(s.data, key) +} + +// Clear 清空所有键值对 +func (s *KVStore) Clear() { + s.mutex.Lock() + defer s.mutex.Unlock() + + s.data = make(map[string]storeItem) +} + +// startCleanupTimer 启动定期清理过期项的计时器 +func (s *KVStore) startCleanupTimer() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + s.cleanup() + } +} + +// cleanup 清理过期的数据项 +func (s *KVStore) cleanup() { + s.mutex.Lock() + defer s.mutex.Unlock() + + now := time.Now().Unix() + for key, item := range s.data { + if item.expiration > 0 && now > item.expiration { + delete(s.data, key) + } + } +} + +// secureRand 生成0到max-1之间的安全随机数 +func secureRand(max int) int { + // 简单实现,可以根据需要使用crypto/rand替换 + return int(time.Now().UnixNano() % int64(max)) +} diff --git a/pkg/utils/password.go b/pkg/utils/password.go index 8800b26..7ae7135 100644 --- a/pkg/utils/password.go +++ b/pkg/utils/password.go @@ -6,13 +6,13 @@ import ( "golang.org/x/crypto/bcrypt" ) -type PasswordType struct { +type PasswordUtils struct { } -var Password = PasswordType{} +var Password = PasswordUtils{} // HashPassword 密码哈希函数 -func (u *PasswordType) HashPassword(password string, salt string) (string, error) { +func (u *PasswordUtils) HashPassword(password string, salt string) (string, error) { saltedPassword := Password.addSalt(password, salt) hashedPassword, err := bcrypt.GenerateFromPassword([]byte(saltedPassword), bcrypt.DefaultCost) if err != nil { @@ -22,7 +22,7 @@ func (u *PasswordType) HashPassword(password string, salt string) (string, error } // VerifyPassword 验证密码 -func (u *PasswordType) VerifyPassword(password, hashedPassword string, salt string) bool { +func (u *PasswordUtils) VerifyPassword(password, hashedPassword string, salt string) bool { if len(hashedPassword) == 0 || len(salt) == 0 { // 防止oidc空密码出问题 return false @@ -33,7 +33,7 @@ func (u *PasswordType) VerifyPassword(password, hashedPassword string, salt stri } // addSalt 加盐函数 -func (u *PasswordType) addSalt(password string, salt string) string { +func (u *PasswordUtils) addSalt(password string, salt string) string { combined := password + salt hash := sha256.New() hash.Write([]byte(combined)) diff --git a/pkg/utils/request_context.go b/pkg/utils/request_context.go new file mode 100644 index 0000000..d4b585b --- /dev/null +++ b/pkg/utils/request_context.go @@ -0,0 +1 @@ +package utils diff --git a/pkg/utils/strings.go b/pkg/utils/strings.go new file mode 100644 index 0000000..ad396f2 --- /dev/null +++ b/pkg/utils/strings.go @@ -0,0 +1,27 @@ +package utils + +import "math/rand" + +type stringsUtils struct{} + +var Strings = stringsUtils{} + +func (s *stringsUtils) GenerateRandomString(length int) string { + const charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + result := make([]byte, length) + for i := range result { + result[i] = charset[rand.Intn(len(charset))] + } + return string(result) +} + +func (s *stringsUtils) GenerateRandomStringWithCharset(length int, charset string) string { + if charset == "" { + charset = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + } + result := make([]byte, length) + for i := range result { + result[i] = charset[rand.Intn(len(charset))] + } + return string(result) +}