mirror of
https://github.com/snowykami/neo-blog.git
synced 2025-09-03 15:56:22 +00:00
⚡ implement email verification feature, add captcha validation middleware, and enhance user authentication flow
This commit is contained in:
@ -3,26 +3,66 @@ package v1
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
"github.com/cloudwego/hertz/pkg/protocol"
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"github.com/snowykami/neo-blog/internal/service"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/errs"
|
||||
"github.com/snowykami/neo-blog/pkg/resps"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
)
|
||||
|
||||
type userType struct{}
|
||||
type userType struct {
|
||||
service service.UserService
|
||||
}
|
||||
|
||||
var User = new(userType)
|
||||
var User = &userType{
|
||||
service: service.NewUserService(),
|
||||
}
|
||||
|
||||
func (u *userType) Login(ctx context.Context, c *app.RequestContext) {
|
||||
var userLoginReq dto.UserLoginReq
|
||||
if err := c.BindAndValidate(&userLoginReq); err != nil {
|
||||
var userLoginReq *dto.UserLoginReq
|
||||
if err := c.BindAndValidate(userLoginReq); err != nil {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
}
|
||||
resp, err := u.service.UserLogin(userLoginReq)
|
||||
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
}
|
||||
if resp == nil {
|
||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
||||
} else {
|
||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
}
|
||||
|
||||
func (u *userType) Register(ctx context.Context, c *app.RequestContext) {
|
||||
var userRegisterReq *dto.UserRegisterReq
|
||||
if err := c.BindAndValidate(userRegisterReq); err != nil {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
resp, err := u.service.UserRegister(userRegisterReq)
|
||||
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
if resp == nil {
|
||||
resps.UnAuthorized(c, resps.ErrInvalidCredentials)
|
||||
return
|
||||
}
|
||||
u.setTokenCookie(c, resp.Token, resp.RefreshToken)
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
|
||||
func (u *userType) Logout(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
u.clearTokenCookie(c)
|
||||
resps.Ok(c, resps.Success, nil)
|
||||
}
|
||||
|
||||
func (u *userType) OidcList(ctx context.Context, c *app.RequestContext) {
|
||||
@ -44,3 +84,28 @@ func (u *userType) Update(ctx context.Context, c *app.RequestContext) {
|
||||
func (u *userType) Delete(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Impl
|
||||
}
|
||||
|
||||
func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) {
|
||||
var verifyEmailReq *dto.VerifyEmailReq
|
||||
if err := c.BindAndValidate(verifyEmailReq); err != nil {
|
||||
resps.BadRequest(c, resps.ErrParamInvalid)
|
||||
return
|
||||
}
|
||||
resp, err := u.service.VerifyEmail(verifyEmailReq)
|
||||
if err != nil {
|
||||
serviceErr := errs.AsServiceError(err)
|
||||
resps.Custom(c, serviceErr.Code, serviceErr.Message, nil)
|
||||
return
|
||||
}
|
||||
resps.Ok(c, resps.Success, resp)
|
||||
}
|
||||
|
||||
func (u *userType) setTokenCookie(c *app.RequestContext, token, refreshToken string) {
|
||||
c.SetCookie("token", token, utils.Env.GetenvAsInt(constant.EnvKeyTokenDuration, constant.EnvKeyTokenDurationDefault), "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", refreshToken, -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
||||
|
||||
func (u *userType) clearTokenCookie(c *app.RequestContext) {
|
||||
c.SetCookie("token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
c.SetCookie("refresh_token", "", -1, "/", "", protocol.CookieSameSiteLaxMode, true, true)
|
||||
}
|
||||
|
@ -14,9 +14,9 @@ type UserLoginReq struct {
|
||||
}
|
||||
|
||||
type UserLoginResp struct {
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User UserDto `json:"user"`
|
||||
Token string `json:"token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
User *UserDto `json:"user"`
|
||||
}
|
||||
|
||||
type UserRegisterReq struct {
|
||||
@ -31,3 +31,11 @@ type UserRegisterResp struct {
|
||||
RefreshToken string `json:"refresh_token"` // 刷新令牌
|
||||
User UserDto `json:"user"` // 用户信息
|
||||
}
|
||||
|
||||
type VerifyEmailReq struct {
|
||||
Email string `json:"email"` // 邮箱地址
|
||||
}
|
||||
|
||||
type VerifyEmailResp struct {
|
||||
Success bool `json:"success"` // 验证码发送成功与否
|
||||
}
|
||||
|
@ -3,10 +3,35 @@ package middleware
|
||||
import (
|
||||
"context"
|
||||
"github.com/cloudwego/hertz/pkg/app"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/snowykami/neo-blog/pkg/resps"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
)
|
||||
|
||||
// UseCaptcha 中间件函数,用于X-Captcha-Token验证码
|
||||
func UseCaptcha() app.HandlerFunc {
|
||||
captchaConfig := utils.Captcha.GetCaptchaConfigFromEnv()
|
||||
return func(ctx context.Context, c *app.RequestContext) {
|
||||
// TODO: Implement captcha validation logic here
|
||||
CaptchaToken := string(c.GetHeader("X-Captcha-Token"))
|
||||
if utils.IsDevMode && CaptchaToken == utils.Env.Get("CAPTCHA_DEV_PASSCODE", "dev_passcode") {
|
||||
// 开发模式直接通过密钥
|
||||
c.Next(ctx)
|
||||
return
|
||||
}
|
||||
ok, err := utils.Captcha.VerifyCaptcha(captchaConfig, CaptchaToken)
|
||||
if err != nil {
|
||||
logrus.Error("Captcha verification error:", err)
|
||||
resps.InternalServerError(c, "Captcha verification failed")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
logrus.Warn("Captcha verification failed for token:", CaptchaToken)
|
||||
resps.Forbidden(c, "Captcha verification failed")
|
||||
c.Abort()
|
||||
return
|
||||
}
|
||||
c.Next(ctx) // 如果验证码验证成功,则继续下一个处理程序
|
||||
return
|
||||
}
|
||||
}
|
||||
|
8
internal/model/session.go
Normal file
8
internal/model/session.go
Normal file
@ -0,0 +1,8 @@
|
||||
package model
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type Session struct {
|
||||
gorm.Model
|
||||
SessionKey string `gorm:"uniqueIndex"` // 会话密钥,唯一索引
|
||||
}
|
@ -1,6 +1,9 @@
|
||||
package model
|
||||
|
||||
import "gorm.io/gorm"
|
||||
import (
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
@ -13,3 +16,14 @@ type User struct {
|
||||
|
||||
Password string // 密码,存储加密后的值
|
||||
}
|
||||
|
||||
func (user *User) ToDto() *dto.UserDto {
|
||||
return &dto.UserDto{
|
||||
Username: user.Username,
|
||||
Nickname: user.Nickname,
|
||||
AvatarUrl: user.AvatarUrl,
|
||||
Email: user.Email,
|
||||
Gender: user.Gender,
|
||||
Role: user.Role,
|
||||
}
|
||||
}
|
||||
|
23
internal/repo/session.go
Normal file
23
internal/repo/session.go
Normal file
@ -0,0 +1,23 @@
|
||||
package repo
|
||||
|
||||
import "github.com/snowykami/neo-blog/internal/model"
|
||||
|
||||
type sessionRepo struct{}
|
||||
|
||||
var Session = sessionRepo{}
|
||||
|
||||
func (s *sessionRepo) SaveSession(sessionKey string) error {
|
||||
session := &model.Session{
|
||||
SessionKey: sessionKey,
|
||||
}
|
||||
return db.Create(session).Error
|
||||
}
|
||||
|
||||
func (s *sessionRepo) IsSessionValid(sessionKey string) (bool, error) {
|
||||
var count int64
|
||||
err := db.Model(&model.Session{}).Where("session_key = ?", sessionKey).Count(&count).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return count > 0, nil
|
||||
}
|
@ -13,6 +13,7 @@ func registerUserRoutes(group *route.RouterGroup) {
|
||||
{
|
||||
userGroupWithoutAuthNeedsCaptcha.POST("/login", v1.User.Login)
|
||||
userGroupWithoutAuthNeedsCaptcha.POST("/register", v1.User.Register)
|
||||
userGroupWithoutAuthNeedsCaptcha.POST("/email/verify", v1.User.VerifyEmail) // Send email verification code
|
||||
userGroupWithoutAuth.GET("/oidc/list", v1.User.OidcList)
|
||||
userGroupWithoutAuth.GET("/oidc/login/:name", v1.User.OidcLogin)
|
||||
userGroupWithoutAuth.GET("/u/:id", v1.User.Get)
|
||||
|
@ -1,25 +1,22 @@
|
||||
package router
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/cloudwego/hertz/pkg/app/server"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/snowykami/neo-blog/internal/router/apiv1"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
)
|
||||
|
||||
var h *server.Hertz
|
||||
|
||||
func Run() error {
|
||||
mode := utils.Env.Get("MODE", constant.ModeProd) // dev | prod
|
||||
switch mode {
|
||||
case constant.ModeProd:
|
||||
if utils.IsDevMode {
|
||||
logrus.Infoln("Running in development mode")
|
||||
return h.Run()
|
||||
} else {
|
||||
logrus.Infoln("Running in production mode")
|
||||
h.Spin()
|
||||
return nil
|
||||
case constant.ModeDev:
|
||||
return h.Run()
|
||||
default:
|
||||
return errors.New("unknown mode: " + mode)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,16 +1,20 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/snowykami/neo-blog/internal/dto"
|
||||
"github.com/snowykami/neo-blog/internal/repo"
|
||||
"github.com/snowykami/neo-blog/internal/static"
|
||||
"github.com/snowykami/neo-blog/pkg/constant"
|
||||
"github.com/snowykami/neo-blog/pkg/resps"
|
||||
"github.com/snowykami/neo-blog/pkg/errs"
|
||||
"github.com/snowykami/neo-blog/pkg/utils"
|
||||
"time"
|
||||
)
|
||||
|
||||
type UserService interface {
|
||||
UserLogin(dto *dto.UserLoginReq) (*dto.UserLoginResp, error)
|
||||
UserLogin(*dto.UserLoginReq) (*dto.UserLoginResp, error)
|
||||
UserRegister(*dto.UserRegisterReq) (*dto.UserRegisterResp, error)
|
||||
VerifyEmail(*dto.VerifyEmailReq) (*dto.VerifyEmailResp, error)
|
||||
// TODO impl other user-related methods
|
||||
}
|
||||
|
||||
@ -20,17 +24,63 @@ func NewUserService() UserService {
|
||||
return &userService{}
|
||||
}
|
||||
|
||||
func (s *userService) UserLogin(dto *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||
user, err := repo.User.GetByUsernameOrEmail(dto.Username)
|
||||
func (s *userService) UserLogin(req *dto.UserLoginReq) (*dto.UserLoginResp, error) {
|
||||
user, err := repo.User.GetByUsernameOrEmail(req.Username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if user == nil {
|
||||
return nil, errors.New(resps.ErrNotFound)
|
||||
return nil, errs.ErrNotFound
|
||||
}
|
||||
if utils.Password.VerifyPassword(dto.Password, user.Password, utils.Env.Get(constant.EnvVarPasswordSalt, "default_salt")) {
|
||||
return nil, nil // TODO: Generate JWT token and return it in the response
|
||||
if utils.Password.VerifyPassword(req.Password, user.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) {
|
||||
|
||||
token := utils.Jwt.NewClaims(user.ID, "", false, time.Duration(utils.Env.GetenvAsInt(constant.EnvKeyTokenDuration, 24)*int(time.Hour)))
|
||||
tokenString, err := token.ToString()
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
|
||||
refreshToken := utils.Jwt.NewClaims(user.ID, utils.Strings.GenerateRandomString(64), true, time.Duration(utils.Env.GetenvAsInt(constant.EnvKeyRefreshTokenDuration, 30)*int(time.Hour)))
|
||||
refreshTokenString, err := refreshToken.ToString()
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
// 对refresh token进行持久化存储
|
||||
err = repo.Session.SaveSession(refreshToken.SessionKey)
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
resp := &dto.UserLoginResp{
|
||||
Token: tokenString,
|
||||
RefreshToken: refreshTokenString,
|
||||
User: user.ToDto(),
|
||||
}
|
||||
return resp, nil
|
||||
} else {
|
||||
return nil, errors.New(resps.ErrInvalidCredentials)
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterResp, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp, error) {
|
||||
generatedVerificationCode := utils.Strings.GenerateRandomStringWithCharset(6, "0123456789abcdef")
|
||||
kv := utils.KV.GetInstance()
|
||||
kv.Set(constant.KVKeyEmailVerificationCode+":"+req.Email, generatedVerificationCode, time.Minute*10)
|
||||
|
||||
template, err := static.RenderTemplate("email/verification-code.tmpl", map[string]interface{}{})
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
if utils.IsDevMode {
|
||||
logrus.Infoln("%s's verification code is %s", req.Email, generatedVerificationCode)
|
||||
}
|
||||
err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true)
|
||||
|
||||
if err != nil {
|
||||
return nil, errs.ErrInternalServer
|
||||
}
|
||||
return &dto.VerifyEmailResp{Success: true}, nil
|
||||
}
|
||||
|
70
internal/static/assets/email/verification-code.tmpl
Normal file
70
internal/static/assets/email/verification-code.tmpl
Normal file
@ -0,0 +1,70 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<title>{{.Title}}</title>
|
||||
<style>
|
||||
body {
|
||||
font-family: Arial, sans-serif;
|
||||
line-height: 1.6;
|
||||
color: #333;
|
||||
background-color: #f9f9f9;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
.container {
|
||||
max-width: 600px;
|
||||
margin: 20px auto;
|
||||
background: #fff;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 8px;
|
||||
padding: 20px;
|
||||
}
|
||||
.header {
|
||||
text-align: center;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.header h1 {
|
||||
font-size: 24px;
|
||||
color: #007BFF;
|
||||
}
|
||||
.content {
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.content p {
|
||||
margin: 10px 0;
|
||||
}
|
||||
.code {
|
||||
font-size: 20px;
|
||||
font-weight: bold;
|
||||
color: #007BFF;
|
||||
text-align: center;
|
||||
margin: 20px 0;
|
||||
padding: 10px;
|
||||
background-color: #f0f8ff;
|
||||
border-radius: 4px;
|
||||
}
|
||||
.footer {
|
||||
text-align: center;
|
||||
font-size: 12px;
|
||||
color: #888;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>欢迎使用 {{.Title}}</h1>
|
||||
</div>
|
||||
<div class="content">
|
||||
<p>尊敬的用户 {{.Email}},您好!</p>
|
||||
<p>{{.Details}} 以下是您的验证码:</p>
|
||||
<div class="code">{{.VerifyCode}}</div>
|
||||
<p>请在 <strong>{{.Expire}}</strong> 分钟内使用此验证码完成验证。</p>
|
||||
</div>
|
||||
<div class="footer">
|
||||
<p>如果您未请求此邮件,请忽略。</p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
31
internal/static/embed.go
Normal file
31
internal/static/embed.go
Normal file
@ -0,0 +1,31 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
)
|
||||
|
||||
//go:embed assets/*
|
||||
var AssetsFS embed.FS
|
||||
|
||||
// RenderTemplate 从嵌入的文件系统中读取模板并渲染
|
||||
func RenderTemplate(name string, data interface{}) (string, error) {
|
||||
templatePath := "assets/" + name
|
||||
templateContent, err := AssetsFS.ReadFile(templatePath)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("读取模板文件失败: %w", err)
|
||||
}
|
||||
// 解析模板
|
||||
tmpl, err := template.New(name).Parse(string(templateContent))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("解析模板失败: %w", err)
|
||||
}
|
||||
// 渲染模板
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return "", fmt.Errorf("渲染模板失败: %w", err)
|
||||
}
|
||||
return buf.String(), nil
|
||||
}
|
18
internal/static/embed_test.go
Normal file
18
internal/static/embed_test.go
Normal file
@ -0,0 +1,18 @@
|
||||
package static
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRenderTemplate(t *testing.T) {
|
||||
template, err := RenderTemplate("email/verification-code.tmpl", map[string]interface{}{
|
||||
"Title": "Test Page",
|
||||
"Email": "xxx@.comcom",
|
||||
"Details": "nihao",
|
||||
})
|
||||
t.Logf(template)
|
||||
if err != nil {
|
||||
t.Errorf("渲染模板失败: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user