From 779c293f04a387cfef210b83632aeeb7c5fb69de Mon Sep 17 00:00:00 2001 From: KirCute_ECT <951206789@qq.com> Date: Sat, 1 Feb 2025 17:29:55 +0800 Subject: [PATCH] fix(driver): implement canceling and updating progress for putting for some drivers (#7847) * fix(driver): additionally implement canceling and updating progress for putting for some drivers * refactor: add driver archive api into template * fix(123): use built-in MD5 to avoid caching full * . * fix build failed --- drivers/115/driver.go | 4 +- drivers/115/util.go | 31 ++++++-- drivers/123/driver.go | 58 ++++++++------ drivers/123/upload.go | 4 +- drivers/alist_v3/driver.go | 18 +++-- drivers/chaoxing/driver.go | 16 +++- drivers/ftp/driver.go | 14 +++- drivers/github/driver.go | 17 +++-- drivers/github/util.go | 16 ---- drivers/ilanzou/driver.go | 36 +++++---- drivers/ipfs_api/driver.go | 11 ++- drivers/kodbox/driver.go | 16 ++-- drivers/lanzou/driver.go | 10 ++- drivers/mediatrack/driver.go | 25 +++--- drivers/netease_music/driver.go | 2 +- drivers/netease_music/types.go | 16 ++++ drivers/netease_music/upload.go | 13 +++- drivers/netease_music/util.go | 32 ++++++-- drivers/pikpak/driver.go | 4 +- drivers/pikpak/util.go | 32 ++++++-- drivers/quqi/driver.go | 13 +++- drivers/s3/driver.go | 19 +++-- drivers/seafile/driver.go | 12 ++- drivers/template/driver.go | 24 +++++- drivers/thunder/driver.go | 22 +++--- drivers/thunderx/driver.go | 22 +++--- drivers/trainbit/driver.go | 18 ++--- drivers/trainbit/util.go | 11 --- drivers/uss/driver.go | 14 +++- drivers/webdav/driver.go | 16 ++-- drivers/weiyun/driver.go | 130 +++++++++++++++++--------------- drivers/wopan/driver.go | 1 + drivers/yandex_disk/driver.go | 16 ++-- internal/driver/driver.go | 6 +- internal/stream/stream.go | 14 ++++ 35 files changed, 457 insertions(+), 256 deletions(-) diff --git a/drivers/115/driver.go b/drivers/115/driver.go index 0bf8a927..0dcb64d8 100644 --- a/drivers/115/driver.go +++ b/drivers/115/driver.go @@ -215,12 +215,12 @@ func (d *Pan115) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr var uploadResult *UploadResult // 闪传失败,上传 if stream.GetSize() <= 10*utils.MB { // 文件大小小于10MB,改用普通模式上传 - if uploadResult, err = d.UploadByOSS(&fastInfo.UploadOSSParams, stream, dirID); err != nil { + if uploadResult, err = d.UploadByOSS(ctx, &fastInfo.UploadOSSParams, stream, dirID, up); err != nil { return nil, err } } else { // 分片上传 - if uploadResult, err = d.UploadByMultipart(&fastInfo.UploadOSSParams, stream.GetSize(), stream, dirID); err != nil { + if uploadResult, err = d.UploadByMultipart(ctx, &fastInfo.UploadOSSParams, stream.GetSize(), stream, dirID, up); err != nil { return nil, err } } diff --git a/drivers/115/util.go b/drivers/115/util.go index 84cbd88f..4d3cdd93 100644 --- a/drivers/115/util.go +++ b/drivers/115/util.go @@ -2,17 +2,21 @@ package _115 import ( "bytes" + "context" "crypto/md5" "crypto/tls" "encoding/hex" "encoding/json" "fmt" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" "strconv" "strings" "sync" + "sync/atomic" "time" "github.com/alist-org/alist/v3/internal/conf" @@ -271,7 +275,7 @@ func UploadDigestRange(stream model.FileStreamer, rangeSpec string) (result stri } // UploadByOSS use aliyun sdk to upload -func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dirID string) (*UploadResult, error) { +func (c *Pan115) UploadByOSS(ctx context.Context, params *driver115.UploadOSSParams, s model.FileStreamer, dirID string, up driver.UpdateProgress) (*UploadResult, error) { ossToken, err := c.client.GetOSSToken() if err != nil { return nil, err @@ -286,6 +290,13 @@ func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dir } var bodyBytes []byte + r := &stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + } if err = bucket.PutObject(params.Object, r, append( driver115.OssOption(params, ossToken), oss.CallbackResult(&bodyBytes), @@ -301,7 +312,8 @@ func (c *Pan115) UploadByOSS(params *driver115.UploadOSSParams, r io.Reader, dir } // UploadByMultipart upload by mutipart blocks -func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize int64, stream model.FileStreamer, dirID string, opts ...driver115.UploadMultipartOption) (*UploadResult, error) { +func (d *Pan115) UploadByMultipart(ctx context.Context, params *driver115.UploadOSSParams, fileSize int64, s model.FileStreamer, + dirID string, up driver.UpdateProgress, opts ...driver115.UploadMultipartOption) (*UploadResult, error) { var ( chunks []oss.FileChunk parts []oss.UploadPart @@ -313,7 +325,7 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i err error ) - tmpF, err := stream.CacheFullInTempFile() + tmpF, err := s.CacheFullInTempFile() if err != nil { return nil, err } @@ -372,6 +384,7 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i quit <- struct{}{} }() + completedNum := atomic.Int32{} // consumers for i := 0; i < options.ThreadsNum; i++ { go func(threadId int) { @@ -384,6 +397,8 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i var part oss.UploadPart // 出现错误就继续尝试,共尝试3次 for retry := 0; retry < 3; retry++ { select { + case <-ctx.Done(): + break case <-ticker.C: if ossToken, err = d.client.GetOSSToken(); err != nil { // 到时重新获取ossToken errCh <- errors.Wrap(err, "刷新token时出现错误") @@ -396,12 +411,18 @@ func (d *Pan115) UploadByMultipart(params *driver115.UploadOSSParams, fileSize i continue } - if part, err = bucket.UploadPart(imur, bytes.NewBuffer(buf), chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { + if part, err = bucket.UploadPart(imur, &stream.ReaderWithCtx{ + Reader: bytes.NewBuffer(buf), + Ctx: ctx, + }, chunk.Size, chunk.Number, driver115.OssOption(params, ossToken)...); err == nil { break } } if err != nil { - errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err)) + errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", s.GetName(), chunk.Number, err)) + } else { + num := completedNum.Add(1) + up(float64(num) * 100.0 / float64(len(chunks))) } UploadedPartsCh <- part } diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 3828a59d..1bf71ae6 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -185,32 +186,35 @@ func (d *Pan123) Remove(ctx context.Context, obj model.Obj) error { } } -func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - // const DEFAULT int64 = 10485760 - h := md5.New() - // need to calculate md5 of the full content - tempFile, err := stream.CacheFullInTempFile() - if err != nil { - return err +func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + etag := file.GetHash().GetHash(utils.MD5) + if len(etag) < utils.MD5.Width { + // const DEFAULT int64 = 10485760 + h := md5.New() + // need to calculate md5 of the full content + tempFile, err := file.CacheFullInTempFile() + if err != nil { + return err + } + defer func() { + _ = tempFile.Close() + }() + if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { + return err + } + _, err = tempFile.Seek(0, io.SeekStart) + if err != nil { + return err + } + etag = hex.EncodeToString(h.Sum(nil)) } - defer func() { - _ = tempFile.Close() - }() - if _, err = utils.CopyWithBuffer(h, tempFile); err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } - etag := hex.EncodeToString(h.Sum(nil)) data := base.Json{ "driveId": 0, "duplicate": 2, // 2->覆盖 1->重命名 0->默认 "etag": etag, - "fileName": stream.GetName(), + "fileName": file.GetName(), "parentFileId": dstDir.GetID(), - "size": stream.GetSize(), + "size": file.GetSize(), "type": 0, } var resp UploadResp @@ -225,7 +229,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr return nil } if resp.Data.AccessKeyId == "" || resp.Data.SecretAccessKey == "" || resp.Data.SessionToken == "" { - err = d.newUpload(ctx, &resp, stream, tempFile, up) + err = d.newUpload(ctx, &resp, file, up) return err } else { cfg := &aws.Config{ @@ -239,15 +243,21 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr return err } uploader := s3manager.NewUploader(s) - if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { - uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) } input := &s3manager.UploadInput{ Bucket: &resp.Data.Bucket, Key: &resp.Data.Key, - Body: tempFile, + Body: &stream.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: up, + }, } _, err = uploader.UploadWithContext(ctx, input) + if err != nil { + return err + } } _, err = d.Request(UploadComplete, http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ diff --git a/drivers/123/upload.go b/drivers/123/upload.go index 66627b4c..a472df55 100644 --- a/drivers/123/upload.go +++ b/drivers/123/upload.go @@ -69,7 +69,7 @@ func (d *Pan123) completeS3(ctx context.Context, upReq *UploadResp, file model.F return err } -func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, reader io.Reader, up driver.UpdateProgress) error { +func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.FileStreamer, up driver.UpdateProgress) error { chunkSize := int64(1024 * 1024 * 16) // fetch s3 pre signed urls chunkCount := int(math.Ceil(float64(file.GetSize()) / float64(chunkSize))) @@ -103,7 +103,7 @@ func (d *Pan123) newUpload(ctx context.Context, upReq *UploadResp, file model.Fi if j == chunkCount { curSize = file.GetSize() - (int64(chunkCount)-1)*chunkSize } - err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(reader, chunkSize), curSize, false, getS3UploadUrl) + err = d.uploadS3Chunk(ctx, upReq, s3PreSignedUrls, j, end, io.LimitReader(file, chunkSize), curSize, false, getS3UploadUrl) if err != nil { return err } diff --git a/drivers/alist_v3/driver.go b/drivers/alist_v3/driver.go index 894bac64..679285e0 100644 --- a/drivers/alist_v3/driver.go +++ b/drivers/alist_v3/driver.go @@ -3,6 +3,7 @@ package alist_v3 import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "path" @@ -181,25 +182,28 @@ func (d *AListV3) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", stream) +func (d *AListV3) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { + req, err := http.NewRequestWithContext(ctx, http.MethodPut, d.Address+"/api/fs/put", &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }) if err != nil { return err } req.Header.Set("Authorization", d.Token) - req.Header.Set("File-Path", path.Join(dstDir.GetPath(), stream.GetName())) + req.Header.Set("File-Path", path.Join(dstDir.GetPath(), s.GetName())) req.Header.Set("Password", d.MetaPassword) - if md5 := stream.GetHash().GetHash(utils.MD5); len(md5) > 0 { + if md5 := s.GetHash().GetHash(utils.MD5); len(md5) > 0 { req.Header.Set("X-File-Md5", md5) } - if sha1 := stream.GetHash().GetHash(utils.SHA1); len(sha1) > 0 { + if sha1 := s.GetHash().GetHash(utils.SHA1); len(sha1) > 0 { req.Header.Set("X-File-Sha1", sha1) } - if sha256 := stream.GetHash().GetHash(utils.SHA256); len(sha256) > 0 { + if sha256 := s.GetHash().GetHash(utils.SHA256); len(sha256) > 0 { req.Header.Set("X-File-Sha256", sha256) } - req.ContentLength = stream.GetSize() + req.ContentLength = s.GetSize() // client := base.NewHttpClient() // client.Timeout = time.Hour * 6 res, err := base.HttpClient.Do(req) diff --git a/drivers/chaoxing/driver.go b/drivers/chaoxing/driver.go index 360c6e3d..9b526f8a 100644 --- a/drivers/chaoxing/driver.go +++ b/drivers/chaoxing/driver.go @@ -6,6 +6,7 @@ import ( "encoding/json" "errors" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "mime/multipart" "net/http" @@ -215,7 +216,7 @@ func (d *ChaoXing) Remove(ctx context.Context, obj model.Obj) error { return nil } -func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { var resp UploadDataRsp _, err := d.request("https://noteyd.chaoxing.com/pc/files/getUploadConfig", http.MethodGet, func(req *resty.Request) { }, &resp) @@ -227,11 +228,11 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileS } body := &bytes.Buffer{} writer := multipart.NewWriter(body) - filePart, err := writer.CreateFormFile("file", stream.GetName()) + filePart, err := writer.CreateFormFile("file", file.GetName()) if err != nil { return err } - _, err = utils.CopyWithBuffer(filePart, stream) + _, err = utils.CopyWithBuffer(filePart, file) if err != nil { return err } @@ -248,7 +249,14 @@ func (d *ChaoXing) Put(ctx context.Context, dstDir model.Obj, stream model.FileS if err != nil { return err } - req, err := http.NewRequest("POST", "https://pan-yz.chaoxing.com/upload", body) + r := &stream.ReaderUpdatingProgress{ + Reader: &stream.SimpleReaderWithSize{ + Reader: body, + Size: int64(body.Len()), + }, + UpdateProgress: up, + } + req, err := http.NewRequestWithContext(ctx, "POST", "https://pan-yz.chaoxing.com/upload", r) if err != nil { return err } diff --git a/drivers/ftp/driver.go b/drivers/ftp/driver.go index 05b9e49a..b3e95f93 100644 --- a/drivers/ftp/driver.go +++ b/drivers/ftp/driver.go @@ -2,6 +2,7 @@ package ftp import ( "context" + "github.com/alist-org/alist/v3/internal/stream" stdpath "path" "github.com/alist-org/alist/v3/internal/driver" @@ -114,13 +115,18 @@ func (d *FTP) Remove(ctx context.Context, obj model.Obj) error { } } -func (d *FTP) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *FTP) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { if err := d.login(); err != nil { return err } - // TODO: support cancel - path := stdpath.Join(dstDir.GetPath(), stream.GetName()) - return d.conn.Stor(encode(path, d.Encoding), stream) + path := stdpath.Join(dstDir.GetPath(), s.GetName()) + return d.conn.Stor(encode(path, d.Encoding), &stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + }) } var _ driver.Driver = (*FTP)(nil) diff --git a/drivers/github/driver.go b/drivers/github/driver.go index eed06882..996c79c7 100644 --- a/drivers/github/driver.go +++ b/drivers/github/driver.go @@ -16,6 +16,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" log "github.com/sirupsen/logrus" @@ -649,15 +650,15 @@ func (d *Github) createGitKeep(path, message string) error { return nil } -func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress) (string, error) { +func (d *Github) putBlob(ctx context.Context, s model.FileStreamer, up driver.UpdateProgress) (string, error) { beforeContent := "{\"encoding\":\"base64\",\"content\":\"" afterContent := "\"}" - length := int64(len(beforeContent)) + calculateBase64Length(stream.GetSize()) + int64(len(afterContent)) + length := int64(len(beforeContent)) + calculateBase64Length(s.GetSize()) + int64(len(afterContent)) beforeContentReader := strings.NewReader(beforeContent) contentReader, contentWriter := io.Pipe() go func() { encoder := base64.NewEncoder(base64.StdEncoding, contentWriter) - if _, err := utils.CopyWithBuffer(encoder, stream); err != nil { + if _, err := utils.CopyWithBuffer(encoder, s); err != nil { _ = contentWriter.CloseWithError(err) return } @@ -667,10 +668,12 @@ func (d *Github) putBlob(ctx context.Context, stream model.FileStreamer, up driv afterContentReader := strings.NewReader(afterContent) req, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("https://api.github.com/repos/%s/%s/git/blobs", d.Owner, d.Repo), - &ReaderWithProgress{ - Reader: io.MultiReader(beforeContentReader, contentReader, afterContentReader), - Length: length, - Progress: up, + &stream.ReaderUpdatingProgress{ + Reader: &stream.SimpleReaderWithSize{ + Reader: io.MultiReader(beforeContentReader, contentReader, afterContentReader), + Size: length, + }, + UpdateProgress: up, }) if err != nil { return "", err diff --git a/drivers/github/util.go b/drivers/github/util.go index 1e7f7fdb..85bc3cb9 100644 --- a/drivers/github/util.go +++ b/drivers/github/util.go @@ -7,26 +7,10 @@ import ( "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" - "io" - "math" "strings" "text/template" ) -type ReaderWithProgress struct { - Reader io.Reader - Length int64 - Progress func(percentage float64) - offset int64 -} - -func (r *ReaderWithProgress) Read(p []byte) (int, error) { - n, err := r.Reader.Read(p) - r.offset += int64(n) - r.Progress(math.Min(100.0, float64(r.offset)/float64(r.Length)*100.0)) - return n, err -} - type MessageTemplateVars struct { UserName string ObjName string diff --git a/drivers/ilanzou/driver.go b/drivers/ilanzou/driver.go index 90ef7c1a..8681fed4 100644 --- a/drivers/ilanzou/driver.go +++ b/drivers/ilanzou/driver.go @@ -6,6 +6,7 @@ import ( "encoding/base64" "encoding/hex" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -266,10 +267,10 @@ func (d *ILanZou) Remove(ctx context.Context, obj model.Obj) error { const DefaultPartSize = 1024 * 1024 * 8 -func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { +func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { h := md5.New() // need to calculate md5 of the full content - tempFile, err := stream.CacheFullInTempFile() + tempFile, err := s.CacheFullInTempFile() if err != nil { return nil, err } @@ -288,8 +289,8 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt res, err := d.proved("/7n/getUpToken", http.MethodPost, func(req *resty.Request) { req.SetBody(base.Json{ "fileId": "", - "fileName": stream.GetName(), - "fileSize": stream.GetSize()/1024 + 1, + "fileName": s.GetName(), + "fileSize": s.GetSize()/1024 + 1, "folderId": dstDir.GetID(), "md5": etag, "type": 1, @@ -301,13 +302,20 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt upToken := utils.Json.Get(res, "upToken").ToString() now := time.Now() key := fmt.Sprintf("disk/%d/%d/%d/%s/%016d", now.Year(), now.Month(), now.Day(), d.account, now.UnixMilli()) + reader := &stream.ReaderUpdatingProgress{ + Reader: &stream.SimpleReaderWithSize{ + Reader: tempFile, + Size: s.GetSize(), + }, + UpdateProgress: up, + } var token string - if stream.GetSize() <= DefaultPartSize { - res, err := d.upClient.R().SetMultipartFormData(map[string]string{ + if s.GetSize() <= DefaultPartSize { + res, err := d.upClient.R().SetContext(ctx).SetMultipartFormData(map[string]string{ "token": upToken, "key": key, - "fname": stream.GetName(), - }).SetMultipartField("file", stream.GetName(), stream.GetMimetype(), tempFile). + "fname": s.GetName(), + }).SetMultipartField("file", s.GetName(), s.GetMimetype(), reader). Post("https://upload.qiniup.com/") if err != nil { return nil, err @@ -321,10 +329,10 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt } uploadId := utils.Json.Get(res.Body(), "uploadId").ToString() parts := make([]Part, 0) - partNum := (stream.GetSize() + DefaultPartSize - 1) / DefaultPartSize + partNum := (s.GetSize() + DefaultPartSize - 1) / DefaultPartSize for i := 1; i <= int(partNum); i++ { u := fmt.Sprintf("https://upload.qiniup.com/buckets/%s/objects/%s/uploads/%s/%d", d.conf.bucket, keyBase64, uploadId, i) - res, err = d.upClient.R().SetHeader("Authorization", "UpToken "+upToken).SetBody(io.LimitReader(tempFile, DefaultPartSize)).Put(u) + res, err = d.upClient.R().SetContext(ctx).SetHeader("Authorization", "UpToken "+upToken).SetBody(io.LimitReader(reader, DefaultPartSize)).Put(u) if err != nil { return nil, err } @@ -335,7 +343,7 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt }) } res, err = d.upClient.R().SetHeader("Authorization", "UpToken "+upToken).SetBody(base.Json{ - "fnmae": stream.GetName(), + "fnmae": s.GetName(), "parts": parts, }).Post(fmt.Sprintf("https://upload.qiniup.com/buckets/%s/objects/%s/uploads/%s", d.conf.bucket, keyBase64, uploadId)) if err != nil { @@ -373,9 +381,9 @@ func (d *ILanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt ID: strconv.FormatInt(file.FileId, 10), //Path: , Name: file.FileName, - Size: stream.GetSize(), - Modified: stream.ModTime(), - Ctime: stream.CreateTime(), + Size: s.GetSize(), + Modified: s.ModTime(), + Ctime: s.CreateTime(), IsFolder: false, HashInfo: utils.NewHashInfo(utils.MD5, etag), }, nil diff --git a/drivers/ipfs_api/driver.go b/drivers/ipfs_api/driver.go index f6f81305..61886b38 100644 --- a/drivers/ipfs_api/driver.go +++ b/drivers/ipfs_api/driver.go @@ -3,6 +3,7 @@ package ipfs import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "net/url" stdpath "path" "path/filepath" @@ -108,9 +109,15 @@ func (d *IPFS) Remove(ctx context.Context, obj model.Obj) error { return d.sh.FilesRm(ctx, obj.GetPath(), true) } -func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *IPFS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { // TODO upload file, optional - _, err := d.sh.Add(stream, ToFiles(stdpath.Join(dstDir.GetPath(), stream.GetName()))) + _, err := d.sh.Add(&stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + }, ToFiles(stdpath.Join(dstDir.GetPath(), s.GetName()))) return err } diff --git a/drivers/kodbox/driver.go b/drivers/kodbox/driver.go index eb5120a6..ff48ffb2 100644 --- a/drivers/kodbox/driver.go +++ b/drivers/kodbox/driver.go @@ -3,6 +3,7 @@ package kodbox import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/go-resty/resty/v2" "net/http" @@ -225,14 +226,19 @@ func (d *KodBox) Remove(ctx context.Context, obj model.Obj) error { return nil } -func (d *KodBox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { +func (d *KodBox) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { var resp *CommonResp _, err := d.request(http.MethodPost, "/?explorer/upload/fileUpload", func(req *resty.Request) { - req.SetFileReader("file", stream.GetName(), stream). + r := &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + } + req.SetFileReader("file", s.GetName(), r). SetResult(&resp). SetFormData(map[string]string{ "path": dstDir.GetPath(), - }) + }). + SetContext(ctx) }) if err != nil { return nil, err @@ -244,8 +250,8 @@ func (d *KodBox) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr return &model.ObjThumb{ Object: model.Object{ Path: resp.Info.(string), - Name: stream.GetName(), - Size: stream.GetSize(), + Name: s.GetName(), + Size: s.GetSize(), IsFolder: false, Modified: time.Now(), Ctime: time.Now(), diff --git a/drivers/lanzou/driver.go b/drivers/lanzou/driver.go index 9e73f052..90635d16 100644 --- a/drivers/lanzou/driver.go +++ b/drivers/lanzou/driver.go @@ -2,6 +2,7 @@ package lanzou import ( "context" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "github.com/alist-org/alist/v3/drivers/base" @@ -208,7 +209,7 @@ func (d *LanZou) Remove(ctx context.Context, obj model.Obj) error { return errs.NotSupport } -func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { +func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { if d.IsCookie() || d.IsAccount() { var resp RespText[[]FileOrFolder] _, err := d._post(d.BaseUrl+"/html5up.php", func(req *resty.Request) { @@ -217,9 +218,12 @@ func (d *LanZou) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr "vie": "2", "ve": "2", "id": "WU_FILE_0", - "name": stream.GetName(), + "name": s.GetName(), "folder_id_bb_n": dstDir.GetID(), - }).SetFileReader("upload_file", stream.GetName(), stream).SetContext(ctx) + }).SetFileReader("upload_file", s.GetName(), &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }).SetContext(ctx) }, &resp, true) if err != nil { return nil, err diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index f0f1ded0..ed53f8ee 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -5,6 +5,7 @@ import ( "crypto/md5" "encoding/hex" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "strconv" @@ -161,7 +162,7 @@ func (d *MediaTrack) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { src := "assets/" + uuid.New().String() var resp UploadResp _, err := d.request("https://jayce.api.mediatrack.cn/v3/storage/tokens/asset", http.MethodGet, func(req *resty.Request) { @@ -180,7 +181,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - tempFile, err := stream.CacheFullInTempFile() + tempFile, err := file.CacheFullInTempFile() if err != nil { return err } @@ -188,13 +189,19 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil _ = tempFile.Close() }() uploader := s3manager.NewUploader(s) - if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { - uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) } input := &s3manager.UploadInput{ Bucket: &resp.Data.Bucket, Key: &resp.Data.Object, - Body: tempFile, + Body: &stream.ReaderUpdatingProgress{ + Reader: &stream.SimpleReaderWithSize{ + Reader: tempFile, + Size: file.GetSize(), + }, + UpdateProgress: up, + }, } _, err = uploader.UploadWithContext(ctx, input) if err != nil { @@ -213,12 +220,12 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil hash := hex.EncodeToString(h.Sum(nil)) data := base.Json{ "category": 0, - "description": stream.GetName(), + "description": file.GetName(), "hash": hash, - "mime": stream.GetMimetype(), - "size": stream.GetSize(), + "mime": file.GetMimetype(), + "size": file.GetSize(), "src": src, - "title": stream.GetName(), + "title": file.GetName(), "type": 0, } _, err = d.request(url, http.MethodPost, func(req *resty.Request) { diff --git a/drivers/netease_music/driver.go b/drivers/netease_music/driver.go index c0d103de..08460cce 100644 --- a/drivers/netease_music/driver.go +++ b/drivers/netease_music/driver.go @@ -88,7 +88,7 @@ func (d *NeteaseMusic) Remove(ctx context.Context, obj model.Obj) error { } func (d *NeteaseMusic) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - return d.putSongStream(stream) + return d.putSongStream(ctx, stream, up) } func (d *NeteaseMusic) Copy(ctx context.Context, srcObj, dstDir model.Obj) error { diff --git a/drivers/netease_music/types.go b/drivers/netease_music/types.go index 0e156ad1..332f75e9 100644 --- a/drivers/netease_music/types.go +++ b/drivers/netease_music/types.go @@ -2,6 +2,7 @@ package netease_music import ( "context" + "github.com/alist-org/alist/v3/internal/driver" "io" "net/http" "strconv" @@ -71,6 +72,8 @@ func (lrc *LyricObj) getLyricLink() *model.Link { type ReqOption struct { crypto string stream model.FileStreamer + up driver.UpdateProgress + ctx context.Context data map[string]string headers map[string]string cookies []*http.Cookie @@ -113,3 +116,16 @@ func (ch *Characteristic) merge(data map[string]string) map[string]interface{} { } return body } + +type InlineReadCloser struct { + io.Reader + io.Closer +} + +func (rc *InlineReadCloser) Read(p []byte) (int, error) { + return rc.Reader.Read(p) +} + +func (rc *InlineReadCloser) Close() error { + return rc.Closer.Close() +} diff --git a/drivers/netease_music/upload.go b/drivers/netease_music/upload.go index 7f580bd1..3ff6216b 100644 --- a/drivers/netease_music/upload.go +++ b/drivers/netease_music/upload.go @@ -1,8 +1,10 @@ package netease_music import ( + "context" "crypto/md5" "encoding/hex" + "github.com/alist-org/alist/v3/internal/driver" "io" "net/http" "strconv" @@ -47,9 +49,12 @@ func (u *uploader) init(stream model.FileStreamer) error { } h := md5.New() - utils.CopyWithBuffer(h, stream) + _, err := utils.CopyWithBuffer(h, stream) + if err != nil { + return err + } u.md5 = hex.EncodeToString(h.Sum(nil)) - _, err := u.file.Seek(0, io.SeekStart) + _, err = u.file.Seek(0, io.SeekStart) if err != nil { return err } @@ -167,7 +172,7 @@ func (u *uploader) publishInfo(resourceId string) error { return nil } -func (u *uploader) upload(stream model.FileStreamer) error { +func (u *uploader) upload(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress) error { bucket := "jd-musicrep-privatecloud-audio-public" token, err := u.allocToken(bucket) if err != nil { @@ -192,6 +197,8 @@ func (u *uploader) upload(stream model.FileStreamer) error { http.MethodPost, ReqOption{ stream: stream, + up: up, + ctx: ctx, headers: map[string]string{ "x-nos-token": token.token, "Content-Type": "audio/mpeg", diff --git a/drivers/netease_music/util.go b/drivers/netease_music/util.go index 4d0696eb..25efde77 100644 --- a/drivers/netease_music/util.go +++ b/drivers/netease_music/util.go @@ -1,7 +1,9 @@ package netease_music import ( - "io" + "context" + "github.com/alist-org/alist/v3/internal/driver" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "path" "regexp" @@ -58,20 +60,38 @@ func (d *NeteaseMusic) request(url, method string, opt ReqOption) ([]byte, error url = "https://music.163.com/api/linux/forward" } + if opt.ctx != nil { + req.SetContext(opt.ctx) + } if method == http.MethodPost { if opt.stream != nil { + if opt.up == nil { + opt.up = func(_ float64) {} + } req.SetContentLength(true) - req.SetBody(io.ReadCloser(opt.stream)) + req.SetBody(&InlineReadCloser{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: opt.stream, + UpdateProgress: opt.up, + }, + Closer: opt.stream, + }) } else { req.SetFormData(data) } res, err := req.Post(url) - return res.Body(), err + if err != nil { + return nil, err + } + return res.Body(), nil } if method == http.MethodGet { res, err := req.Get(url) - return res.Body(), err + if err != nil { + return nil, err + } + return res.Body(), nil } return nil, errs.NotImplement @@ -206,7 +226,7 @@ func (d *NeteaseMusic) removeSongObj(file model.Obj) error { return err } -func (d *NeteaseMusic) putSongStream(stream model.FileStreamer) error { +func (d *NeteaseMusic) putSongStream(ctx context.Context, stream model.FileStreamer, up driver.UpdateProgress) error { tmp, err := stream.CacheFullInTempFile() if err != nil { return err @@ -231,7 +251,7 @@ func (d *NeteaseMusic) putSongStream(stream model.FileStreamer) error { } if u.meta.needUpload { - err = u.upload(stream) + err = u.upload(ctx, stream, up) if err != nil { return err } diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index 3db273d6..504b1d0e 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -255,10 +255,10 @@ func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr } if stream.GetSize() <= 10*utils.MB { // 文件大小 小于10MB,改用普通模式上传 - return d.UploadByOSS(¶ms, stream, up) + return d.UploadByOSS(ctx, ¶ms, stream, up) } // 分片上传 - return d.UploadByMultipart(¶ms, stream.GetSize(), stream, up) + return d.UploadByMultipart(ctx, ¶ms, stream.GetSize(), stream, up) } // 离线下载文件 diff --git a/drivers/pikpak/util.go b/drivers/pikpak/util.go index e8f3c854..eb96a42a 100644 --- a/drivers/pikpak/util.go +++ b/drivers/pikpak/util.go @@ -2,6 +2,7 @@ package pikpak import ( "bytes" + "context" "crypto/md5" "crypto/sha1" "encoding/hex" @@ -9,6 +10,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/aliyun/aliyun-oss-go-sdk/oss" jsoniter "github.com/json-iterator/go" @@ -19,6 +21,7 @@ import ( "regexp" "strings" "sync" + "sync/atomic" "time" "github.com/alist-org/alist/v3/drivers/base" @@ -417,7 +420,7 @@ func (d *PikPak) refreshCaptchaToken(action string, metas map[string]string) err return nil } -func (d *PikPak) UploadByOSS(params *S3Params, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *PikPak) UploadByOSS(ctx context.Context, params *S3Params, s model.FileStreamer, up driver.UpdateProgress) error { ossClient, err := oss.New(params.Endpoint, params.AccessKeyID, params.AccessKeySecret) if err != nil { return err @@ -427,14 +430,20 @@ func (d *PikPak) UploadByOSS(params *S3Params, stream model.FileStreamer, up dri return err } - err = bucket.PutObject(params.Key, stream, OssOption(params)...) + err = bucket.PutObject(params.Key, &stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + }, OssOption(params)...) if err != nil { return err } return nil } -func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *PikPak) UploadByMultipart(ctx context.Context, params *S3Params, fileSize int64, s model.FileStreamer, up driver.UpdateProgress) error { var ( chunks []oss.FileChunk parts []oss.UploadPart @@ -444,7 +453,7 @@ func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream mode err error ) - tmpF, err := stream.CacheFullInTempFile() + tmpF, err := s.CacheFullInTempFile() if err != nil { return err } @@ -488,6 +497,7 @@ func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream mode quit <- struct{}{} }() + completedNum := atomic.Int32{} // consumers for i := 0; i < ThreadsNum; i++ { go func(threadId int) { @@ -500,6 +510,8 @@ func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream mode var part oss.UploadPart // 出现错误就继续尝试,共尝试3次 for retry := 0; retry < 3; retry++ { select { + case <-ctx.Done(): + break case <-ticker.C: errCh <- errors.Wrap(err, "ossToken 过期") default: @@ -511,12 +523,18 @@ func (d *PikPak) UploadByMultipart(params *S3Params, fileSize int64, stream mode } b := bytes.NewBuffer(buf) - if part, err = bucket.UploadPart(imur, b, chunk.Size, chunk.Number, OssOption(params)...); err == nil { + if part, err = bucket.UploadPart(imur, &stream.ReaderWithCtx{ + Reader: b, + Ctx: ctx, + }, chunk.Size, chunk.Number, OssOption(params)...); err == nil { break } } if err != nil { - errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", stream.GetName(), chunk.Number, err)) + errCh <- errors.Wrap(err, fmt.Sprintf("上传 %s 的第%d个分片时出现错误:%v", s.GetName(), chunk.Number, err)) + } else { + num := completedNum.Add(1) + up(float64(num) * 100.0 / float64(len(chunks))) } UploadedPartsCh <- part } @@ -547,7 +565,7 @@ LOOP: // EOF错误是xml的Unmarshal导致的,响应其实是json格式,所以实际上上传是成功的 if _, err = bucket.CompleteMultipartUpload(imur, parts, OssOption(params)...); err != nil && !errors.Is(err, io.EOF) { // 当文件名含有 &< 这两个字符之一时响应的xml解析会出现错误,实际上上传是成功的 - if filename := filepath.Base(stream.GetName()); !strings.ContainsAny(filename, "&<") { + if filename := filepath.Base(s.GetName()); !strings.ContainsAny(filename, "&<") { return err } } diff --git a/drivers/quqi/driver.go b/drivers/quqi/driver.go index 51e54981..2ab972ca 100644 --- a/drivers/quqi/driver.go +++ b/drivers/quqi/driver.go @@ -3,6 +3,7 @@ package quqi import ( "bytes" "context" + "errors" "io" "strconv" "strings" @@ -11,6 +12,7 @@ import ( "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" + istream "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" "github.com/alist-org/alist/v3/pkg/utils/random" "github.com/aws/aws-sdk-go/aws" @@ -385,9 +387,16 @@ func (d *Quqi) Put(ctx context.Context, dstDir model.Obj, stream model.FileStrea } uploader := s3manager.NewUploader(s) buf := make([]byte, 1024*1024*2) + fup := &istream.ReaderUpdatingProgress{ + Reader: &istream.SimpleReaderWithSize{ + Reader: f, + Size: int64(len(buf)), + }, + UpdateProgress: up, + } for partNumber := int64(1); ; partNumber++ { - n, err := io.ReadFull(f, buf) - if err != nil && err != io.ErrUnexpectedEOF { + n, err := io.ReadFull(fup, buf) + if err != nil && !errors.Is(err, io.ErrUnexpectedEOF) { if err == io.EOF { break } diff --git a/drivers/s3/driver.go b/drivers/s3/driver.go index 82c050a1..a7e924e2 100644 --- a/drivers/s3/driver.go +++ b/drivers/s3/driver.go @@ -163,18 +163,21 @@ func (d *S3) Remove(ctx context.Context, obj model.Obj) error { return d.removeFile(obj.GetPath()) } -func (d *S3) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *S3) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { uploader := s3manager.NewUploader(d.Session) - if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { - uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + if s.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = s.GetSize() / (s3manager.MaxUploadParts - 1) } - key := getKey(stdpath.Join(dstDir.GetPath(), stream.GetName()), false) - contentType := stream.GetMimetype() + key := getKey(stdpath.Join(dstDir.GetPath(), s.GetName()), false) + contentType := s.GetMimetype() log.Debugln("key:", key) input := &s3manager.UploadInput{ - Bucket: &d.Bucket, - Key: &key, - Body: stream, + Bucket: &d.Bucket, + Key: &key, + Body: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, ContentType: &contentType, } _, err := uploader.UploadWithContext(ctx, input) diff --git a/drivers/seafile/driver.go b/drivers/seafile/driver.go index 6d1f16da..f23038d1 100644 --- a/drivers/seafile/driver.go +++ b/drivers/seafile/driver.go @@ -3,6 +3,7 @@ package seafile import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "strings" "time" @@ -197,7 +198,7 @@ func (d *Seafile) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *Seafile) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *Seafile) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { repo, path, err := d.getRepoAndPath(dstDir.GetPath()) if err != nil { return err @@ -214,11 +215,16 @@ func (d *Seafile) Put(ctx context.Context, dstDir model.Obj, stream model.FileSt u := string(res) u = u[1 : len(u)-1] // remove quotes _, err = d.request(http.MethodPost, u, func(req *resty.Request) { - req.SetFileReader("file", stream.GetName(), stream). + r := &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + } + req.SetFileReader("file", s.GetName(), r). SetFormData(map[string]string{ "parent_dir": path, "replace": "1", - }) + }). + SetContext(ctx) }) return err } diff --git a/drivers/template/driver.go b/drivers/template/driver.go index 439f57f3..ff3648db 100644 --- a/drivers/template/driver.go +++ b/drivers/template/driver.go @@ -66,11 +66,33 @@ func (d *Template) Remove(ctx context.Context, obj model.Obj) error { return errs.NotImplement } -func (d *Template) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { +func (d *Template) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (model.Obj, error) { // TODO upload file, optional return nil, errs.NotImplement } +func (d *Template) GetArchiveMeta(ctx context.Context, obj model.Obj, args model.ArchiveArgs) (model.ArchiveMeta, error) { + // TODO get archive file meta-info, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Template) ListArchive(ctx context.Context, obj model.Obj, args model.ArchiveInnerArgs) ([]model.Obj, error) { + // TODO list args.InnerPath in the archive obj, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Template) Extract(ctx context.Context, obj model.Obj, args model.ArchiveInnerArgs) (*model.Link, error) { + // TODO return link of file args.InnerPath in the archive obj, return errs.NotImplement to use an internal archive tool, optional + return nil, errs.NotImplement +} + +func (d *Template) ArchiveDecompress(ctx context.Context, srcObj, dstDir model.Obj, args model.ArchiveDecompressArgs) ([]model.Obj, error) { + // TODO extract args.InnerPath path in the archive srcObj to the dstDir location, optional + // a folder with the same name as the archive file needs to be created to store the extracted results if args.PutIntoNewDir + // return errs.NotImplement to use an internal archive tool + return nil, errs.NotImplement +} + //func (d *Template) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { // return nil, errs.NotSupport //} diff --git a/drivers/thunder/driver.go b/drivers/thunder/driver.go index 8403f261..1b7f0af6 100644 --- a/drivers/thunder/driver.go +++ b/drivers/thunder/driver.go @@ -3,6 +3,7 @@ package thunder import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "strconv" "strings" @@ -332,16 +333,16 @@ func (xc *XunLeiCommon) Remove(ctx context.Context, obj model.Obj) error { return err } -func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - hi := stream.GetHash() +func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + hi := file.GetHash() gcid := hi.GetHash(hash_extend.GCID) if len(gcid) < hash_extend.GCID.Width { - tFile, err := stream.CacheFullInTempFile() + tFile, err := file.CacheFullInTempFile() if err != nil { return err } - gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) if err != nil { return err } @@ -353,8 +354,8 @@ func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model. r.SetBody(&base.Json{ "kind": FILE, "parent_id": dstDir.GetID(), - "name": stream.GetName(), - "size": stream.GetSize(), + "name": file.GetName(), + "size": file.GetSize(), "hash": gcid, "upload_type": UPLOAD_TYPE_RESUMABLE, }) @@ -375,14 +376,17 @@ func (xc *XunLeiCommon) Put(ctx context.Context, dstDir model.Obj, stream model. return err } uploader := s3manager.NewUploader(s) - if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { - uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) } _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: aws.String(param.Bucket), Key: aws.String(param.Key), Expires: aws.Time(param.Expiration), - Body: stream, + Body: &stream.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: up, + }, }) return err } diff --git a/drivers/thunderx/driver.go b/drivers/thunderx/driver.go index b9ee668c..93e07ca9 100644 --- a/drivers/thunderx/driver.go +++ b/drivers/thunderx/driver.go @@ -8,6 +8,7 @@ import ( "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/op" + "github.com/alist-org/alist/v3/internal/stream" "github.com/alist-org/alist/v3/pkg/utils" hash_extend "github.com/alist-org/alist/v3/pkg/utils/hash" "github.com/aws/aws-sdk-go/aws" @@ -363,16 +364,16 @@ func (xc *XunLeiXCommon) Remove(ctx context.Context, obj model.Obj) error { return err } -func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - hi := stream.GetHash() +func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) error { + hi := file.GetHash() gcid := hi.GetHash(hash_extend.GCID) if len(gcid) < hash_extend.GCID.Width { - tFile, err := stream.CacheFullInTempFile() + tFile, err := file.CacheFullInTempFile() if err != nil { return err } - gcid, err = utils.HashFile(hash_extend.GCID, tFile, stream.GetSize()) + gcid, err = utils.HashFile(hash_extend.GCID, tFile, file.GetSize()) if err != nil { return err } @@ -384,8 +385,8 @@ func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, stream model r.SetBody(&base.Json{ "kind": FILE, "parent_id": dstDir.GetID(), - "name": stream.GetName(), - "size": stream.GetSize(), + "name": file.GetName(), + "size": file.GetSize(), "hash": gcid, "upload_type": UPLOAD_TYPE_RESUMABLE, }) @@ -406,14 +407,17 @@ func (xc *XunLeiXCommon) Put(ctx context.Context, dstDir model.Obj, stream model return err } uploader := s3manager.NewUploader(s) - if stream.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { - uploader.PartSize = stream.GetSize() / (s3manager.MaxUploadParts - 1) + if file.GetSize() > s3manager.MaxUploadParts*s3manager.DefaultUploadPartSize { + uploader.PartSize = file.GetSize() / (s3manager.MaxUploadParts - 1) } _, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{ Bucket: aws.String(param.Bucket), Key: aws.String(param.Key), Expires: aws.Time(param.Expiration), - Body: stream, + Body: &stream.ReaderUpdatingProgress{ + Reader: file, + UpdateProgress: up, + }, }) return err } diff --git a/drivers/trainbit/driver.go b/drivers/trainbit/driver.go index 795b2fb8..2b1815ed 100644 --- a/drivers/trainbit/driver.go +++ b/drivers/trainbit/driver.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "io" "net/http" "net/url" @@ -114,23 +115,18 @@ func (d *Trainbit) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *Trainbit) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { endpoint, _ := url.Parse("https://tb28.trainbit.com/api/upload/send_raw/") query := &url.Values{} query.Add("q", strings.Split(dstDir.GetID(), "_")[1]) query.Add("guid", guid) - query.Add("name", url.QueryEscape(local2provider(stream.GetName(), false)+".")) + query.Add("name", url.QueryEscape(local2provider(s.GetName(), false)+".")) endpoint.RawQuery = query.Encode() - var total int64 - total = 0 - progressReader := &ProgressReader{ - stream, - func(byteNum int) { - total += int64(byteNum) - up(float64(total) / float64(stream.GetSize()) * 100) - }, + progressReader := &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, } - req, err := http.NewRequest(http.MethodPost, endpoint.String(), progressReader) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint.String(), progressReader) if err != nil { return err } diff --git a/drivers/trainbit/util.go b/drivers/trainbit/util.go index afc111a8..486e8851 100644 --- a/drivers/trainbit/util.go +++ b/drivers/trainbit/util.go @@ -13,17 +13,6 @@ import ( "github.com/alist-org/alist/v3/internal/model" ) -type ProgressReader struct { - io.Reader - reporter func(byteNum int) -} - -func (progressReader *ProgressReader) Read(data []byte) (int, error) { - byteNum, err := progressReader.Reader.Read(data) - progressReader.reporter(byteNum) - return byteNum, err -} - func get(url string, apiKey string, AUSHELLPORTAL string) (*http.Response, error) { req, err := http.NewRequest(http.MethodGet, url, nil) if err != nil { diff --git a/drivers/uss/driver.go b/drivers/uss/driver.go index 447515d8..3c54797c 100644 --- a/drivers/uss/driver.go +++ b/drivers/uss/driver.go @@ -3,6 +3,7 @@ package uss import ( "context" "fmt" + "github.com/alist-org/alist/v3/internal/stream" "net/url" "path" "strings" @@ -122,11 +123,16 @@ func (d *USS) Remove(ctx context.Context, obj model.Obj) error { }) } -func (d *USS) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - // TODO not support cancel?? +func (d *USS) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { return d.client.Put(&upyun.PutObjectConfig{ - Path: getKey(path.Join(dstDir.GetPath(), stream.GetName()), false), - Reader: stream, + Path: getKey(path.Join(dstDir.GetPath(), s.GetName()), false), + Reader: &stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + }, }) } diff --git a/drivers/webdav/driver.go b/drivers/webdav/driver.go index b402b1db..35240c49 100644 --- a/drivers/webdav/driver.go +++ b/drivers/webdav/driver.go @@ -2,6 +2,7 @@ package webdav import ( "context" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "os" "path" @@ -93,13 +94,18 @@ func (d *WebDav) Remove(ctx context.Context, obj model.Obj) error { return d.client.RemoveAll(getPath(obj)) } -func (d *WebDav) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *WebDav) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { callback := func(r *http.Request) { - r.Header.Set("Content-Type", stream.GetMimetype()) - r.ContentLength = stream.GetSize() + r.Header.Set("Content-Type", s.GetMimetype()) + r.ContentLength = s.GetSize() } - // TODO: support cancel - err := d.client.WriteStream(path.Join(dstDir.GetPath(), stream.GetName()), stream, 0644, callback) + err := d.client.WriteStream(path.Join(dstDir.GetPath(), s.GetName()), &stream.ReaderWithCtx{ + Reader: &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }, + Ctx: ctx, + }, 0644, callback) return err } diff --git a/drivers/weiyun/driver.go b/drivers/weiyun/driver.go index e6d5897c..59bd7237 100644 --- a/drivers/weiyun/driver.go +++ b/drivers/weiyun/driver.go @@ -7,6 +7,7 @@ import ( "math" "net/http" "strconv" + "sync/atomic" "time" "github.com/alist-org/alist/v3/drivers/base" @@ -311,77 +312,82 @@ func (d *WeiYun) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr // NOTE: // 秒传需要sha1最后一个状态,但sha1无法逆运算需要读完整个文件(或许可以??) // 服务器支持上传进度恢复,不需要额外实现 - if folder, ok := dstDir.(*Folder); ok { - file, err := stream.CacheFullInTempFile() - if err != nil { - return nil, err - } + var folder *Folder + var ok bool + if folder, ok = dstDir.(*Folder); !ok { + return nil, errs.NotSupport + } + file, err := stream.CacheFullInTempFile() + if err != nil { + return nil, err + } - // step 1. - preData, err := d.client.PreUpload(ctx, weiyunsdkgo.UpdloadFileParam{ - PdirKey: folder.GetPKey(), - DirKey: folder.DirKey, + // step 1. + preData, err := d.client.PreUpload(ctx, weiyunsdkgo.UpdloadFileParam{ + PdirKey: folder.GetPKey(), + DirKey: folder.DirKey, - FileName: stream.GetName(), - FileSize: stream.GetSize(), - File: file, + FileName: stream.GetName(), + FileSize: stream.GetSize(), + File: file, - ChannelCount: 4, - FileExistOption: 1, - }) - if err != nil { - return nil, err - } + ChannelCount: 4, + FileExistOption: 1, + }) + if err != nil { + return nil, err + } - // not fast upload - if !preData.FileExist { - // step.2 增加上传通道 - if len(preData.ChannelList) < d.uploadThread { - newCh, err := d.client.AddUploadChannel(len(preData.ChannelList), d.uploadThread, preData.UploadAuthData) - if err != nil { - return nil, err - } - preData.ChannelList = append(preData.ChannelList, newCh.AddChannels...) - } - // step.3 上传 - threadG, upCtx := errgroup.NewGroupWithContext(ctx, len(preData.ChannelList), - retry.Attempts(3), - retry.Delay(time.Second), - retry.DelayType(retry.BackOffDelay)) - - for _, channel := range preData.ChannelList { - if utils.IsCanceled(upCtx) { - break - } - - var channel = channel - threadG.Go(func(ctx context.Context) error { - for { - channel.Len = int(math.Min(float64(stream.GetSize()-channel.Offset), float64(channel.Len))) - upData, err := d.client.UploadFile(upCtx, channel, preData.UploadAuthData, - io.NewSectionReader(file, channel.Offset, int64(channel.Len))) - if err != nil { - return err - } - // 上传完成 - if upData.UploadState != 1 { - return nil - } - channel = upData.Channel - } - }) - } - if err = threadG.Wait(); err != nil { + // not fast upload + if !preData.FileExist { + // step.2 增加上传通道 + if len(preData.ChannelList) < d.uploadThread { + newCh, err := d.client.AddUploadChannel(len(preData.ChannelList), d.uploadThread, preData.UploadAuthData) + if err != nil { return nil, err } + preData.ChannelList = append(preData.ChannelList, newCh.AddChannels...) } + // step.3 上传 + threadG, upCtx := errgroup.NewGroupWithContext(ctx, len(preData.ChannelList), + retry.Attempts(3), + retry.Delay(time.Second), + retry.DelayType(retry.BackOffDelay)) - return &File{ - PFolder: folder, - File: preData.File, - }, nil + total := atomic.Int64{} + for _, channel := range preData.ChannelList { + if utils.IsCanceled(upCtx) { + break + } + + var channel = channel + threadG.Go(func(ctx context.Context) error { + for { + channel.Len = int(math.Min(float64(stream.GetSize()-channel.Offset), float64(channel.Len))) + upData, err := d.client.UploadFile(upCtx, channel, preData.UploadAuthData, + io.NewSectionReader(file, channel.Offset, int64(channel.Len))) + if err != nil { + return err + } + cur := total.Add(int64(channel.Len)) + up(float64(cur) * 100.0 / float64(stream.GetSize())) + // 上传完成 + if upData.UploadState != 1 { + return nil + } + channel = upData.Channel + } + }) + } + if err = threadG.Wait(); err != nil { + return nil, err + } } - return nil, errs.NotSupport + + return &File{ + PFolder: folder, + File: preData.File, + }, nil } // func (d *WeiYun) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) { diff --git a/drivers/wopan/driver.go b/drivers/wopan/driver.go index bccce4b1..86093fc1 100644 --- a/drivers/wopan/driver.go +++ b/drivers/wopan/driver.go @@ -161,6 +161,7 @@ func (d *Wopan) Put(ctx context.Context, dstDir model.Obj, stream model.FileStre OnProgress: func(current, total int64) { up(100 * float64(current) / float64(total)) }, + Ctx: ctx, }) return err } diff --git a/drivers/yandex_disk/driver.go b/drivers/yandex_disk/driver.go index 5af9f2e4..fe858519 100644 --- a/drivers/yandex_disk/driver.go +++ b/drivers/yandex_disk/driver.go @@ -2,6 +2,7 @@ package yandex_disk import ( "context" + "github.com/alist-org/alist/v3/internal/stream" "net/http" "path" "strconv" @@ -106,25 +107,30 @@ func (d *YandexDisk) Remove(ctx context.Context, obj model.Obj) error { return err } -func (d *YandexDisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { +func (d *YandexDisk) Put(ctx context.Context, dstDir model.Obj, s model.FileStreamer, up driver.UpdateProgress) error { var resp UploadResp _, err := d.request("/upload", http.MethodGet, func(req *resty.Request) { req.SetQueryParams(map[string]string{ - "path": path.Join(dstDir.GetPath(), stream.GetName()), + "path": path.Join(dstDir.GetPath(), s.GetName()), "overwrite": "true", }) }, &resp) if err != nil { return err } - req, err := http.NewRequest(resp.Method, resp.Href, stream) + req, err := http.NewRequestWithContext(ctx, resp.Method, resp.Href, &stream.ReaderUpdatingProgress{ + Reader: s, + UpdateProgress: up, + }) if err != nil { return err } - req = req.WithContext(ctx) - req.Header.Set("Content-Length", strconv.FormatInt(stream.GetSize(), 10)) + req.Header.Set("Content-Length", strconv.FormatInt(s.GetSize(), 10)) req.Header.Set("Content-Type", "application/octet-stream") res, err := base.HttpClient.Do(req) + if err != nil { + return err + } _ = res.Body.Close() return err } diff --git a/internal/driver/driver.go b/internal/driver/driver.go index 09fd42e7..292f8e6a 100644 --- a/internal/driver/driver.go +++ b/internal/driver/driver.go @@ -77,7 +77,7 @@ type Remove interface { } type Put interface { - Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up UpdateProgress) error + Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) error } type PutURL interface { @@ -113,7 +113,7 @@ type CopyResult interface { } type PutResult interface { - Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up UpdateProgress) (model.Obj, error) + Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) (model.Obj, error) } type PutURLResult interface { @@ -159,7 +159,7 @@ type ArchiveDecompressResult interface { ArchiveDecompress(ctx context.Context, srcObj, dstDir model.Obj, args model.ArchiveDecompressArgs) ([]model.Obj, error) } -type UpdateProgress model.UpdateProgress +type UpdateProgress = model.UpdateProgress type Progress struct { Total int64 diff --git a/internal/stream/stream.go b/internal/stream/stream.go index 1962fb46..74646bfb 100644 --- a/internal/stream/stream.go +++ b/internal/stream/stream.go @@ -562,3 +562,17 @@ func (f *FileReadAtSeeker) Seek(offset int64, whence int) (int64, error) { func (f *FileReadAtSeeker) Close() error { return f.ss.Close() } + +type ReaderWithCtx struct { + io.Reader + Ctx context.Context +} + +func (r *ReaderWithCtx) Read(p []byte) (n int, err error) { + select { + case <-r.Ctx.Done(): + return 0, r.Ctx.Err() + default: + return r.Reader.Read(p) + } +}