@ -3,6 +3,7 @@ package handles
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
@ -11,8 +12,10 @@ import (
|
||||
"github.com/alist-org/alist/v3/internal/setting"
|
||||
"github.com/alist-org/alist/v3/pkg/utils"
|
||||
"github.com/alist-org/alist/v3/server/common"
|
||||
"github.com/coreos/go-oidc"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/go-resty/resty/v2"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
func SSOLoginRedirect(c *gin.Context) {
|
||||
@ -53,6 +56,14 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||
r_url = endpoint + "/login/oauth/authorize?"
|
||||
urlValues.Add("scope", "profile")
|
||||
urlValues.Add("state", endpoint)
|
||||
case "OIDC":
|
||||
oauth2Config, err := GetOIDCClient(c)
|
||||
if err != nil {
|
||||
common.ErrorStrResp(c, err.Error(), 400)
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusFound, oauth2Config.AuthCodeURL("state"))
|
||||
return
|
||||
default:
|
||||
common.ErrorStrResp(c, "invalid platform", 400)
|
||||
return
|
||||
@ -65,6 +76,108 @@ func SSOLoginRedirect(c *gin.Context) {
|
||||
|
||||
var ssoClient = resty.New().SetRetryCount(3)
|
||||
|
||||
func GetOIDCClient(c *gin.Context) (*oauth2.Config, error) {
|
||||
argument := c.Query("method")
|
||||
redirect_uri := common.GetApiUrl(c.Request) + "/api/auth/sso_callback" + "?method=" + argument
|
||||
endpoint := setting.GetStr(conf.SSOEndpointName)
|
||||
provider, err := oidc.NewProvider(c, endpoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clientId := setting.GetStr(conf.SSOClientId)
|
||||
clientSecret := setting.GetStr(conf.SSOClientSecret)
|
||||
return &oauth2.Config{
|
||||
ClientID: clientId,
|
||||
ClientSecret: clientSecret,
|
||||
RedirectURL: redirect_uri,
|
||||
|
||||
// Discovery returns the OAuth2 endpoints.
|
||||
Endpoint: provider.Endpoint(),
|
||||
|
||||
// "openid" is a required scope for OpenID Connect flows.
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile"},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func OIDCLoginCallback(c *gin.Context) {
|
||||
argument := c.Query("method")
|
||||
enabled := setting.GetBool(conf.SSOLoginEnabled)
|
||||
clientId := setting.GetStr(conf.SSOClientId)
|
||||
if !enabled {
|
||||
common.ErrorResp(c, errors.New("invalid request"), 500)
|
||||
}
|
||||
endpoint := setting.GetStr(conf.SSOEndpointName)
|
||||
provider, err := oidc.NewProvider(c, endpoint)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
oauth2Config, err := GetOIDCClient(c)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
oauth2Token, err := oauth2Config.Exchange(c, c.Query("code"))
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
common.ErrorStrResp(c, "no id_token found in oauth2 token", 400)
|
||||
return
|
||||
}
|
||||
verifier := provider.Verifier(&oidc.Config{
|
||||
ClientID: clientId,
|
||||
})
|
||||
idToken, err := verifier.Verify(c, rawIDToken)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
return
|
||||
}
|
||||
type UserInfo struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
claims := UserInfo{}
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
c.Error(err)
|
||||
return
|
||||
}
|
||||
UserID := claims.Name
|
||||
if argument == "get_sso_id" {
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
<head></head>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage({"sso_id": "%s"}, "*")
|
||||
window.close()
|
||||
</script>
|
||||
</body>`, UserID)
|
||||
c.Data(200, "text/html; charset=utf-8", []byte(html))
|
||||
return
|
||||
}
|
||||
if argument == "sso_get_token" {
|
||||
user, err := db.GetUserBySSOID(UserID)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
}
|
||||
token, err := common.GenerateToken(user.Username)
|
||||
if err != nil {
|
||||
common.ErrorResp(c, err, 400)
|
||||
}
|
||||
html := fmt.Sprintf(`<!DOCTYPE html>
|
||||
<head></head>
|
||||
<body>
|
||||
<script>
|
||||
window.opener.postMessage({"token":"%s"}, "*")
|
||||
window.close()
|
||||
</script>
|
||||
</body>`, token)
|
||||
c.Data(200, "text/html; charset=utf-8", []byte(html))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func SSOLoginCallback(c *gin.Context) {
|
||||
argument := c.Query("method")
|
||||
if argument == "get_sso_id" || argument == "sso_get_token" {
|
||||
@ -108,6 +221,9 @@ func SSOLoginCallback(c *gin.Context) {
|
||||
scope = "profile"
|
||||
authstring = "code"
|
||||
idstring = "preferred_username"
|
||||
case "OIDC":
|
||||
OIDCLoginCallback(c)
|
||||
return
|
||||
default:
|
||||
common.ErrorStrResp(c, "invalid platform", 400)
|
||||
return
|
||||
|
Reference in New Issue
Block a user