feat(traffic): support limit task worker count & file stream rate (#7948)

* feat: set task workers num & client stream rate limit

* feat: server stream rate limit

* upgrade xhofe/tache

* .
This commit is contained in:
KirCute_ECT
2025-02-16 12:22:11 +08:00
committed by GitHub
parent 399336b33c
commit 3b71500f23
79 changed files with 803 additions and 327 deletions

View File

@ -11,6 +11,7 @@ import (
"github.com/alist-org/alist/v3/pkg/utils/random"
"github.com/pkg/errors"
"gorm.io/gorm"
"strconv"
)
var initialSettingItems []model.SettingItem
@ -191,12 +192,12 @@ func InitialSettings() []model.SettingItem {
{Key: conf.LdapDefaultPermission, Value: "0", Type: conf.TypeNumber, Group: model.LDAP, Flag: model.PRIVATE},
{Key: conf.LdapLoginTips, Value: "login with ldap", Type: conf.TypeString, Group: model.LDAP, Flag: model.PUBLIC},
//s3 settings
// s3 settings
{Key: conf.S3AccessKeyId, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE},
{Key: conf.S3SecretAccessKey, Value: "", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE},
{Key: conf.S3Buckets, Value: "[]", Type: conf.TypeString, Group: model.S3, Flag: model.PRIVATE},
//ftp settings
// ftp settings
{Key: conf.FTPPublicHost, Value: "127.0.0.1", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE},
{Key: conf.FTPPasvPortMap, Value: "", Type: conf.TypeText, Group: model.FTP, Flag: model.PRIVATE},
{Key: conf.FTPProxyUserAgent, Value: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " +
@ -205,6 +206,18 @@ func InitialSettings() []model.SettingItem {
{Key: conf.FTPImplicitTLS, Value: "false", Type: conf.TypeBool, Group: model.FTP, Flag: model.PRIVATE},
{Key: conf.FTPTLSPrivateKeyPath, Value: "", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE},
{Key: conf.FTPTLSPublicCertPath, Value: "", Type: conf.TypeString, Group: model.FTP, Flag: model.PRIVATE},
// traffic settings
{Key: conf.TaskOfflineDownloadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Download.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.TaskOfflineDownloadTransferThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Transfer.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.TaskUploadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Upload.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.TaskCopyThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Copy.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.TaskDecompressDownloadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.Decompress.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.TaskDecompressUploadThreadsNum, Value: strconv.Itoa(conf.Conf.Tasks.DecompressUpload.Workers), Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.StreamMaxClientDownloadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.StreamMaxClientUploadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.StreamMaxServerDownloadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
{Key: conf.StreamMaxServerUploadSpeed, Value: "-1", Type: conf.TypeNumber, Group: model.TRAFFIC, Flag: model.PRIVATE},
}
initialSettingItems = append(initialSettingItems, tool.Tools.Items()...)
if flags.Dev {

View File

@ -0,0 +1,53 @@
package bootstrap
import (
"context"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/alist-org/alist/v3/internal/stream"
"golang.org/x/time/rate"
)
type blockBurstLimiter struct {
*rate.Limiter
}
func (l blockBurstLimiter) WaitN(ctx context.Context, total int) error {
for total > 0 {
n := l.Burst()
if l.Limiter.Limit() == rate.Inf || n > total {
n = total
}
err := l.Limiter.WaitN(ctx, n)
if err != nil {
return err
}
total -= n
}
return nil
}
func streamFilterNegative(limit int) (rate.Limit, int) {
if limit < 0 {
return rate.Inf, 0
}
return rate.Limit(limit) * 1024.0, limit * 1024
}
func initLimiter(limiter *stream.Limiter, s string) {
clientDownLimit, burst := streamFilterNegative(setting.GetInt(s, -1))
*limiter = blockBurstLimiter{Limiter: rate.NewLimiter(clientDownLimit, burst)}
op.RegisterSettingChangingCallback(func() {
newLimit, newBurst := streamFilterNegative(setting.GetInt(s, -1))
(*limiter).SetLimit(newLimit)
(*limiter).SetBurst(newBurst)
})
}
func InitStreamLimit() {
initLimiter(&stream.ClientDownloadLimit, conf.StreamMaxClientDownloadSpeed)
initLimiter(&stream.ClientUploadLimit, conf.StreamMaxClientUploadSpeed)
initLimiter(&stream.ServerDownloadLimit, conf.StreamMaxServerDownloadSpeed)
initLimiter(&stream.ServerUploadLimit, conf.StreamMaxServerUploadSpeed)
}

View File

@ -5,17 +5,44 @@ import (
"github.com/alist-org/alist/v3/internal/db"
"github.com/alist-org/alist/v3/internal/fs"
"github.com/alist-org/alist/v3/internal/offline_download/tool"
"github.com/alist-org/alist/v3/internal/op"
"github.com/alist-org/alist/v3/internal/setting"
"github.com/xhofe/tache"
)
func taskFilterNegative(num int) int64 {
if num < 0 {
num = 0
}
return int64(num)
}
func InitTaskManager() {
fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(conf.Conf.Tasks.Upload.Workers), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) //upload will not support persist
fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(conf.Conf.Tasks.Copy.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant), db.UpdateTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry))
tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(conf.Conf.Tasks.Download.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant), db.UpdateTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry))
tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(conf.Conf.Tasks.Transfer.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant), db.UpdateTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry))
fs.UploadTaskManager = tache.NewManager[*fs.UploadTask](tache.WithWorks(setting.GetInt(conf.TaskUploadThreadsNum, conf.Conf.Tasks.Upload.Workers)), tache.WithMaxRetry(conf.Conf.Tasks.Upload.MaxRetry)) //upload will not support persist
op.RegisterSettingChangingCallback(func() {
fs.UploadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskUploadThreadsNum, conf.Conf.Tasks.Upload.Workers)))
})
fs.CopyTaskManager = tache.NewManager[*fs.CopyTask](tache.WithWorks(setting.GetInt(conf.TaskCopyThreadsNum, conf.Conf.Tasks.Copy.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant), db.UpdateTaskDataFunc("copy", conf.Conf.Tasks.Copy.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Copy.MaxRetry))
op.RegisterSettingChangingCallback(func() {
fs.CopyTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskCopyThreadsNum, conf.Conf.Tasks.Copy.Workers)))
})
tool.DownloadTaskManager = tache.NewManager[*tool.DownloadTask](tache.WithWorks(setting.GetInt(conf.TaskOfflineDownloadThreadsNum, conf.Conf.Tasks.Download.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant), db.UpdateTaskDataFunc("download", conf.Conf.Tasks.Download.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Download.MaxRetry))
op.RegisterSettingChangingCallback(func() {
tool.DownloadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskOfflineDownloadThreadsNum, conf.Conf.Tasks.Download.Workers)))
})
tool.TransferTaskManager = tache.NewManager[*tool.TransferTask](tache.WithWorks(setting.GetInt(conf.TaskOfflineDownloadTransferThreadsNum, conf.Conf.Tasks.Transfer.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant), db.UpdateTaskDataFunc("transfer", conf.Conf.Tasks.Transfer.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Transfer.MaxRetry))
op.RegisterSettingChangingCallback(func() {
tool.TransferTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskOfflineDownloadTransferThreadsNum, conf.Conf.Tasks.Transfer.Workers)))
})
if len(tool.TransferTaskManager.GetAll()) == 0 { //prevent offline downloaded files from being deleted
CleanTempDir()
}
fs.ArchiveDownloadTaskManager = tache.NewManager[*fs.ArchiveDownloadTask](tache.WithWorks(conf.Conf.Tasks.Decompress.Workers), tache.WithPersistFunction(db.GetTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant), db.UpdateTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Decompress.MaxRetry))
fs.ArchiveContentUploadTaskManager.Manager = tache.NewManager[*fs.ArchiveContentUploadTask](tache.WithWorks(conf.Conf.Tasks.DecompressUpload.Workers), tache.WithMaxRetry(conf.Conf.Tasks.DecompressUpload.MaxRetry)) //decompress upload will not support persist
fs.ArchiveDownloadTaskManager = tache.NewManager[*fs.ArchiveDownloadTask](tache.WithWorks(setting.GetInt(conf.TaskDecompressDownloadThreadsNum, conf.Conf.Tasks.Decompress.Workers)), tache.WithPersistFunction(db.GetTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant), db.UpdateTaskDataFunc("decompress", conf.Conf.Tasks.Decompress.TaskPersistant)), tache.WithMaxRetry(conf.Conf.Tasks.Decompress.MaxRetry))
op.RegisterSettingChangingCallback(func() {
fs.ArchiveDownloadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskDecompressDownloadThreadsNum, conf.Conf.Tasks.Decompress.Workers)))
})
fs.ArchiveContentUploadTaskManager.Manager = tache.NewManager[*fs.ArchiveContentUploadTask](tache.WithWorks(setting.GetInt(conf.TaskDecompressUploadThreadsNum, conf.Conf.Tasks.DecompressUpload.Workers)), tache.WithMaxRetry(conf.Conf.Tasks.DecompressUpload.MaxRetry)) //decompress upload will not support persist
op.RegisterSettingChangingCallback(func() {
fs.ArchiveContentUploadTaskManager.SetWorkersNumActive(taskFilterNegative(setting.GetInt(conf.TaskDecompressUploadThreadsNum, conf.Conf.Tasks.DecompressUpload.Workers)))
})
}

View File

@ -115,6 +115,18 @@ const (
FTPImplicitTLS = "ftp_implicit_tls"
FTPTLSPrivateKeyPath = "ftp_tls_private_key_path"
FTPTLSPublicCertPath = "ftp_tls_public_cert_path"
// traffic
TaskOfflineDownloadThreadsNum = "offline_download_task_threads_num"
TaskOfflineDownloadTransferThreadsNum = "offline_download_transfer_task_threads_num"
TaskUploadThreadsNum = "upload_task_threads_num"
TaskCopyThreadsNum = "copy_task_threads_num"
TaskDecompressDownloadThreadsNum = "decompress_download_task_threads_num"
TaskDecompressUploadThreadsNum = "decompress_upload_task_threads_num"
StreamMaxClientDownloadSpeed = "max_client_download_speed"
StreamMaxClientUploadSpeed = "max_client_upload_speed"
StreamMaxServerDownloadSpeed = "max_server_download_speed"
StreamMaxServerUploadSpeed = "max_server_upload_speed"
)
const (

View File

@ -77,6 +77,29 @@ type Remove interface {
}
type Put interface {
// Put a file (provided as a FileStreamer) into the driver
// Besides the most basic upload functionality, the following features also need to be implemented:
// 1. Canceling (when `<-ctx.Done()` returns), by the following methods:
// (1) Use request methods that carry context, such as the following:
// a. http.NewRequestWithContext
// b. resty.Request.SetContext
// c. s3manager.Uploader.UploadWithContext
// d. utils.CopyWithCtx
// (2) Use a `driver.ReaderWithCtx` or a `driver.NewLimitedUploadStream`
// (3) Use `utils.IsCanceled` to check if the upload has been canceled during the upload process,
// this is typically applicable to chunked uploads.
// 2. Submit upload progress (via `up`) in real-time. There are three recommended ways as follows:
// (1) Use `utils.CopyWithCtx`
// (2) Use `driver.ReaderUpdatingProgress`
// (3) Use `driver.Progress` with `io.TeeReader`
// 3. Slow down upload speed (via `stream.ServerUploadLimit`). It requires you to wrap the read stream
// in a `driver.RateLimitReader` or a `driver.RateLimitFile` after calculating the file's hash and
// before uploading the file or file chunks. Or you can directly call `driver.ServerUploadLimitWaitN`
// if your file chunks are sufficiently small (less than about 50KB).
// NOTE that the network speed may be significantly slower than the stream's read speed. Therefore, if
// you use a `errgroup.Group` to upload each chunk in parallel, you should consider using a recursive
// mutex like `semaphore.Weighted` to limit the maximum number of upload threads, preventing excessive
// memory usage caused by buffering too many file chunks awaiting upload.
Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) error
}
@ -113,6 +136,29 @@ type CopyResult interface {
}
type PutResult interface {
// Put a file (provided as a FileStreamer) into the driver and return the put obj
// Besides the most basic upload functionality, the following features also need to be implemented:
// 1. Canceling (when `<-ctx.Done()` returns), which can be supported by the following methods:
// (1) Use request methods that carry context, such as the following:
// a. http.NewRequestWithContext
// b. resty.Request.SetContext
// c. s3manager.Uploader.UploadWithContext
// d. utils.CopyWithCtx
// (2) Use a `driver.ReaderWithCtx` or `driver.NewLimitedUploadStream`
// (3) Use `utils.IsCanceled` to check if the upload has been canceled during the upload process,
// this is typically applicable to chunked uploads.
// 2. Submit upload progress (via `up`) in real-time. There are three recommended ways as follows:
// (1) Use `utils.CopyWithCtx`
// (2) Use `driver.ReaderUpdatingProgress`
// (3) Use `driver.Progress` with `io.TeeReader`
// 3. Slow down upload speed (via `stream.ServerUploadLimit`). It requires you to wrap the read stream
// in a `driver.RateLimitReader` or a `driver.RateLimitFile` after calculating the file's hash and
// before uploading the file or file chunks. Or you can directly call `driver.ServerUploadLimitWaitN`
// if your file chunks are sufficiently small (less than about 50KB).
// NOTE that the network speed may be significantly slower than the stream's read speed. Therefore, if
// you use a `errgroup.Group` to upload each chunk in parallel, you should consider using a recursive
// mutex like `semaphore.Weighted` to limit the maximum number of upload threads, preventing excessive
// memory usage caused by buffering too many file chunks awaiting upload.
Put(ctx context.Context, dstDir model.Obj, file model.FileStreamer, up UpdateProgress) (model.Obj, error)
}
@ -159,28 +205,6 @@ type ArchiveDecompressResult interface {
ArchiveDecompress(ctx context.Context, srcObj, dstDir model.Obj, args model.ArchiveDecompressArgs) ([]model.Obj, error)
}
type UpdateProgress = model.UpdateProgress
type Progress struct {
Total int64
Done int64
up UpdateProgress
}
func (p *Progress) Write(b []byte) (n int, err error) {
n = len(b)
p.Done += int64(n)
p.up(float64(p.Done) / float64(p.Total) * 100)
return
}
func NewProgress(total int64, up UpdateProgress) *Progress {
return &Progress{
Total: total,
up: up,
}
}
type Reference interface {
InitReference(storage Driver) error
}

62
internal/driver/utils.go Normal file
View File

@ -0,0 +1,62 @@
package driver
import (
"context"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/internal/stream"
"io"
)
type UpdateProgress = model.UpdateProgress
type Progress struct {
Total int64
Done int64
up UpdateProgress
}
func (p *Progress) Write(b []byte) (n int, err error) {
n = len(b)
p.Done += int64(n)
p.up(float64(p.Done) / float64(p.Total) * 100)
return
}
func NewProgress(total int64, up UpdateProgress) *Progress {
return &Progress{
Total: total,
up: up,
}
}
type RateLimitReader = stream.RateLimitReader
type RateLimitWriter = stream.RateLimitWriter
type RateLimitFile = stream.RateLimitFile
func NewLimitedUploadStream(ctx context.Context, r io.Reader) *RateLimitReader {
return &RateLimitReader{
Reader: r,
Limiter: stream.ServerUploadLimit,
Ctx: ctx,
}
}
func NewLimitedUploadFile(ctx context.Context, f model.File) *RateLimitFile {
return &RateLimitFile{
File: f,
Limiter: stream.ServerUploadLimit,
Ctx: ctx,
}
}
func ServerUploadLimitWaitN(ctx context.Context, n int) error {
return stream.ServerUploadLimit.WaitN(ctx, n)
}
type ReaderWithCtx = stream.ReaderWithCtx
type ReaderUpdatingProgress = stream.ReaderUpdatingProgress
type SimpleReaderWithSize = stream.SimpleReaderWithSize

View File

@ -12,6 +12,7 @@ const (
LDAP
S3
FTP
TRAFFIC
)
const (

View File

@ -3,6 +3,7 @@ package net
import (
"compress/gzip"
"context"
"crypto/tls"
"fmt"
"io"
"mime"
@ -14,7 +15,6 @@ import (
"sync"
"time"
"github.com/alist-org/alist/v3/drivers/base"
"github.com/alist-org/alist/v3/internal/conf"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/http_range"
@ -264,7 +264,7 @@ var httpClient *http.Client
func HttpClient() *http.Client {
once.Do(func() {
httpClient = base.NewHttpClient()
httpClient = NewHttpClient()
httpClient.CheckRedirect = func(req *http.Request, via []*http.Request) error {
if len(via) >= 10 {
return errors.New("stopped after 10 redirects")
@ -275,3 +275,13 @@ func HttpClient() *http.Client {
})
return httpClient
}
func NewHttpClient() *http.Client {
return &http.Client{
Timeout: time.Hour * 48,
Transport: &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{InsecureSkipVerify: conf.Conf.TlsInsecureSkipVerify},
},
}
}

View File

@ -26,9 +26,18 @@ var settingGroupCacheF = func(key string, item []model.SettingItem) {
settingGroupCache.Set(key, item, cache.WithEx[[]model.SettingItem](time.Hour))
}
var settingChangingCallbacks = make([]func(), 0)
func RegisterSettingChangingCallback(f func()) {
settingChangingCallbacks = append(settingChangingCallbacks, f)
}
func SettingCacheUpdate() {
settingCache.Clear()
settingGroupCache.Clear()
for _, cb := range settingChangingCallbacks {
cb()
}
}
func GetPublicSettingsMap() map[string]string {

152
internal/stream/limit.go Normal file
View File

@ -0,0 +1,152 @@
package stream
import (
"context"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/http_range"
"github.com/alist-org/alist/v3/pkg/utils"
"golang.org/x/time/rate"
"io"
"time"
)
type Limiter interface {
Limit() rate.Limit
Burst() int
TokensAt(time.Time) float64
Tokens() float64
Allow() bool
AllowN(time.Time, int) bool
Reserve() *rate.Reservation
ReserveN(time.Time, int) *rate.Reservation
Wait(context.Context) error
WaitN(context.Context, int) error
SetLimit(rate.Limit)
SetLimitAt(time.Time, rate.Limit)
SetBurst(int)
SetBurstAt(time.Time, int)
}
var (
ClientDownloadLimit Limiter
ClientUploadLimit Limiter
ServerDownloadLimit Limiter
ServerUploadLimit Limiter
)
type RateLimitReader struct {
io.Reader
Limiter Limiter
Ctx context.Context
}
func (r *RateLimitReader) Read(p []byte) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
}
n, err = r.Reader.Read(p)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
}
func (r *RateLimitReader) Close() error {
if c, ok := r.Reader.(io.Closer); ok {
return c.Close()
}
return nil
}
type RateLimitWriter struct {
io.Writer
Limiter Limiter
Ctx context.Context
}
func (w *RateLimitWriter) Write(p []byte) (n int, err error) {
if w.Ctx != nil && utils.IsCanceled(w.Ctx) {
return 0, w.Ctx.Err()
}
n, err = w.Writer.Write(p)
if err != nil {
return
}
if w.Limiter != nil {
if w.Ctx == nil {
w.Ctx = context.Background()
}
err = w.Limiter.WaitN(w.Ctx, n)
}
return
}
func (w *RateLimitWriter) Close() error {
if c, ok := w.Writer.(io.Closer); ok {
return c.Close()
}
return nil
}
type RateLimitFile struct {
model.File
Limiter Limiter
Ctx context.Context
}
func (r *RateLimitFile) Read(p []byte) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
}
n, err = r.File.Read(p)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
}
func (r *RateLimitFile) ReadAt(p []byte, off int64) (n int, err error) {
if r.Ctx != nil && utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
}
n, err = r.File.ReadAt(p, off)
if err != nil {
return
}
if r.Limiter != nil {
if r.Ctx == nil {
r.Ctx = context.Background()
}
err = r.Limiter.WaitN(r.Ctx, n)
}
return
}
type RateLimitRangeReadCloser struct {
model.RangeReadCloserIF
Limiter Limiter
}
func (rrc RateLimitRangeReadCloser) RangeRead(ctx context.Context, httpRange http_range.Range) (io.ReadCloser, error) {
rc, err := rrc.RangeReadCloserIF.RangeRead(ctx, httpRange)
if err != nil {
return nil, err
}
return &RateLimitReader{
Reader: rc,
Limiter: rrc.Limiter,
Ctx: ctx,
}, nil
}

View File

@ -182,14 +182,24 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
}
if ss.Link != nil {
if ss.Link.MFile != nil {
ss.mFile = ss.Link.MFile
ss.Reader = ss.Link.MFile
ss.Closers.Add(ss.Link.MFile)
mFile := ss.Link.MFile
if _, ok := mFile.(*os.File); !ok {
mFile = &RateLimitFile{
File: mFile,
Limiter: ServerDownloadLimit,
Ctx: fs.Ctx,
}
}
ss.mFile = mFile
ss.Reader = mFile
ss.Closers.Add(mFile)
return &ss, nil
}
if ss.Link.RangeReadCloser != nil {
ss.rangeReadCloser = ss.Link.RangeReadCloser
ss.rangeReadCloser = RateLimitRangeReadCloser{
RangeReadCloserIF: ss.Link.RangeReadCloser,
Limiter: ServerDownloadLimit,
}
ss.Add(ss.rangeReadCloser)
return &ss, nil
}
@ -198,6 +208,10 @@ func NewSeekableStream(fs FileStream, link *model.Link) (*SeekableStream, error)
if err != nil {
return nil, err
}
rrc = RateLimitRangeReadCloser{
RangeReadCloserIF: rrc,
Limiter: ServerDownloadLimit,
}
ss.rangeReadCloser = rrc
ss.Add(rrc)
return &ss, nil
@ -259,7 +273,7 @@ func (ss *SeekableStream) CacheFullInTempFile() (model.File, error) {
if ss.tmpFile != nil {
return ss.tmpFile, nil
}
if ss.mFile != nil {
if _, ok := ss.mFile.(*os.File); ok {
return ss.mFile, nil
}
tmpF, err := utils.CreateTempFile(ss, ss.GetSize())
@ -276,7 +290,7 @@ func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdatePr
if ss.tmpFile != nil {
return ss.tmpFile, nil
}
if ss.mFile != nil {
if _, ok := ss.mFile.(*os.File); ok {
return ss.mFile, nil
}
tmpF, err := utils.CreateTempFile(&ReaderUpdatingProgress{
@ -293,12 +307,13 @@ func (ss *SeekableStream) CacheFullInTempFileAndUpdateProgress(up model.UpdatePr
}
func (f *FileStream) SetTmpFile(r *os.File) {
f.Reader = r
f.Add(r)
f.tmpFile = r
f.Reader = r
}
type ReaderWithSize interface {
io.Reader
io.ReadCloser
GetSize() int64
}
@ -311,6 +326,13 @@ func (r *SimpleReaderWithSize) GetSize() int64 {
return r.Size
}
func (r *SimpleReaderWithSize) Close() error {
if c, ok := r.Reader.(io.Closer); ok {
return c.Close()
}
return nil
}
type ReaderUpdatingProgress struct {
Reader ReaderWithSize
model.UpdateProgress
@ -324,6 +346,10 @@ func (r *ReaderUpdatingProgress) Read(p []byte) (n int, err error) {
return n, err
}
func (r *ReaderUpdatingProgress) Close() error {
return r.Reader.Close()
}
type SStreamReadAtSeeker interface {
model.File
GetRawStream() *SeekableStream
@ -534,7 +560,7 @@ func (r *RangeReadReadAtSeeker) Read(p []byte) (n int, err error) {
func (r *RangeReadReadAtSeeker) Close() error {
if r.headCache != nil {
r.headCache.close()
_ = r.headCache.close()
}
return r.ss.Close()
}
@ -562,17 +588,3 @@ func (f *FileReadAtSeeker) Seek(offset int64, whence int) (int64, error) {
func (f *FileReadAtSeeker) Close() error {
return f.ss.Close()
}
type ReaderWithCtx struct {
io.Reader
Ctx context.Context
}
func (r *ReaderWithCtx) Read(p []byte) (n int, err error) {
select {
case <-r.Ctx.Done():
return 0, r.Ctx.Err()
default:
return r.Reader.Read(p)
}
}

View File

@ -3,6 +3,7 @@ package stream
import (
"context"
"fmt"
"github.com/alist-org/alist/v3/pkg/utils"
"io"
"net/http"
@ -76,3 +77,22 @@ func checkContentRange(header *http.Header, offset int64) bool {
}
return false
}
type ReaderWithCtx struct {
io.Reader
Ctx context.Context
}
func (r *ReaderWithCtx) Read(p []byte) (n int, err error) {
if utils.IsCanceled(r.Ctx) {
return 0, r.Ctx.Err()
}
return r.Reader.Read(p)
}
func (r *ReaderWithCtx) Close() error {
if c, ok := r.Reader.(io.Closer); ok {
return c.Close()
}
return nil
}