From b2f5757f8d57c803358e7ba47083ba8354a11cd8 Mon Sep 17 00:00:00 2001 From: Andy Hsu Date: Fri, 26 May 2023 21:54:57 +0800 Subject: [PATCH] fix(copy): copy from driver that return `writer` (close #4291) --- drivers/quark/driver.go | 96 +++++++++++++++++++------------------- internal/fs/copy.go | 5 +- internal/fs/util.go | 29 +++++------- internal/model/args.go | 17 ++++--- server/common/proxy.go | 17 ++++++- server/handles/fsmanage.go | 2 +- server/webdav/webdav.go | 2 +- 7 files changed, 91 insertions(+), 77 deletions(-) diff --git a/drivers/quark/driver.go b/drivers/quark/driver.go index 4185ffd5..f404f54f 100644 --- a/drivers/quark/driver.go +++ b/drivers/quark/driver.go @@ -69,58 +69,60 @@ func (d *Quark) Link(ctx context.Context, file model.Obj, args model.LinkArgs) ( } u := resp.Data[0].DownloadUrl start, end := int64(0), file.GetSize() - return &model.Link{ - Handle: func(w http.ResponseWriter, r *http.Request) error { - if rg := r.Header.Get("Range"); rg != "" { - parseRange, err := http_range.ParseRange(rg, file.GetSize()) - if err != nil { - return err - } - start, end = parseRange[0].Start, parseRange[0].Start+parseRange[0].Length - w.Header().Set("Content-Range", parseRange[0].ContentRange(file.GetSize())) - w.Header().Set("Content-Length", strconv.FormatInt(parseRange[0].Length, 10)) - w.WriteHeader(http.StatusPartialContent) - } else { - w.Header().Set("Content-Length", strconv.FormatInt(file.GetSize(), 10)) - w.WriteHeader(http.StatusOK) + link := model.Link{ + Header: http.Header{}, + } + if rg := args.Header.Get("Range"); rg != "" { + parseRange, err := http_range.ParseRange(rg, file.GetSize()) + if err != nil { + return nil, err + } + start, end = parseRange[0].Start, parseRange[0].Start+parseRange[0].Length + link.Header.Set("Content-Range", parseRange[0].ContentRange(file.GetSize())) + link.Header.Set("Content-Length", strconv.FormatInt(parseRange[0].Length, 10)) + link.Status = http.StatusPartialContent + } else { + link.Header.Set("Content-Length", strconv.FormatInt(file.GetSize(), 10)) + link.Status = http.StatusOK + } + link.Writer = func(w io.Writer) error { + // request 10 MB at a time + chunkSize := int64(10 * 1024 * 1024) + for start < end { + _end := start + chunkSize + if _end > end { + _end = end } - // request 10 MB at a time - chunkSize := int64(10 * 1024 * 1024) - for start < end { - _end := start + chunkSize - if _end > end { - _end = end - } - _range := "bytes=" + strconv.FormatInt(start, 10) + "-" + strconv.FormatInt(_end-1, 10) - start = _end - err = func() error { - req, err := http.NewRequest(r.Method, u, nil) - if err != nil { - return err - } - req.Header.Set("Range", _range) - req.Header.Set("User-Agent", ua) - req.Header.Set("Cookie", d.Cookie) - req.Header.Set("Referer", "https://pan.quark.cn") - resp, err := base.HttpClient.Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusPartialContent { - return fmt.Errorf("unexpected status code: %d", resp.StatusCode) - } - _, err = io.Copy(w, resp.Body) - return err - }() + _range := "bytes=" + strconv.FormatInt(start, 10) + "-" + strconv.FormatInt(_end-1, 10) + start = _end + err = func() error { + req, err := http.NewRequest(http.MethodGet, u, nil) if err != nil { return err } + req.Header.Set("Range", _range) + req.Header.Set("User-Agent", ua) + req.Header.Set("Cookie", d.Cookie) + req.Header.Set("Referer", "https://pan.quark.cn") + resp, err := base.HttpClient.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusPartialContent { + return fmt.Errorf("unexpected status code: %d", resp.StatusCode) + } + _, err = io.Copy(w, resp.Body) + return err + }() + if err != nil { + return err + } - } - return nil - }, - }, nil + } + return nil + } + return &link, nil } func (d *Quark) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error { diff --git a/internal/fs/copy.go b/internal/fs/copy.go index e18a6127..87735f2a 100644 --- a/internal/fs/copy.go +++ b/internal/fs/copy.go @@ -3,6 +3,7 @@ package fs import ( "context" "fmt" + "net/http" stdpath "path" "sync/atomic" @@ -87,7 +88,9 @@ func copyFileBetween2Storages(tsk *task.Task[uint64], srcStorage, dstStorage dri if err != nil { return errors.WithMessagef(err, "failed get src [%s] file", srcFilePath) } - link, _, err := op.Link(tsk.Ctx, srcStorage, srcFilePath, model.LinkArgs{}) + link, _, err := op.Link(tsk.Ctx, srcStorage, srcFilePath, model.LinkArgs{ + Header: http.Header{}, + }) if err != nil { return errors.WithMessagef(err, "failed get [%s] link", srcFilePath) } diff --git a/internal/fs/util.go b/internal/fs/util.go index a8af9614..947f8090 100644 --- a/internal/fs/util.go +++ b/internal/fs/util.go @@ -10,30 +10,13 @@ import ( "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/model" - "github.com/alist-org/alist/v3/internal/op" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/server/common" "github.com/google/uuid" "github.com/pkg/errors" + log "github.com/sirupsen/logrus" ) -func ClearCache(path string) { - storage, actualPath, err := op.GetStorageAndActualPath(path) - if err != nil { - return - } - op.ClearCache(storage, actualPath) -} - -func containsByName(files []model.Obj, file model.Obj) bool { - for _, f := range files { - if f.GetName() == file.GetName() { - return true - } - } - return false -} - func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, error) { var rc io.ReadCloser mimetype := utils.GetMimeType(file.GetName()) @@ -51,6 +34,16 @@ func getFileStreamFromLink(file model.Obj, link *model.Link) (*model.FileStream, return nil, errors.Wrapf(err, "failed to open file %s", *link.FilePath) } rc = f + } else if link.Writer != nil { + r, w := io.Pipe() + go func() { + err := link.Writer(w) + err = w.CloseWithError(err) + if err != nil { + log.Errorf("[getFileStreamFromLink] failed to write: %v", err) + } + }() + rc = r } else { req, err := http.NewRequest(http.MethodGet, link.URL, nil) if err != nil { diff --git a/internal/model/args.go b/internal/model/args.go index e9c3d12f..f40adad8 100644 --- a/internal/model/args.go +++ b/internal/model/args.go @@ -18,13 +18,14 @@ type LinkArgs struct { } type Link struct { - URL string `json:"url"` - Header http.Header `json:"header"` // needed header - Data io.ReadCloser // return file reader directly - Status int // status maybe 200 or 206, etc - FilePath *string // local file, return the filepath - Expiration *time.Duration // url expiration time - Handle func(w http.ResponseWriter, r *http.Request) error `json:"-"` // custom handler + URL string `json:"url"` + Header http.Header `json:"header"` // needed header + Data io.ReadCloser // return file reader directly + Status int // status maybe 200 or 206, etc + FilePath *string // local file, return the filepath + Expiration *time.Duration // url expiration time + //Handle func(w http.ResponseWriter, r *http.Request) error `json:"-"` // custom handler + Writer WriterFunc `json:"-"` // custom writer } type OtherArgs struct { @@ -38,3 +39,5 @@ type FsOtherArgs struct { Method string `json:"method" form:"method"` Data interface{} `json:"data" form:"data"` } + +type WriterFunc func(w io.Writer) error diff --git a/server/common/proxy.go b/server/common/proxy.go index 66724b67..2ad8b473 100644 --- a/server/common/proxy.go +++ b/server/common/proxy.go @@ -81,8 +81,21 @@ func Proxy(w http.ResponseWriter, r *http.Request, link *model.Link, file model. w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, filename, url.PathEscape(filename))) http.ServeContent(w, r, file.GetName(), fileStat.ModTime(), f) return nil - } else if link.Handle != nil { - return link.Handle(w, r) + } else if link.Writer != nil { + if link.Header != nil { + for h, v := range link.Header { + w.Header()[h] = v + } + } + if cd := w.Header().Get("Content-Disposition"); cd == "" { + w.Header().Set("Content-Disposition", fmt.Sprintf(`attachment; filename="%s"; filename*=UTF-8''%s`, file.GetName(), url.PathEscape(file.GetName()))) + } + if link.Status == 0 { + w.WriteHeader(http.StatusOK) + } else { + w.WriteHeader(link.Status) + } + return link.Writer(w) } else { req, err := http.NewRequest(r.Method, link.URL, nil) if err != nil { diff --git a/server/handles/fsmanage.go b/server/handles/fsmanage.go index c1ea915e..adeca863 100644 --- a/server/handles/fsmanage.go +++ b/server/handles/fsmanage.go @@ -476,7 +476,7 @@ func Link(c *gin.Context) { }) return } - link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), HttpReq: c.Request}) + link, _, err := fs.Link(c, rawPath, model.LinkArgs{IP: c.ClientIP(), Header: c.Request.Header, HttpReq: c.Request}) if err != nil { common.ErrorResp(c, err, 500) return diff --git a/server/webdav/webdav.go b/server/webdav/webdav.go index a5b362d6..e01d2efc 100644 --- a/server/webdav/webdav.go +++ b/server/webdav/webdav.go @@ -249,7 +249,7 @@ func (h *Handler) handleGetHeadPost(w http.ResponseWriter, r *http.Request) (sta w.Header().Set("Cache-Control", "max-age=0, no-cache, no-store, must-revalidate") http.Redirect(w, r, u, http.StatusFound) } else { - link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), HttpReq: r}) + link, _, err := fs.Link(ctx, reqPath, model.LinkArgs{IP: utils.ClientIP(r), Header: r.Header, HttpReq: r}) if err != nil { return http.StatusInternalServerError, err }