diff --git a/drivers/base/driver.go b/drivers/base/driver.go index e9a77588..e7a8dce3 100644 --- a/drivers/base/driver.go +++ b/drivers/base/driver.go @@ -111,6 +111,13 @@ func GetDrivers() map[string][]Item { Required: true, Description: "Transfer the WebDAV of this account through the server", }, + { + Name: "webdav_direct", + Label: "webdav direct", + Type: TypeBool, + Required: true, + Description: "Transfer the WebDAV of this account through the native", + }, }, v.Items()...) } res[k] = append([]Item{ diff --git a/drivers/base/types.go b/drivers/base/types.go index 5cd2c8c4..5456b800 100644 --- a/drivers/base/types.go +++ b/drivers/base/types.go @@ -44,7 +44,8 @@ type Header struct { } type Link struct { - Url string `json:"url"` - Headers []Header `json:"headers"` - Data io.ReadCloser + Url string `json:"url"` + Headers []Header `json:"headers"` + Data io.ReadCloser + FilePath string `json:"path"` // for native } diff --git a/drivers/native/driver.go b/drivers/native/driver.go index 9d03d2ff..36b1573c 100644 --- a/drivers/native/driver.go +++ b/drivers/native/driver.go @@ -132,7 +132,7 @@ func (driver Native) Link(args base.Args, account *model.Account) (*base.Link, e return nil, base.ErrNotFile } link := base.Link{ - Url: fullPath, + FilePath: fullPath, } return &link, nil } diff --git a/model/account.go b/model/account.go index bc367b6c..9bb04f6b 100644 --- a/model/account.go +++ b/model/account.go @@ -33,8 +33,9 @@ type Account struct { SiteUrl string `json:"site_url"` SiteId string `json:"site_id"` InternalType string `json:"internal_type"` - WebdavProxy bool `json:"webdav_proxy"` // 开启之后只会webdav走中转 - Proxy bool `json:"proxy"` // 是否中转,开启之后web和webdav都会走中转 + WebdavProxy bool `json:"webdav_proxy"` // 开启之后只会webdav走中转 + Proxy bool `json:"proxy"` // 是否中转,开启之后web和webdav都会走中转 + WebdavDirect bool `json:"webdav_direct"` // webdav 下载不跳转 //AllowProxy bool `json:"allow_proxy"` // 是否允许中转下载 DownProxyUrl string `json:"down_proxy_url"` // 用于中转下载服务的URL 两处 1. path请求中返回的链接 2. down下载时进行302 APIProxyUrl string `json:"api_proxy_url"` // 用于中转api的地址 diff --git a/server/common/proxy.go b/server/common/proxy.go new file mode 100644 index 00000000..be09224b --- /dev/null +++ b/server/common/proxy.go @@ -0,0 +1,87 @@ +package common + +import ( + "errors" + "fmt" + "github.com/Xhofe/alist/drivers/base" + "github.com/Xhofe/alist/model" + log "github.com/sirupsen/logrus" + "io" + "io/ioutil" + "net/http" + "net/url" + "os" + "strconv" +) + +var HttpClient = &http.Client{} + +func Proxy(w http.ResponseWriter, r *http.Request, link *base.Link, file *model.File) error { + // 本机读取数据 + var err error + if link.Data != nil { + //c.Data(http.StatusOK, "application/octet-stream", link.Data) + defer func() { + _ = link.Data.Close() + }() + w.WriteHeader(http.StatusOK) + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename=%s`, url.QueryEscape(file.Name))) + w.Header().Set("Content-Length", strconv.FormatInt(file.Size, 10)) + _, err = io.Copy(w, link.Data) + if err != nil { + return err + } + return nil + } + // 本机文件直接返回文件 + if link.FilePath != "" { + f, err := os.Open(link.FilePath) + if err != nil { + return err + } + defer func() { + _ = f.Close() + }() + fileStat, err := os.Stat(link.FilePath) + if err != nil { + return err + } + http.ServeContent(w, r, file.Name, fileStat.ModTime(), f) + return nil + } else { + req, err := http.NewRequest(r.Method, link.Url, nil) + if err != nil { + return err + } + for h, val := range r.Header { + req.Header[h] = val + } + for _, header := range link.Headers { + req.Header.Set(header.Name, header.Value) + } + res, err := HttpClient.Do(req) + if err != nil { + return err + } + defer func() { + _ = res.Body.Close() + }() + log.Debugf("proxy status: %d", res.StatusCode) + w.WriteHeader(res.StatusCode) + if res.StatusCode >= 400 { + all, _ := ioutil.ReadAll(res.Body) + msg := string(all) + log.Debugln(msg) + return errors.New(msg) + } + for h, v := range res.Header { + w.Header()[h] = v + } + _, err = io.Copy(w, res.Body) + if err != nil { + return err + } + return nil + } +} diff --git a/server/controllers/proxy.go b/server/controllers/proxy.go index 1fd2d9f6..f2b5c2a2 100644 --- a/server/controllers/proxy.go +++ b/server/controllers/proxy.go @@ -10,13 +10,7 @@ import ( "github.com/gin-gonic/gin" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" - "io" - "io/ioutil" - "net/http" - "net/url" - "os" "path/filepath" - "strconv" ) func Proxy(c *gin.Context) { @@ -63,98 +57,13 @@ func Proxy(c *gin.Context) { common.ErrorResp(c, err, 500) return } - // 本机读取数据 - if link.Data != nil { - //c.Data(http.StatusOK, "application/octet-stream", link.Data) - defer func() { - _ = link.Data.Close() - }() - c.Status(http.StatusOK) - c.Header("Content-Type", "application/octet-stream") - c.Header("Content-Disposition", fmt.Sprintf(`attachment; filename=%s`, url.QueryEscape(file.Name))) - c.Header("Content-Length", strconv.FormatInt(file.Size, 10)) - _, err = io.Copy(c.Writer, link.Data) - if err != nil { - _, _ = c.Writer.WriteString(err.Error()) - } - return - } - // 本机文件直接返回文件 - if account.Type == "Native" { - // 对于名称为index.html的文件需要特殊处理 - if utils.Base(rawPath) == "index.html" { - file, err := os.Open(link.Url) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - defer func() { - _ = file.Close() - }() - fileStat, err := os.Stat(link.Url) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - http.ServeContent(c.Writer, c.Request, utils.Base(rawPath), fileStat.ModTime(), file) - return - } - c.File(link.Url) - return - } else { - //if utils.GetFileType(filepath.Ext(rawPath)) == conf.TEXT { - // Text(c, link) - // return - //} - r := c.Request - w := c.Writer - //target, err := url.Parse(link.Url) - //if err != nil { - // common.ErrorResp(c, err, 500) - // return - //} - req, err := http.NewRequest("GET", link.Url, nil) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - log.Debugf("%+v", r.Header) - for h, val := range r.Header { - req.Header[h] = val - } - for _, header := range link.Headers { - req.Header.Set(header.Name, header.Value) - } - log.Debugf("%+v", req.Header) - res, err := HttpClient.Do(req) - if err != nil { - common.ErrorResp(c, err, 500) - return - } - defer func() { - _ = res.Body.Close() - }() - log.Debugf("proxy status: %d", res.StatusCode) - w.WriteHeader(res.StatusCode) - if res.StatusCode >= 400 { - all, _ := ioutil.ReadAll(res.Body) - log.Debugln(string(all)) - common.ErrorStrResp(c, string(all), 500) - return - } - for h, v := range res.Header { - w.Header()[h] = v - } - _, err = io.Copy(w, res.Body) - if err != nil { - common.ErrorResp(c, err, 500) - return - } + err = common.Proxy(c.Writer, c.Request, link, file) + if err != nil { + common.ErrorResp(c, err, 500) } } var client *resty.Client -var HttpClient = &http.Client{} func init() { client = resty.New() diff --git a/server/webdav/file.go b/server/webdav/file.go index 1274fe31..479a16df 100644 --- a/server/webdav/file.go +++ b/server/webdav/file.go @@ -95,7 +95,7 @@ func ClientIP(r *http.Request) string { return "" } -func (fs *FileSystem) Link(r *http.Request, rawPath string) (string, error) { +func (fs *FileSystem) Link(w http.ResponseWriter, r *http.Request, rawPath string) (string, error) { rawPath = utils.ParsePath(rawPath) log.Debugf("get link path: %s", rawPath) if model.AccountsCount() > 1 && rawPath == "/" { @@ -110,6 +110,19 @@ func (fs *FileSystem) Link(r *http.Request, rawPath string) (string, error) { if r.TLS != nil { protocol = "https" } + // 直接返回 + if account.WebdavDirect { + file, err := fs.File(rawPath) + if err != nil { + return "", err + } + link_, err := driver.Link(base.Args{Path: path_}, account) + if err != nil { + return "", err + } + err = common.Proxy(w, r, link_, file) + return "", err + } if driver.Config().OnlyProxy || account.WebdavProxy { link = fmt.Sprintf("%s://%s/p%s", protocol, r.Host, rawPath) if conf.GetBool("check down link") { diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index c0b48220..cc24309b 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -237,11 +237,14 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request, fs * } w.Header().Set("ETag", etag) log.Debugf("url: %+v", r.URL) - link, err := fs.Link(r, reqPath) + link, err := fs.Link(w, r, reqPath) + log.Debugf("webdav link error: %s", err.Error()) if err != nil { return http.StatusInternalServerError, err } - http.Redirect(w, r, link, 302) + if link != "" { + http.Redirect(w, r, link, 302) + } return 0, nil }