feat: file proxy handle

This commit is contained in:
Noah Hsu 2022-06-28 21:58:46 +08:00
parent d1efec4539
commit 96380a50da
11 changed files with 167 additions and 90 deletions

View File

@ -85,7 +85,7 @@ func copyFileBetween2Accounts(tsk *task.Task[uint64], srcAccount, dstAccount dri
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath)
} }
link, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{}) link, _, err := operations.Link(tsk.Ctx, srcAccount, srcFilePath, model.LinkArgs{})
if err != nil { if err != nil {
return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) return errors.WithMessagef(err, "failed get [%s] link", srcFilePath)
} }

View File

@ -30,13 +30,13 @@ func Get(ctx context.Context, path string) (model.Obj, error) {
return res, nil return res, nil
} }
func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) { func Link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
res, err := link(ctx, path, args) res, file, err := link(ctx, path, args)
if err != nil { if err != nil {
log.Errorf("failed link %s: %+v", path, err) log.Errorf("failed link %s: %+v", path, err)
return nil, err return nil, nil, err
} }
return res, nil return res, file, nil
} }
func MakeDir(ctx context.Context, path string) error { func MakeDir(ctx context.Context, path string) error {

View File

@ -7,10 +7,10 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, error) { func link(ctx context.Context, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
account, actualPath, err := operations.GetAccountAndActualPath(path) account, actualPath, err := operations.GetAccountAndActualPath(path)
if err != nil { if err != nil {
return nil, errors.WithMessage(err, "failed get account") return nil, nil, errors.WithMessage(err, "failed get account")
} }
return operations.Link(ctx, account, actualPath, args) return operations.Link(ctx, account, actualPath, args)
} }

View File

@ -115,19 +115,19 @@ var linkCache = cache.NewMemCache(cache.WithShards[*model.Link](16))
var linkG singleflight.Group[*model.Link] var linkG singleflight.Group[*model.Link]
// Link get link, if is an url. should have an expiry time // Link get link, if is an url. should have an expiry time
func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, error) { func Link(ctx context.Context, account driver.Driver, path string, args model.LinkArgs) (*model.Link, model.Obj, error) {
key := stdpath.Join(account.GetAccount().VirtualPath, path)
if link, ok := linkCache.Get(key); ok {
return link, nil
}
fn := func() (*model.Link, error) {
file, err := Get(ctx, account, path) file, err := Get(ctx, account, path)
if err != nil { if err != nil {
return nil, errors.WithMessage(err, "failed to get file") return nil, nil, errors.WithMessage(err, "failed to get file")
} }
if file.IsDir() { if file.IsDir() {
return nil, errors.WithStack(errs.NotFile) return nil, nil, errors.WithStack(errs.NotFile)
} }
key := stdpath.Join(account.GetAccount().VirtualPath, path)
if link, ok := linkCache.Get(key); ok {
return link, file, nil
}
fn := func() (*model.Link, error) {
link, err := account.Link(ctx, file, args) link, err := account.Link(ctx, file, args)
if err != nil { if err != nil {
return nil, errors.WithMessage(err, "failed get link") return nil, errors.WithMessage(err, "failed get link")
@ -138,7 +138,7 @@ func Link(ctx context.Context, account driver.Driver, path string, args model.Li
return link, nil return link, nil
} }
link, err, _ := linkG.Do(key, fn) link, err, _ := linkG.Do(key, fn)
return link, err return link, file, err
} }
func MakeDir(ctx context.Context, account driver.Driver, path string) error { func MakeDir(ctx context.Context, account driver.Driver, path string) error {

View File

@ -10,6 +10,15 @@ import (
var once sync.Once var once sync.Once
var instance sign.Sign var instance sign.Sign
func Sign(data string) string {
expire := setting.GetIntSetting("link_expiration", 0)
if expire == 0 {
return NotExpired(data)
} else {
return WithDuration(data, time.Duration(expire)*time.Hour)
}
}
func WithDuration(data string, d time.Duration) string { func WithDuration(data string, d time.Duration) string {
once.Do(Instance) once.Do(Instance)
return instance.Sign(data, time.Now().Add(d).Unix()) return instance.Sign(data, time.Now().Add(d).Unix())

13
server/common/sign.go Normal file
View File

@ -0,0 +1,13 @@
package common
import (
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/sign"
)
func Sign(obj model.Obj) string {
if obj.IsDir() {
return ""
}
return sign.Sign(obj.GetName())
}

View File

@ -1,66 +1,33 @@
package controllers package controllers
import ( import (
"fmt"
"github.com/alist-org/alist/v3/internal/sign"
stdpath "path" stdpath "path"
"strings" "strings"
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/pkg/errors"
) )
func Down(c *gin.Context) { func Down(c *gin.Context) {
rawPath := parsePath(c.Param("path")) rawPath := c.MustGet("path").(string)
filename := stdpath.Base(rawPath) filename := stdpath.Base(rawPath)
meta, err := db.GetNearestMeta(rawPath)
if err != nil {
if !errors.Is(errors.Cause(err), errs.MetaNotFound) {
common.ErrorResp(c, err, 500, true)
return
}
}
// verify sign
if needSign(meta, rawPath) {
s := c.Param("sign")
err = sign.Verify(filename, s)
if err != nil {
common.ErrorResp(c, err, 401)
return
}
}
account, err := fs.GetAccount(rawPath) account, err := fs.GetAccount(rawPath)
if err != nil { if err != nil {
common.ErrorResp(c, err, 500) common.ErrorResp(c, err, 500)
return return
} }
if needProxy(account, filename) { if shouldProxy(account, filename) {
link, err := fs.Link(c, rawPath, model.LinkArgs{ Proxy(c)
Header: c.Request.Header,
})
if err != nil {
common.ErrorResp(c, err, 500)
return return
}
obj, err := fs.Get(c, rawPath)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
err = common.Proxy(c.Writer, c.Request, link, obj)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
} else { } else {
link, err := fs.Link(c, rawPath, model.LinkArgs{ link, _, err := fs.Link(c, rawPath, model.LinkArgs{
IP: c.ClientIP(), IP: c.ClientIP(),
Header: c.Request.Header, Header: c.Request.Header,
}) })
@ -72,25 +39,49 @@ func Down(c *gin.Context) {
} }
} }
// TODO: implement func Proxy(c *gin.Context) {
// path maybe contains # ? etc. rawPath := c.MustGet("path").(string)
func parsePath(path string) string { filename := stdpath.Base(rawPath)
return utils.StandardizePath(path) account, err := fs.GetAccount(rawPath)
if err != nil {
common.ErrorResp(c, err, 500)
return
}
if canProxy(account, filename) {
downProxyUrl := account.GetAccount().DownProxyUrl
if downProxyUrl != "" {
_, ok := c.GetQuery("d")
if ok {
URL := fmt.Sprintf("%s%s?sign=%s", strings.Split(downProxyUrl, "\n")[0], rawPath, sign.Sign(filename))
c.Redirect(302, URL)
return
}
}
link, file, err := fs.Link(c, rawPath, model.LinkArgs{
Header: c.Request.Header,
})
if err != nil {
common.ErrorResp(c, err, 500)
return
}
err = common.Proxy(c.Writer, c.Request, link, file)
if err != nil {
common.ErrorResp(c, err, 500, true)
return
}
} else {
common.ErrorStrResp(c, "proxy not allowed", 403)
return
}
} }
func needSign(meta *model.Meta, path string) bool { // TODO need optimize
if meta == nil || meta.Password == "" { // when should be proxy?
return false // 1. config.MustProxy()
} // 2. account.WebProxy
if !meta.SubFolder && path != meta.Path { // 3. proxy_types
return false func shouldProxy(account driver.Driver, filename string) bool {
} if account.Config().MustProxy() || account.GetAccount().WebProxy {
return true
}
func needProxy(account driver.Driver, filename string) bool {
config := account.Config()
if config.MustProxy() {
return true return true
} }
proxyTypes := setting.GetByKey("proxy_types") proxyTypes := setting.GetByKey("proxy_types")
@ -99,3 +90,25 @@ func needProxy(account driver.Driver, filename string) bool {
} }
return false return false
} }
// TODO need optimize
// when can be proxy?
// 1. text file
// 2. config.MustProxy()
// 3. account.WebProxy
// 4. proxy_types
// solution: text_file + shouldProxy()
func canProxy(account driver.Driver, filename string) bool {
if account.Config().MustProxy() || account.GetAccount().WebProxy {
return true
}
proxyTypes := setting.GetByKey("proxy_types")
if strings.Contains(proxyTypes, utils.Ext(filename)) {
return true
}
textTypes := setting.GetByKey("text_types")
if strings.Contains(textTypes, utils.Ext(filename)) {
return true
}
return false
}

View File

@ -53,7 +53,7 @@ func FsGet(c *gin.Context) {
Size: obj.GetSize(), Size: obj.GetSize(),
IsDir: obj.IsDir(), IsDir: obj.IsDir(),
Modified: obj.ModTime(), Modified: obj.ModTime(),
Sign: Sign(obj), Sign: common.Sign(obj),
}, },
// TODO: set raw url // TODO: set raw url
}) })

View File

@ -9,7 +9,6 @@ import (
"github.com/alist-org/alist/v3/internal/fs" "github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/setting" "github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common" "github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@ -108,20 +107,8 @@ func toObjResp(objs []model.Obj, path string, baseURL string) []ObjResp {
Size: obj.GetSize(), Size: obj.GetSize(),
IsDir: obj.IsDir(), IsDir: obj.IsDir(),
Modified: obj.ModTime(), Modified: obj.ModTime(),
Sign: Sign(obj), Sign: common.Sign(obj),
}) })
} }
return resp return resp
} }
func Sign(obj model.Obj) string {
if obj.IsDir() {
return ""
}
expire := setting.GetIntSetting("link_expiration", 0)
if expire == 0 {
return sign.NotExpired(obj.GetName())
} else {
return sign.WithDuration(obj.GetName(), time.Duration(expire)*time.Hour)
}
}

View File

@ -0,0 +1,54 @@
package middlewares
import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/errs"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/sign"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/alist-org/alist/v3/server/common"
"github.com/gin-gonic/gin"
"github.com/pkg/errors"
stdpath "path"
)
func Down(c *gin.Context) {
rawPath := parsePath(c.Param("path"))
c.Set("path", rawPath)
filename := stdpath.Base(rawPath)
meta, err := db.GetNearestMeta(rawPath)
if err != nil {
if !errors.Is(errors.Cause(err), errs.MetaNotFound) {
common.ErrorResp(c, err, 500, true)
return
}
}
c.Set("meta", meta)
// verify sign
if needSign(meta, rawPath) {
s := c.Param("sign")
err = sign.Verify(filename, s)
if err != nil {
common.ErrorResp(c, err, 401)
c.Abort()
return
}
}
c.Next()
}
// TODO: implement
// path maybe contains # ? etc.
func parsePath(path string) string {
return utils.StandardizePath(path)
}
func needSign(meta *model.Meta, path string) bool {
if meta == nil || meta.Password == "" {
return false
}
if !meta.SubFolder && path != meta.Path {
return false
}
return true
}

View File

@ -13,7 +13,8 @@ func Init(r *gin.Engine) {
common.SecretKey = []byte(conf.Conf.JwtSecret) common.SecretKey = []byte(conf.Conf.JwtSecret)
Cors(r) Cors(r)
r.GET("/d/*path", controllers.Down) r.GET("/d/*path", middlewares.Down, controllers.Down)
r.GET("/p/*path", middlewares.Down, controllers.Proxy)
api := r.Group("/api", middlewares.Auth) api := r.Group("/api", middlewares.Auth)
api.POST("/auth/login", controllers.Login) api.POST("/auth/login", controllers.Login)