diff --git a/internal/controller/v1/user.go b/internal/controller/v1/user.go index a36b6e8..4ebadba 100644 --- a/internal/controller/v1/user.go +++ b/internal/controller/v1/user.go @@ -21,18 +21,21 @@ var User = &userType{ } func (u *userType) Login(ctx context.Context, c *app.RequestContext) { - var userLoginReq *dto.UserLoginReq - if err := c.BindAndValidate(userLoginReq); err != nil { + var userLoginReq dto.UserLoginReq + if err := c.BindAndValidate(&userLoginReq); err != nil { resps.BadRequest(c, resps.ErrParamInvalid) + return } - resp, err := u.service.UserLogin(userLoginReq) + resp, err := u.service.UserLogin(&userLoginReq) if err != nil { serviceErr := errs.AsServiceError(err) resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) + return } if resp == nil { resps.UnAuthorized(c, resps.ErrInvalidCredentials) + return } else { u.setTokenCookie(c, resp.Token, resp.RefreshToken) resps.Ok(c, resps.Success, resp) @@ -40,12 +43,12 @@ func (u *userType) Login(ctx context.Context, c *app.RequestContext) { } func (u *userType) Register(ctx context.Context, c *app.RequestContext) { - var userRegisterReq *dto.UserRegisterReq - if err := c.BindAndValidate(userRegisterReq); err != nil { + var userRegisterReq dto.UserRegisterReq + if err := c.BindAndValidate(&userRegisterReq); err != nil { resps.BadRequest(c, resps.ErrParamInvalid) return } - resp, err := u.service.UserRegister(userRegisterReq) + resp, err := u.service.UserRegister(&userRegisterReq) if err != nil { serviceErr := errs.AsServiceError(err) @@ -86,12 +89,12 @@ func (u *userType) Delete(ctx context.Context, c *app.RequestContext) { } func (u *userType) VerifyEmail(ctx context.Context, c *app.RequestContext) { - var verifyEmailReq *dto.VerifyEmailReq - if err := c.BindAndValidate(verifyEmailReq); err != nil { + var verifyEmailReq dto.VerifyEmailReq + if err := c.BindAndValidate(&verifyEmailReq); err != nil { resps.BadRequest(c, resps.ErrParamInvalid) return } - resp, err := u.service.VerifyEmail(verifyEmailReq) + resp, err := u.service.VerifyEmail(&verifyEmailReq) if err != nil { serviceErr := errs.AsServiceError(err) resps.Custom(c, serviceErr.Code, serviceErr.Message, nil) diff --git a/internal/model/user.go b/internal/model/user.go index e5e1b90..525cfe7 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -13,8 +13,7 @@ type User struct { Email string `gorm:"uniqueIndex"` Gender string Role string `gorm:"default:'user'"` - - Password string // 密码,存储加密后的值 + Password string // 密码,存储加密后的值 } func (user *User) ToDto() *dto.UserDto { diff --git a/internal/repo/init.go b/internal/repo/init.go index 1487618..25bbac9 100644 --- a/internal/repo/init.go +++ b/internal/repo/init.go @@ -71,14 +71,12 @@ func InitDatabase() error { return errors.New("unsupported database driver, only sqlite and postgres are supported") } - return nil + // 迁移模型 + if err = migrate(); err != nil { + logrus.Error("Failed to migrate models:", err) + return err + } // TODO: impl - - //// 迁移模型 - //if err = models.Migrate(db); err != nil { - // logrus.Error("Failed to migrate models:", err) - // return err - //} //// 执行初始化数据 //// 创建管理员账户 //hashedPassword, err := utils.Password.HashPassword(config.AdminPassword, config.JwtSecret) @@ -95,7 +93,7 @@ func InitDatabase() error { // logrus.Error("Failed to update admin user:", err) // return err //} - //return nil + return nil } // initPostgres 初始化PostgreSQL连接 @@ -128,6 +126,8 @@ func migrate() error { return GetDB().AutoMigrate( &model.Comment{}, &model.Label{}, + &model.OidcConfig{}, &model.Post{}, + &model.Session{}, &model.User{}) } diff --git a/internal/repo/user.go b/internal/repo/user.go index 4128d78..dcab751 100644 --- a/internal/repo/user.go +++ b/internal/repo/user.go @@ -43,3 +43,19 @@ func (user *userRepo) Update(userModel *model.User) error { } return nil } + +func (user *userRepo) CheckUsernameExists(username string) (bool, error) { + var count int64 + if err := GetDB().Model(&model.User{}).Where("username = ?", username).Count(&count).Error; err != nil { + return false, err + } + return count > 0, nil +} + +func (user *userRepo) CheckEmailExists(email string) (bool, error) { + var count int64 + if err := GetDB().Model(&model.User{}).Where("email = ?", email).Count(&count).Error; err != nil { + return false, err + } + return count > 0, nil +} diff --git a/internal/service/user.go b/internal/service/user.go index 6f6dc63..b666bf2 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -76,21 +76,30 @@ func (s *userService) UserRegister(req *dto.UserRegisterReq) (*dto.UserRegisterR } } // 检查用户名或邮箱是否已存在 - existingUser, err := repo.User.GetByUsernameOrEmail(req.Username) + usernameExist, err := repo.User.CheckUsernameExists(req.Username) if err != nil { return nil, errs.ErrInternalServer } - if existingUser != nil { + emailExist, err := repo.User.CheckEmailExists(req.Email) + if err != nil { + return nil, errs.ErrInternalServer + } + if usernameExist || emailExist { return nil, errs.New(http.StatusConflict, "Username or email already exists", nil) } // 创建新用户 + hashedPassword, err := utils.Password.HashPassword(req.Password, utils.Env.Get(constant.EnvKeyPasswordSalt, "default_salt")) + if err != nil { + logrus.Errorln("Failed to hash password:", err) + return nil, errs.ErrInternalServer + } newUser := &model.User{ Username: req.Username, Nickname: req.Nickname, Email: req.Email, Gender: "", Role: "user", - Password: "", + Password: hashedPassword, } err = repo.User.Create(newUser) if err != nil { @@ -131,7 +140,7 @@ func (s *userService) VerifyEmail(req *dto.VerifyEmailReq) (*dto.VerifyEmailResp return nil, errs.ErrInternalServer } if utils.IsDevMode { - logrus.Infoln("%s's verification code is %s", req.Email, generatedVerificationCode) + logrus.Infof("%s's verification code is %s", req.Email, generatedVerificationCode) } err = utils.Email.SendEmail(utils.Email.GetEmailConfigFromEnv(), req.Email, "验证你的电子邮件 / Verify your email", template, true) diff --git a/pkg/utils/env.go b/pkg/utils/env.go index da3d6b8..b5430e9 100644 --- a/pkg/utils/env.go +++ b/pkg/utils/env.go @@ -15,7 +15,7 @@ func init() { _ = godotenv.Load() // Init env - IsDevMode = Env.Get(constant.EnvKeyMode, constant.ModeDev) == constant.ModeDev + IsDevMode = Env.Get(constant.EnvKeyMode, constant.ModeProd) == constant.ModeDev } type envUtils struct{} diff --git a/pkg/utils/json_web_token.go b/pkg/utils/jwt.go similarity index 100% rename from pkg/utils/json_web_token.go rename to pkg/utils/jwt.go