diff --git a/internal/aria2/add.go b/internal/aria2/add.go index 5e73d7fe..5db998f0 100644 --- a/internal/aria2/add.go +++ b/internal/aria2/add.go @@ -6,22 +6,16 @@ import ( "github.com/alist-org/alist/v3/conf" "github.com/alist-org/alist/v3/internal/driver" "github.com/alist-org/alist/v3/internal/fs" - "github.com/alist-org/alist/v3/internal/model" "github.com/alist-org/alist/v3/internal/operations" "github.com/alist-org/alist/v3/pkg/task" "github.com/google/uuid" "github.com/pkg/errors" - "mime" - "os" - "path" "path/filepath" - "strconv" - "time" ) -func AddURI(ctx context.Context, uri string, dstPath string, parentPath string) error { +func AddURI(ctx context.Context, uri string, dstDirPath string) error { // check account - account, actualParentPath, err := operations.GetAccountAndActualPath(parentPath) + account, dstDirActualPath, err := operations.GetAccountAndActualPath(dstDirPath) if err != nil { return errors.WithMessage(err, "failed get account") } @@ -30,7 +24,7 @@ func AddURI(ctx context.Context, uri string, dstPath string, parentPath string) return errors.WithStack(fs.ErrUploadNotSupported) } // check path is valid - obj, err := operations.Get(ctx, account, actualParentPath) + obj, err := operations.Get(ctx, account, dstDirActualPath) if err != nil { if !errors.Is(errors.Cause(err), driver.ErrorObjectNotFound) { return errors.WithMessage(err, "failed get object") @@ -51,99 +45,17 @@ func AddURI(ctx context.Context, uri string, dstPath string, parentPath string) return errors.Wrapf(err, "failed to add uri %s", uri) } // TODO add to task manager - Aria2TaskManager.Submit(task.WithCancelCtx(&task.Task[string, OfflineDownload]{ + TaskManager.Submit(task.WithCancelCtx(&task.Task[string, interface{}]{ ID: gid, - Name: fmt.Sprintf("download %s to [%s](%s)", uri, account.GetAccount().VirtualPath, actualParentPath), - Func: func(tsk *task.Task[string, OfflineDownload]) error { - defer func() { - notify.Signals.Delete(gid) - // clear temp dir - _ = os.RemoveAll(tempDir) - }() - c := make(chan int) - notify.Signals.Store(gid, c) - retried := 0 - for { - select { - case <-tsk.Ctx.Done(): - _, err := client.Remove(gid) - if err != nil { - return err - } - case status := <-c: - switch status { - case Completed: - return nil - default: - info, err := client.TellStatus(gid) - if err != nil { - retried++ - } - if retried > 5 { - return errors.Errorf("failed to get status of %s, retried %d times", gid, retried) - } - retried = 0 - if len(info.FollowedBy) != 0 { - gid = info.FollowedBy[0] - - } - // update download status - total, err := strconv.ParseUint(info.TotalLength, 10, 64) - if err != nil { - total = 0 - } - downloaded, err := strconv.ParseUint(info.CompletedLength, 10, 64) - if err != nil { - downloaded = 0 - } - tsk.SetProgress(int(float64(downloaded) / float64(total))) - switch info.Status { - case "complete": - // get files - files, err := client.GetFiles(gid) - if err != nil { - return errors.Wrapf(err, "failed to get files of %s", gid) - } - // upload files - for _, file := range files { - size, _ := strconv.ParseUint(file.Length, 10, 64) - f, err := os.Open(file.Path) - mimetype := mime.TypeByExtension(path.Ext(file.Path)) - if mimetype == "" { - mimetype = "application/octet-stream" - } - if err != nil { - return errors.Wrapf(err, "failed to open file %s", file.Path) - } - stream := model.FileStream{ - Obj: model.Object{ - Name: path.Base(file.Path), - Size: size, - Modified: time.Now(), - IsFolder: false, - }, - ReadCloser: f, - Mimetype: "", - } - return operations.Put(tsk.Ctx, account, actualParentPath, stream, tsk.SetProgress) - } - case "error": - return errors.Errorf("failed to download %s, error: %s", gid, info.ErrorMessage) - case "active", "waiting", "paused": - // do nothing - case "removed": - return errors.Errorf("failed to download %s, removed", gid) - default: - return errors.Errorf("failed to download %s, unknown status %s", gid, info.Status) - } - } - } + Name: fmt.Sprintf("download %s to [%s](%s)", uri, account.GetAccount().VirtualPath, dstDirActualPath), + Func: func(tsk *task.Task[string, interface{}]) error { + m := &Monitor{ + tsk: tsk, + tempDir: tempDir, + retried: 0, + dstDirPath: dstDirPath, } - }, - Data: OfflineDownload{ - Gid: gid, - URI: uri, - DstPath: dstPath, + return m.Loop() }, })) return nil diff --git a/internal/aria2/aria2.go b/internal/aria2/aria2.go index 9bbb03ab..1aeaecf6 100644 --- a/internal/aria2/aria2.go +++ b/internal/aria2/aria2.go @@ -8,7 +8,7 @@ import ( "time" ) -var Aria2TaskManager = task.NewTaskManager[string, OfflineDownload](3) +var TaskManager = task.NewTaskManager[string, interface{}](3) var notify = NewNotify() var client rpc.Client diff --git a/internal/aria2/monitor.go b/internal/aria2/monitor.go new file mode 100644 index 00000000..9e8ca9d3 --- /dev/null +++ b/internal/aria2/monitor.go @@ -0,0 +1,149 @@ +package aria2 + +import ( + "fmt" + "github.com/alist-org/alist/v3/internal/model" + "github.com/alist-org/alist/v3/internal/operations" + "github.com/alist-org/alist/v3/pkg/task" + "github.com/pkg/errors" + log "github.com/sirupsen/logrus" + "mime" + "os" + "path" + "strconv" + "sync" + "sync/atomic" + "time" +) + +type Monitor struct { + tsk *task.Task[string, interface{}] + tempDir string + retried int + c chan int + dstDirPath string +} + +func (m *Monitor) Loop() error { + defer func() { + notify.Signals.Delete(m.tsk.ID) + // clear temp dir, should do while complete + //_ = os.RemoveAll(m.tempDir) + }() + m.c = make(chan int) + notify.Signals.Store(m.tsk.ID, m.c) + for { + select { + case <-m.tsk.Ctx.Done(): + _, err := client.Remove(m.tsk.ID) + return err + case <-m.c: + ok, err := m.Update() + if ok { + return err + } + case <-time.After(time.Second * 5): + ok, err := m.Update() + if ok { + return err + } + } + } +} + +func (m *Monitor) Update() (bool, error) { + info, err := client.TellStatus(m.tsk.ID) + if err != nil { + m.retried++ + } + if m.retried > 5 { + return true, errors.Errorf("failed to get status of %s, retried %d times", m.tsk.ID, m.retried) + } + m.retried = 0 + if len(info.FollowedBy) != 0 { + gid := info.FollowedBy[0] + notify.Signals.Delete(m.tsk.ID) + m.tsk.ID = gid + notify.Signals.Store(gid, m.c) + } + // update download status + total, err := strconv.ParseUint(info.TotalLength, 10, 64) + if err != nil { + total = 0 + } + downloaded, err := strconv.ParseUint(info.CompletedLength, 10, 64) + if err != nil { + downloaded = 0 + } + m.tsk.SetProgress(int(float64(downloaded) / float64(total))) + switch info.Status { + case "complete": + err := m.Complete() + return true, errors.WithMessage(err, "failed to transfer file") + case "error": + return true, errors.Errorf("failed to download %s, error: %s", m.tsk.ID, info.ErrorMessage) + case "active", "waiting", "paused": + m.tsk.SetStatus("aria2: " + info.Status) + return false, nil + case "removed": + return true, errors.Errorf("failed to download %s, removed", m.tsk.ID) + default: + return true, errors.Errorf("failed to download %s, unknown status %s", m.tsk.ID, info.Status) + } +} + +var transferTaskManager = task.NewTaskManager[uint64, interface{}](3, func(k *uint64) { + atomic.AddUint64(k, 1) +}) + +func (m *Monitor) Complete() error { + // check dstDir again + account, dstDirActualPath, err := operations.GetAccountAndActualPath(m.dstDirPath) + if err != nil { + return errors.WithMessage(err, "failed get account") + } + // get files + files, err := client.GetFiles(m.tsk.ID) + if err != nil { + return errors.Wrapf(err, "failed to get files of %s", m.tsk.ID) + } + // upload files + var wg sync.WaitGroup + wg.Add(len(files)) + go func() { + wg.Wait() + err := os.RemoveAll(m.tempDir) + if err != nil { + log.Errorf("failed to remove aria2 temp dir: %+v", err.Error()) + } + }() + for _, file := range files { + transferTaskManager.Submit(task.WithCancelCtx[uint64](&task.Task[uint64, interface{}]{ + Name: fmt.Sprintf("transfer %s to %s", file.Path, m.dstDirPath), + Func: func(tsk *task.Task[uint64, interface{}]) error { + defer wg.Done() + size, _ := strconv.ParseUint(file.Length, 10, 64) + mimetype := mime.TypeByExtension(path.Ext(file.Path)) + if mimetype == "" { + mimetype = "application/octet-stream" + } + f, err := os.Open(file.Path) + if err != nil { + return errors.Wrapf(err, "failed to open file %s", file.Path) + } + stream := model.FileStream{ + Obj: model.Object{ + Name: path.Base(file.Path), + Size: size, + Modified: time.Now(), + IsFolder: false, + }, + ReadCloser: f, + Mimetype: "", + } + return operations.Put(tsk.Ctx, account, dstDirActualPath, stream, tsk.SetProgress) + }, + })) + } + return nil +} diff --git a/internal/aria2/offlinedownload.go b/internal/aria2/offlinedownload.go deleted file mode 100644 index 2cbd0153..00000000 --- a/internal/aria2/offlinedownload.go +++ /dev/null @@ -1,7 +0,0 @@ -package aria2 - -type OfflineDownload struct { - Gid string - DstPath string - URI string -}