feat(sso): custom username key for OIDC (close #5169)

This commit is contained in:
Andy Hsu 2023-10-02 14:42:40 +08:00
parent 40a6fcbdff
commit e719a1a456
3 changed files with 31 additions and 15 deletions

View File

@ -158,6 +158,7 @@ func InitialSettings() []model.SettingItem {
{Key: conf.SSOLoginPlatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC}, {Key: conf.SSOLoginPlatform, Type: conf.TypeSelect, Options: "Casdoor,Github,Microsoft,Google,Dingtalk,OIDC", Group: model.SSO, Flag: model.PUBLIC},
{Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOClientId, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOClientSecret, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOOIDCUsernameKey, Value: "name", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOOrganizationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOApplicationName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},
{Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE}, {Key: conf.SSOEndpointName, Value: "", Type: conf.TypeString, Group: model.SSO, Flag: model.PRIVATE},

View File

@ -62,6 +62,7 @@ const (
SSOClientSecret = "sso_client_secret" SSOClientSecret = "sso_client_secret"
SSOLoginEnabled = "sso_login_enabled" SSOLoginEnabled = "sso_login_enabled"
SSOLoginPlatform = "sso_login_platform" SSOLoginPlatform = "sso_login_platform"
SSOOIDCUsernameKey = "sso_oidc_username_key"
SSOOrganizationName = "sso_organization_name" SSOOrganizationName = "sso_organization_name"
SSOApplicationName = "sso_application_name" SSOApplicationName = "sso_application_name"
SSOEndpointName = "sso_endpoint_name" SSOEndpointName = "sso_endpoint_name"

View File

@ -2,6 +2,7 @@ package handles
import ( import (
"encoding/base32" "encoding/base32"
"encoding/base64"
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
@ -166,10 +167,22 @@ func autoRegister(username, userID string, err error) (*model.User, error) {
return user, nil return user, nil
} }
func parseJWT(p string) ([]byte, error) {
parts := strings.Split(p, ".")
if len(parts) < 2 {
return nil, fmt.Errorf("oidc: malformed jwt, expected 3 parts got %d", len(parts))
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("oidc: malformed jwt payload: %v", err)
}
return payload, nil
}
func OIDCLoginCallback(c *gin.Context) { func OIDCLoginCallback(c *gin.Context) {
usecompatibility := setting.GetBool(conf.SSOCompatibilityMode) useCompatibility := setting.GetBool(conf.SSOCompatibilityMode)
argument := c.Query("method") argument := c.Query("method")
if usecompatibility { if useCompatibility {
argument = path.Base(c.Request.URL.Path) argument = path.Base(c.Request.URL.Path)
} }
clientId := setting.GetStr(conf.SSOClientId) clientId := setting.GetStr(conf.SSOClientId)
@ -208,23 +221,24 @@ func OIDCLoginCallback(c *gin.Context) {
verifier := provider.Verifier(&oidc.Config{ verifier := provider.Verifier(&oidc.Config{
ClientID: clientId, ClientID: clientId,
}) })
idToken, err := verifier.Verify(c, rawIDToken) _, err = verifier.Verify(c, rawIDToken)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
} }
type UserInfo struct { payload, err := parseJWT(rawIDToken)
Name string `json:"name"` if err != nil {
}
claims := UserInfo{}
if err := idToken.Claims(&claims); err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
return return
} }
UserID := claims.Name userID := utils.Json.Get(payload, conf.SSOOIDCUsernameKey).ToString()
if userID == "" {
common.ErrorStrResp(c, "cannot get username from OIDC provider", 400)
return
}
if argument == "get_sso_id" { if argument == "get_sso_id" {
if usecompatibility { if useCompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+UserID) c.Redirect(302, common.GetApiUrl(c.Request)+"/@manage?sso_id="+userID)
return return
} }
html := fmt.Sprintf(`<!DOCTYPE html> html := fmt.Sprintf(`<!DOCTYPE html>
@ -234,14 +248,14 @@ func OIDCLoginCallback(c *gin.Context) {
window.opener.postMessage({"sso_id": "%s"}, "*") window.opener.postMessage({"sso_id": "%s"}, "*")
window.close() window.close()
</script> </script>
</body>`, UserID) </body>`, userID)
c.Data(200, "text/html; charset=utf-8", []byte(html)) c.Data(200, "text/html; charset=utf-8", []byte(html))
return return
} }
if argument == "sso_get_token" { if argument == "sso_get_token" {
user, err := db.GetUserBySSOID(UserID) user, err := db.GetUserBySSOID(userID)
if err != nil { if err != nil {
user, err = autoRegister(UserID, UserID, err) user, err = autoRegister(userID, userID, err)
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
} }
@ -250,7 +264,7 @@ func OIDCLoginCallback(c *gin.Context) {
if err != nil { if err != nil {
common.ErrorResp(c, err, 400) common.ErrorResp(c, err, 400)
} }
if usecompatibility { if useCompatibility {
c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token) c.Redirect(302, common.GetApiUrl(c.Request)+"/@login?token="+token)
return return
} }