diff --git a/drivers/123/driver.go b/drivers/123/driver.go index 18fd0a6b..ba3157f9 100644 --- a/drivers/123/driver.go +++ b/drivers/123/driver.go @@ -13,7 +13,6 @@ import ( "os" "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -193,7 +192,7 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr uploadFile = io.MultiReader(buf, stream) } else { // 计算完整文件MD5 - tempFile, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) if err != nil { return err } @@ -201,11 +200,9 @@ func (d *Pan123) Put(ctx context.Context, dstDir model.Obj, stream model.FileStr _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() - - if _, err = io.Copy(io.MultiWriter(tempFile, h), stream); err != nil { + if _, err = io.Copy(h, tempFile); err != nil { return err } - _, err = tempFile.Seek(0, io.SeekStart) if err != nil { return err diff --git a/drivers/189pc/utils.go b/drivers/189pc/utils.go index b75eddf6..2a67106e 100644 --- a/drivers/189pc/utils.go +++ b/drivers/189pc/utils.go @@ -506,14 +506,15 @@ func (y *Yun189PC) CommonUpload(ctx context.Context, dstDir model.Obj, file mode // 快传 func (y *Yun189PC) FastUpload(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up driver.UpdateProgress) (err error) { // 需要获取完整文件md5,必须支持 io.Seek - if _, ok := file.GetReadCloser().(*os.File); !ok { - r, err := utils.CreateTempFile(file) - if err != nil { - return err - } - file.Close() - file.SetReadCloser(r) + tempFile, err := utils.CreateTempFile(file.GetReadCloser()) + if err != nil { + return err } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() + file.SetReadCloser(tempFile) const DEFAULT int64 = 10485760 count := int(math.Ceil(float64(file.GetSize()) / float64(DEFAULT))) diff --git a/drivers/baidu_netdisk/driver.go b/drivers/baidu_netdisk/driver.go index 3779c771..8d9c13c8 100644 --- a/drivers/baidu_netdisk/driver.go +++ b/drivers/baidu_netdisk/driver.go @@ -14,7 +14,6 @@ import ( "strings" "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -119,28 +118,14 @@ func (d *BaiduNetdisk) Remove(ctx context.Context, obj model.Obj) error { } func (d *BaiduNetdisk) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - var tempFile *os.File - var err error - if f, ok := stream.GetReadCloser().(*os.File); ok { - tempFile = f - } else { - tempFile, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return err - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - }() - _, err = io.Copy(tempFile, stream) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } + tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + if err != nil { + return err } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() var Default int64 = 4 * 1024 * 1024 defaultByteData := make([]byte, Default) count := int(math.Ceil(float64(stream.GetSize()) / float64(Default))) diff --git a/drivers/mediatrack/driver.go b/drivers/mediatrack/driver.go index 432a8f35..a5500ce6 100644 --- a/drivers/mediatrack/driver.go +++ b/drivers/mediatrack/driver.go @@ -12,7 +12,6 @@ import ( "strconv" "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -177,7 +176,7 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil if err != nil { return err } - tempFile, err := os.CreateTemp(conf.Conf.TempDir, "file-*") + tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) if err != nil { return err } @@ -185,14 +184,6 @@ func (d *MediaTrack) Put(ctx context.Context, dstDir model.Obj, stream model.Fil _ = tempFile.Close() _ = os.Remove(tempFile.Name()) }() - _, err = io.Copy(tempFile, stream) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } uploader := s3manager.NewUploader(s) input := &s3manager.UploadInput{ Bucket: &resp.Data.Bucket, diff --git a/drivers/pikpak/driver.go b/drivers/pikpak/driver.go index ffdc1680..10da567e 100644 --- a/drivers/pikpak/driver.go +++ b/drivers/pikpak/driver.go @@ -11,7 +11,6 @@ import ( "strings" "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -134,28 +133,14 @@ func (d *PikPak) Remove(ctx context.Context, obj model.Obj) error { } func (d *PikPak) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - var tempFile *os.File - var err error - if f, ok := stream.GetReadCloser().(*os.File); ok { - tempFile = f - } else { - tempFile, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return err - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - }() - _, err = io.Copy(tempFile, stream) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } + tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + if err != nil { + return err } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() // cal sha1 s := sha1.New() _, err = io.Copy(s, tempFile) diff --git a/drivers/quark/driver.go b/drivers/quark/driver.go index 08723791..dd5b286e 100644 --- a/drivers/quark/driver.go +++ b/drivers/quark/driver.go @@ -10,7 +10,6 @@ import ( "os" "github.com/alist-org/alist/v3/drivers/base" - "github.com/alist-org/alist/v3/internal/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/errs" "github.com/alist-org/alist/v3/internal/model" @@ -135,28 +134,14 @@ func (d *Quark) Remove(ctx context.Context, obj model.Obj) error { } func (d *Quark) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error { - var tempFile *os.File - var err error - if f, ok := stream.GetReadCloser().(*os.File); ok { - tempFile = f - } else { - tempFile, err = os.CreateTemp(conf.Conf.TempDir, "file-*") - if err != nil { - return err - } - defer func() { - _ = tempFile.Close() - _ = os.Remove(tempFile.Name()) - }() - _, err = io.Copy(tempFile, stream) - if err != nil { - return err - } - _, err = tempFile.Seek(0, io.SeekStart) - if err != nil { - return err - } + tempFile, err := utils.CreateTempFile(stream.GetReadCloser()) + if err != nil { + return err } + defer func() { + _ = tempFile.Close() + _ = os.Remove(tempFile.Name()) + }() m := md5.New() _, err = io.Copy(m, tempFile) if err != nil { diff --git a/pkg/utils/file.go b/pkg/utils/file.go index 7f9372b8..37702c42 100644 --- a/pkg/utils/file.go +++ b/pkg/utils/file.go @@ -34,16 +34,21 @@ func CreateNestedFile(path string) (*os.File, error) { // CreateTempFile create temp file from io.ReadCloser, and seek to 0 func CreateTempFile(r io.ReadCloser) (*os.File, error) { + if f, ok := r.(*os.File); ok { + return f, nil + } f, err := os.CreateTemp(conf.Conf.TempDir, "file-*") if err != nil { return nil, err } _, err = io.Copy(f, r) if err != nil { + _ = os.Remove(f.Name()) return nil, err } _, err = f.Seek(0, io.SeekStart) if err != nil { + _ = os.Remove(f.Name()) return nil, err } return f, nil