feat: add smb driver (close #1746) (#2114)

* feat: add smb driver (close #1746)

* Update driver.go
This commit is contained in:
BoYanZh
2022-10-25 23:00:23 +08:00
committed by GitHub
parent 0019959eec
commit dd4674e486
7 changed files with 322 additions and 0 deletions

View File

@ -21,6 +21,7 @@ import (
_ "github.com/alist-org/alist/v3/drivers/quark"
_ "github.com/alist-org/alist/v3/drivers/s3"
_ "github.com/alist-org/alist/v3/drivers/sftp"
_ "github.com/alist-org/alist/v3/drivers/smb"
_ "github.com/alist-org/alist/v3/drivers/teambition"
_ "github.com/alist-org/alist/v3/drivers/thunder"
_ "github.com/alist-org/alist/v3/drivers/uss"

163
drivers/smb/driver.go Normal file
View File

@ -0,0 +1,163 @@
package smb
import (
"context"
"errors"
"path/filepath"
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/model"
"github.com/alist-org/alist/v3/pkg/utils"
"github.com/hirochachacha/go-smb2"
)
type SMB struct {
model.Storage
Addition
fs *smb2.Share
}
func (d *SMB) Config() driver.Config {
return config
}
func (d *SMB) GetAddition() driver.Additional {
return d.Addition
}
func (d *SMB) Init(ctx context.Context, storage model.Storage) error {
d.Storage = storage
err := utils.Json.UnmarshalFromString(d.Storage.Addition, &d.Addition)
if err != nil {
return err
}
return d.initFS()
}
func (d *SMB) Drop(ctx context.Context) error {
if d.fs != nil {
_ = d.fs.Umount()
}
return nil
}
func (d *SMB) List(ctx context.Context, dir model.Obj, args model.ListArgs) ([]model.Obj, error) {
fullPath := d.getSMBPath(dir)
rawFiles, err := d.fs.ReadDir(fullPath)
if err != nil {
return nil, err
}
var files []model.Obj
for _, f := range rawFiles {
file := model.ObjThumb{
Object: model.Object{
Name: f.Name(),
Modified: f.ModTime(),
Size: f.Size(),
IsFolder: f.IsDir(),
},
}
files = append(files, &file)
}
return files, nil
}
//func (d *SMB) Get(ctx context.Context, path string) (model.Obj, error) {
// // this is optional
// return nil, errs.NotImplement
//}
func (d *SMB) Link(ctx context.Context, file model.Obj, args model.LinkArgs) (*model.Link, error) {
fullPath := d.getSMBPath(file)
remoteFile, err := d.fs.Open(fullPath)
if err != nil {
return nil, err
}
return &model.Link{
Data: remoteFile,
}, nil
}
func (d *SMB) MakeDir(ctx context.Context, parentDir model.Obj, dirName string) error {
fullPath := filepath.Join(d.getSMBPath(parentDir), dirName)
err := d.fs.MkdirAll(fullPath, 0700)
if err != nil {
return err
}
return nil
}
func (d *SMB) Move(ctx context.Context, srcObj, dstDir model.Obj) error {
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName())
err := d.fs.Rename(srcPath, dstPath)
if err != nil {
return err
}
return nil
}
func (d *SMB) Rename(ctx context.Context, srcObj model.Obj, newName string) error {
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(filepath.Dir(srcPath), newName)
err := d.fs.Rename(srcPath, dstPath)
if err != nil {
return err
}
return nil
}
func (d *SMB) Copy(ctx context.Context, srcObj, dstDir model.Obj) error {
srcPath := d.getSMBPath(srcObj)
dstPath := filepath.Join(d.getSMBPath(dstDir), srcObj.GetName())
var err error
if srcObj.IsDir() {
err = d.CopyDir(srcPath, dstPath)
} else {
err = d.CopyFile(srcPath, dstPath)
}
if err != nil {
return err
}
return nil
}
func (d *SMB) Remove(ctx context.Context, obj model.Obj) error {
var err error
fullPath := d.getSMBPath(obj)
if obj.IsDir() {
err = d.fs.RemoveAll(fullPath)
} else {
err = d.fs.Remove(fullPath)
}
if err != nil {
return err
}
return nil
}
func (d *SMB) Put(ctx context.Context, dstDir model.Obj, stream model.FileStreamer, up driver.UpdateProgress) error {
fullPath := filepath.Join(d.getSMBPath(dstDir), stream.GetName())
out, err := d.fs.Create(fullPath)
if err != nil {
return err
}
defer func() {
_ = out.Close()
if errors.Is(err, context.Canceled) {
_ = d.fs.Remove(fullPath)
}
}()
err = utils.CopyWithCtx(ctx, out, stream, stream.GetSize(), up)
if err != nil {
return err
}
return nil
}
//func (d *SMB) Other(ctx context.Context, args model.OtherArgs) (interface{}, error) {
// return nil, errs.NotSupport
//}
var _ driver.Driver = (*SMB)(nil)

28
drivers/smb/meta.go Normal file
View File

@ -0,0 +1,28 @@
package smb
import (
"github.com/alist-org/alist/v3/internal/driver"
"github.com/alist-org/alist/v3/internal/op"
)
type Addition struct {
driver.RootPath
Address string `json:"address" required:"true"`
Username string `json:"username" required:"true"`
Password string `json:"password"`
ShareName string `json:"share_name" required:"true"`
}
var config = driver.Config{
Name: "SMB",
LocalSort: true,
OnlyLocal: true,
DefaultRoot: ".",
NoCache: true,
}
func init() {
op.RegisterDriver(config, func() driver.Driver {
return &SMB{}
})
}

1
drivers/smb/types.go Normal file
View File

@ -0,0 +1 @@
package smb

122
drivers/smb/util.go Normal file
View File

@ -0,0 +1,122 @@
package smb
import (
"io"
"io/fs"
"net"
"os"
"path/filepath"
"github.com/alist-org/alist/v3/internal/model"
"github.com/hirochachacha/go-smb2"
)
func (d *SMB) initFS() error {
conn, err := net.Dial("tcp", d.Address)
if err != nil {
return err
}
dialer := &smb2.Dialer{
Initiator: &smb2.NTLMInitiator{
User: d.Username,
Password: d.Password,
},
}
s, err := dialer.Dial(conn)
if err != nil {
return err
}
d.fs, err = s.Mount(d.ShareName)
if err != nil {
return err
}
return err
}
func (d *SMB) getSMBPath(dir model.Obj) string {
fullPath := dir.GetPath()
if fullPath[0:1] != "." {
fullPath = "." + fullPath
}
return fullPath
}
// CopyFile File copies a single file from src to dst
func (d *SMB) CopyFile(src, dst string) error {
var err error
var srcfd *smb2.File
var dstfd *smb2.File
var srcinfo fs.FileInfo
if srcfd, err = d.fs.Open(src); err != nil {
return err
}
defer srcfd.Close()
if dstfd, err = d.CreateNestedFile(dst); err != nil {
return err
}
defer dstfd.Close()
if _, err = io.Copy(dstfd, srcfd); err != nil {
return err
}
if srcinfo, err = d.fs.Stat(src); err != nil {
return err
}
return d.fs.Chmod(dst, srcinfo.Mode())
}
// CopyDir Dir copies a whole directory recursively
func (d *SMB) CopyDir(src string, dst string) error {
var err error
var fds []fs.FileInfo
var srcinfo fs.FileInfo
if srcinfo, err = d.fs.Stat(src); err != nil {
return err
}
if err = d.fs.MkdirAll(dst, srcinfo.Mode()); err != nil {
return err
}
if fds, err = d.fs.ReadDir(src); err != nil {
return err
}
for _, fd := range fds {
srcfp := filepath.Join(src, fd.Name())
dstfp := filepath.Join(dst, fd.Name())
if fd.IsDir() {
if err = d.CopyDir(srcfp, dstfp); err != nil {
return err
}
} else {
if err = d.CopyFile(srcfp, dstfp); err != nil {
return err
}
}
}
return nil
}
// Exists determine whether the file exists
func (d *SMB) Exists(name string) bool {
if _, err := d.fs.Stat(name); err != nil {
if os.IsNotExist(err) {
return false
}
}
return true
}
// CreateNestedFile create nested file
func (d *SMB) CreateNestedFile(path string) (*smb2.File, error) {
basePath := filepath.Dir(path)
if !d.Exists(basePath) {
err := d.fs.MkdirAll(basePath, 0700)
if err != nil {
return nil, err
}
}
return d.fs.Create(path)
}