130 lines
2.8 KiB
Go

package transfer
import (
"context"
"fmt"
"io"
"log"
"os"
"path/filepath"
)
type NASTransfer struct {
config NASConfig
connected bool
}
func NewNASTransfer(config NASConfig) *NASTransfer {
nt := &NASTransfer{
config: config,
}
err := nt.ensureDirectoryExists(nt.config.Path)
if err != nil {
log.Fatalf("Failed to create directory %s: %v", nt.config.Path, err)
}
return nt
}
func (nt *NASTransfer) TransferFile(ctx context.Context, item *TransferItem) error {
destPath := filepath.Join(nt.config.Path, item.DestinationPath)
destDir := filepath.Dir(destPath)
if err := nt.ensureDirectoryExists(destDir); err != nil {
return fmt.Errorf("Failed to create directory %s: %w", destDir, err)
}
transferCtx, cancel := context.WithTimeout(ctx, nt.config.Timeout)
defer cancel()
if err := nt.copyFile(transferCtx, item.SourcePath, destPath); err != nil {
return fmt.Errorf("Failed to copy file %s to %s: %w", item.SourcePath, destPath, err)
}
if nt.config.VerifySize {
if err := nt.VerifyTransfer(item.SourcePath, destPath); err != nil {
os.Remove(destPath)
return fmt.Errorf("Failed to verify transfer: %w", err)
}
}
log.Printf("File transfer completed: %s -> %s", item.SourcePath, destPath)
return nil
}
func (nt *NASTransfer) copyFile(ctx context.Context, srcPath, destPath string) error {
src, err := os.Open(srcPath)
if err != nil {
return fmt.Errorf("Failed to open source file: %w", err)
}
defer src.Close()
dest, err := os.Create(destPath)
if err != nil {
return fmt.Errorf("Failed to create destination file: %w", err)
}
defer dest.Close()
done := make(chan error, 1)
go func() {
_, err := io.Copy(dest, src)
done <- err
}()
select {
case <-ctx.Done():
return ctx.Err()
case err := <-done:
if err != nil {
return err
}
return dest.Sync()
}
}
func (nt *NASTransfer) VerifyTransfer(srcPath, destPath string) error {
srcInfo, err := os.Stat(srcPath)
if err != nil {
return fmt.Errorf("Failed to stat source file: %w", err)
}
destInfo, err := os.Stat(destPath)
if err != nil {
return fmt.Errorf("Failed to stat destination file: %w", err)
}
if srcInfo.Size() != destInfo.Size() {
return fmt.Errorf("size mismatch: source=%d, dest=%d", srcInfo.Size(), destInfo.Size())
}
return nil
}
func (nt *NASTransfer) ensureDirectoryExists(path string) error {
if err := os.MkdirAll(path, 0755); err != nil {
return fmt.Errorf("Failed to create directory: %w", err)
}
return nil
}
func (nt *NASTransfer) TestConnection() error {
testFile := filepath.Join(nt.config.Path, ".connection_test")
f, err := os.Create(testFile)
if err != nil {
return fmt.Errorf("Failed to create test file: %w", err)
}
f.Close()
os.Remove(testFile)
nt.connected = true
log.Printf("Connected to NAS at %s", nt.config.Path)
return nil
}
func (nt *NASTransfer) IsConnected() bool {
return nt.connected
}