585 lines
16 KiB
Go
585 lines
16 KiB
Go
package main
|
|
|
|
import (
|
|
"bufio"
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"encoding/json"
|
|
"flag"
|
|
"fmt"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"syscall"
|
|
"time"
|
|
|
|
proton "github.com/henrybear327/go-proton-api"
|
|
proton_api_bridge "github.com/henrybear327/Proton-API-Bridge"
|
|
"github.com/henrybear327/Proton-API-Bridge/common"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"golang.org/x/term"
|
|
)
|
|
|
|
const (
|
|
credentialFile = "credentials.json"
|
|
stateDBFile = "state.db"
|
|
version = "0.1.0"
|
|
)
|
|
|
|
type Config struct {
|
|
SourceDir string
|
|
RemoteDir string
|
|
ExcludePatterns []string
|
|
DryRun bool
|
|
Verbose bool
|
|
RateLimitMs int
|
|
RetryCount int
|
|
RetryDelayMs int
|
|
}
|
|
|
|
type BackupState struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func main() {
|
|
// Parse flags
|
|
sourceDir := flag.String("source", "", "Source directory to backup")
|
|
remoteDir := flag.String("remote", "", "Remote directory on Proton Drive (e.g., /backups/immich)")
|
|
exclude := flag.String("exclude", "thumbs,encoded-video", "Comma-separated patterns to exclude")
|
|
dryRun := flag.Bool("dry-run", false, "Show what would be uploaded without uploading")
|
|
verbose := flag.Bool("verbose", false, "Verbose output")
|
|
rateLimit := flag.Int("rate-limit", 2000, "Milliseconds between uploads")
|
|
retryCount := flag.Int("retries", 5, "Number of retries on failure")
|
|
retryDelay := flag.Int("retry-delay", 30000, "Milliseconds to wait before retry")
|
|
showVersion := flag.Bool("version", false, "Show version")
|
|
|
|
flag.Parse()
|
|
|
|
if *showVersion {
|
|
fmt.Printf("proton-backup %s\n", version)
|
|
os.Exit(0)
|
|
}
|
|
|
|
if *sourceDir == "" || *remoteDir == "" {
|
|
fmt.Println("Usage: proton-backup -source <dir> -remote <remote-path>")
|
|
fmt.Println()
|
|
flag.PrintDefaults()
|
|
os.Exit(1)
|
|
}
|
|
|
|
// Parse exclude patterns
|
|
var excludePatterns []string
|
|
if *exclude != "" {
|
|
excludePatterns = strings.Split(*exclude, ",")
|
|
}
|
|
|
|
config := Config{
|
|
SourceDir: *sourceDir,
|
|
RemoteDir: *remoteDir,
|
|
ExcludePatterns: excludePatterns,
|
|
DryRun: *dryRun,
|
|
Verbose: *verbose,
|
|
RateLimitMs: *rateLimit,
|
|
RetryCount: *retryCount,
|
|
RetryDelayMs: *retryDelay,
|
|
}
|
|
|
|
if err := run(config); err != nil {
|
|
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
|
os.Exit(1)
|
|
}
|
|
}
|
|
|
|
func run(config Config) error {
|
|
ctx := context.Background()
|
|
|
|
// Initialize state directory
|
|
stateDir := filepath.Join(os.Getenv("HOME"), ".config", "proton-backup")
|
|
if err := os.MkdirAll(stateDir, 0700); err != nil {
|
|
return fmt.Errorf("failed to create state directory: %w", err)
|
|
}
|
|
|
|
// Initialize state database
|
|
state, err := NewBackupState(filepath.Join(stateDir, stateDBFile))
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize state: %w", err)
|
|
}
|
|
defer state.Close()
|
|
|
|
// Connect to Proton Drive
|
|
credPath := filepath.Join(stateDir, credentialFile)
|
|
protonDrive, err := connectProton(ctx, credPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to connect to Proton Drive: %w", err)
|
|
}
|
|
defer protonDrive.Logout(ctx)
|
|
|
|
fmt.Println("Connected to Proton Drive")
|
|
|
|
// Get root folder
|
|
rootLink := protonDrive.RootLink
|
|
if rootLink == nil {
|
|
return fmt.Errorf("failed to get root link")
|
|
}
|
|
|
|
// Create or get remote directory
|
|
remoteFolderID, err := ensureRemoteDir(ctx, protonDrive, config.RemoteDir)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to create remote directory: %w", err)
|
|
}
|
|
|
|
fmt.Printf("Remote folder: %s (ID: %s)\n", config.RemoteDir, remoteFolderID)
|
|
|
|
// Walk source directory and sync
|
|
return syncDirectory(ctx, protonDrive, state, config, remoteFolderID)
|
|
}
|
|
|
|
func connectProton(ctx context.Context, credPath string) (*proton_api_bridge.ProtonDrive, error) {
|
|
cfg := proton_api_bridge.NewDefaultConfig()
|
|
cfg.AppVersion = "web-drive@5.2.0"
|
|
cfg.UserAgent = "proton-backup/0.1.0"
|
|
cfg.ReplaceExistingDraft = true
|
|
cfg.ConcurrentBlockUploadCount = 2
|
|
// Reset credentials from default config
|
|
cfg.UseReusableLogin = false
|
|
cfg.ReusableCredential = &common.ReusableCredentialData{}
|
|
cfg.FirstLoginCredential = nil
|
|
|
|
// Check for cached credentials
|
|
if data, err := os.ReadFile(credPath); err == nil {
|
|
var savedCreds common.ProtonDriveCredential
|
|
if err := json.Unmarshal(data, &savedCreds); err == nil && savedCreds.UID != "" {
|
|
cfg.UseReusableLogin = true
|
|
cfg.ReusableCredential = &common.ReusableCredentialData{
|
|
UID: savedCreds.UID,
|
|
AccessToken: savedCreds.AccessToken,
|
|
RefreshToken: savedCreds.RefreshToken,
|
|
SaltedKeyPass: savedCreds.SaltedKeyPass,
|
|
}
|
|
fmt.Println("Using cached credentials...")
|
|
}
|
|
}
|
|
|
|
// If no cached credentials, prompt for login
|
|
if !cfg.UseReusableLogin {
|
|
fmt.Println("No cached credentials found, prompting for login...")
|
|
username, password, err := promptCredentials()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read credentials: %w", err)
|
|
}
|
|
if username == "" || password == "" {
|
|
return nil, fmt.Errorf("username and password are required")
|
|
}
|
|
cfg.FirstLoginCredential = &common.FirstLoginCredentialData{
|
|
Username: username,
|
|
Password: password,
|
|
}
|
|
fmt.Printf("Logging in as: %s\n", username)
|
|
}
|
|
|
|
// Auth handler - called when auth is refreshed
|
|
authHandler := func(auth proton.Auth) {
|
|
fmt.Println("Auth refreshed, will save on exit")
|
|
}
|
|
|
|
// Deauth handler
|
|
deauthHandler := func() {
|
|
fmt.Println("Deauthenticated, removing cached credentials")
|
|
os.Remove(credPath)
|
|
}
|
|
|
|
protonDrive, newCreds, err := proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler)
|
|
if err != nil {
|
|
// If cached credentials failed, try fresh login
|
|
if cfg.UseReusableLogin {
|
|
fmt.Println("Cached credentials failed, please login again")
|
|
os.Remove(credPath)
|
|
|
|
username, password, err := promptCredentials()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cfg.UseReusableLogin = false
|
|
cfg.ReusableCredential = nil
|
|
cfg.FirstLoginCredential = &common.FirstLoginCredentialData{
|
|
Username: username,
|
|
Password: password,
|
|
}
|
|
protonDrive, newCreds, err = proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("authentication failed: %w", err)
|
|
}
|
|
} else {
|
|
return nil, fmt.Errorf("authentication failed: %w", err)
|
|
}
|
|
}
|
|
|
|
// Save credentials for next time
|
|
if newCreds != nil {
|
|
data, _ := json.MarshalIndent(newCreds, "", " ")
|
|
if err := os.WriteFile(credPath, data, 0600); err != nil {
|
|
fmt.Printf("Warning: failed to save credentials: %v\n", err)
|
|
} else {
|
|
fmt.Println("Credentials saved for next run")
|
|
}
|
|
}
|
|
|
|
return protonDrive, nil
|
|
}
|
|
|
|
func promptCredentials() (string, string, error) {
|
|
// Check environment variables first (for non-interactive use)
|
|
username := os.Getenv("PROTON_USERNAME")
|
|
password := os.Getenv("PROTON_PASSWORD")
|
|
if username != "" && password != "" {
|
|
fmt.Println("Using credentials from environment variables")
|
|
return username, password, nil
|
|
}
|
|
|
|
reader := bufio.NewReader(os.Stdin)
|
|
|
|
fmt.Print("Proton username: ")
|
|
usernameInput, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
username = strings.TrimSpace(usernameInput)
|
|
|
|
fmt.Print("Proton password: ")
|
|
// Check if stdin is a terminal
|
|
if term.IsTerminal(int(syscall.Stdin)) {
|
|
passwordBytes, err := term.ReadPassword(int(syscall.Stdin))
|
|
fmt.Println()
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
password = string(passwordBytes)
|
|
} else {
|
|
// Non-terminal: read password from stdin (for piped input)
|
|
passwordInput, err := reader.ReadString('\n')
|
|
if err != nil {
|
|
return "", "", err
|
|
}
|
|
password = strings.TrimSpace(passwordInput)
|
|
}
|
|
|
|
return username, password, nil
|
|
}
|
|
|
|
func ensureRemoteDir(ctx context.Context, pd *proton_api_bridge.ProtonDrive, remotePath string) (string, error) {
|
|
parts := strings.Split(strings.Trim(remotePath, "/"), "/")
|
|
|
|
currentID := pd.RootLink.LinkID
|
|
|
|
for _, part := range parts {
|
|
if part == "" {
|
|
continue
|
|
}
|
|
|
|
// Try to find existing folder (searchForFile=false, searchForFolder=true, state=1 for active)
|
|
link, err := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1)
|
|
if err == nil && link != nil {
|
|
currentID = link.LinkID
|
|
continue
|
|
}
|
|
|
|
// Create folder
|
|
newID, err := pd.CreateNewFolderByID(ctx, currentID, part)
|
|
if err != nil {
|
|
// Maybe it was created by another process, try to find it again
|
|
link, err2 := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1)
|
|
if err2 == nil && link != nil {
|
|
currentID = link.LinkID
|
|
continue
|
|
}
|
|
return "", fmt.Errorf("failed to create folder %s: %w", part, err)
|
|
}
|
|
currentID = newID
|
|
fmt.Printf("Created remote folder: %s\n", part)
|
|
}
|
|
|
|
return currentID, nil
|
|
}
|
|
|
|
func syncDirectory(ctx context.Context, pd *proton_api_bridge.ProtonDrive, state *BackupState, config Config, remoteFolderID string) error {
|
|
var totalFiles, uploadedFiles, skippedFiles, errorFiles int
|
|
startTime := time.Now()
|
|
|
|
// First pass: count files
|
|
fmt.Println("Scanning source directory...")
|
|
err := filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return nil // Skip errors
|
|
}
|
|
if !info.IsDir() && !shouldExclude(path, config.ExcludePatterns) {
|
|
totalFiles++
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
fmt.Printf("Found %d files to process\n", totalFiles)
|
|
|
|
// Create a folder cache to avoid repeated API calls
|
|
folderCache := make(map[string]string)
|
|
folderCache[""] = remoteFolderID
|
|
folderCache["."] = remoteFolderID
|
|
|
|
// Second pass: upload files
|
|
processed := 0
|
|
err = filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// Skip directories
|
|
if info.IsDir() {
|
|
return nil
|
|
}
|
|
|
|
// Check exclusions
|
|
if shouldExclude(path, config.ExcludePatterns) {
|
|
if config.Verbose {
|
|
fmt.Printf("SKIP (excluded): %s\n", path)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
processed++
|
|
|
|
// Get relative path
|
|
relPath, err := filepath.Rel(config.SourceDir, path)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
// Check if already uploaded
|
|
fileHash := hashFileInfo(path, info)
|
|
if state.IsUploaded(relPath, fileHash) {
|
|
skippedFiles++
|
|
if config.Verbose {
|
|
fmt.Printf("SKIP (exists) [%d/%d]: %s\n", processed, totalFiles, relPath)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// Ensure parent folder exists on remote
|
|
parentRelPath := filepath.Dir(relPath)
|
|
parentID, err := ensureRemotePath(ctx, pd, folderCache, remoteFolderID, parentRelPath)
|
|
if err != nil {
|
|
fmt.Printf("ERROR (folder) [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err)
|
|
errorFiles++
|
|
return nil
|
|
}
|
|
|
|
// Upload file
|
|
if config.DryRun {
|
|
fmt.Printf("DRY-RUN [%d/%d]: would upload %s\n", processed, totalFiles, relPath)
|
|
uploadedFiles++
|
|
} else {
|
|
err = uploadWithRetry(ctx, pd, path, parentID, info, config)
|
|
if err != nil {
|
|
fmt.Printf("ERROR [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err)
|
|
errorFiles++
|
|
} else {
|
|
uploadedFiles++
|
|
state.MarkUploaded(relPath, fileHash)
|
|
fmt.Printf("OK [%d/%d]: %s (%s)\n", processed, totalFiles, relPath, humanSize(info.Size()))
|
|
}
|
|
|
|
// Rate limiting between uploads
|
|
time.Sleep(time.Duration(config.RateLimitMs) * time.Millisecond)
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
elapsed := time.Since(startTime)
|
|
fmt.Printf("\n=== Summary ===\n")
|
|
fmt.Printf("Duration: %s\n", elapsed.Round(time.Second))
|
|
fmt.Printf("Total files: %d\n", totalFiles)
|
|
fmt.Printf("Uploaded: %d\n", uploadedFiles)
|
|
fmt.Printf("Skipped (already uploaded): %d\n", skippedFiles)
|
|
fmt.Printf("Errors: %d\n", errorFiles)
|
|
|
|
return err
|
|
}
|
|
|
|
func ensureRemotePath(ctx context.Context, pd *proton_api_bridge.ProtonDrive, cache map[string]string, rootID, relPath string) (string, error) {
|
|
if relPath == "." || relPath == "" {
|
|
return rootID, nil
|
|
}
|
|
|
|
// Check cache
|
|
if id, ok := cache[relPath]; ok {
|
|
return id, nil
|
|
}
|
|
|
|
// Ensure parent exists first
|
|
parentPath := filepath.Dir(relPath)
|
|
parentID, err := ensureRemotePath(ctx, pd, cache, rootID, parentPath)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
folderName := filepath.Base(relPath)
|
|
|
|
// Try to find existing folder
|
|
link, err := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1)
|
|
if err == nil && link != nil {
|
|
cache[relPath] = link.LinkID
|
|
return link.LinkID, nil
|
|
}
|
|
|
|
// Create folder
|
|
newID, err := pd.CreateNewFolderByID(ctx, parentID, folderName)
|
|
if err != nil {
|
|
// Maybe created concurrently, try to find again
|
|
link, err2 := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1)
|
|
if err2 == nil && link != nil {
|
|
cache[relPath] = link.LinkID
|
|
return link.LinkID, nil
|
|
}
|
|
return "", fmt.Errorf("failed to create folder %s: %w", folderName, err)
|
|
}
|
|
|
|
cache[relPath] = newID
|
|
return newID, nil
|
|
}
|
|
|
|
func uploadWithRetry(ctx context.Context, pd *proton_api_bridge.ProtonDrive, localPath, parentID string, info os.FileInfo, config Config) error {
|
|
var lastErr error
|
|
|
|
for attempt := 0; attempt <= config.RetryCount; attempt++ {
|
|
if attempt > 0 {
|
|
delay := time.Duration(config.RetryDelayMs) * time.Millisecond
|
|
// Exponential backoff for rate limit errors
|
|
if lastErr != nil {
|
|
errStr := lastErr.Error()
|
|
if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") {
|
|
delay = delay * time.Duration(attempt+1)
|
|
}
|
|
}
|
|
fmt.Printf(" Retry %d/%d after %v...\n", attempt, config.RetryCount, delay)
|
|
time.Sleep(delay)
|
|
}
|
|
|
|
// Open file for reading
|
|
file, err := os.Open(localPath)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to open file: %w", err)
|
|
}
|
|
|
|
// Upload using reader API (takes parentLinkID string)
|
|
_, _, err = pd.UploadFileByReader(ctx, parentID, info.Name(), info.ModTime(), file, 0)
|
|
file.Close()
|
|
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
|
|
lastErr = err
|
|
|
|
// Check if it's a retryable error
|
|
errStr := err.Error()
|
|
if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") ||
|
|
strings.Contains(errStr, "retry") || strings.Contains(errStr, "500") ||
|
|
strings.Contains(errStr, "503") {
|
|
continue
|
|
}
|
|
|
|
// For non-retryable errors, check if file already exists
|
|
if strings.Contains(errStr, "already exists") || strings.Contains(errStr, "Draft already exists") {
|
|
// File exists, treat as success
|
|
return nil
|
|
}
|
|
}
|
|
|
|
return lastErr
|
|
}
|
|
|
|
func shouldExclude(path string, patterns []string) bool {
|
|
for _, pattern := range patterns {
|
|
pattern = strings.TrimSpace(pattern)
|
|
if pattern == "" {
|
|
continue
|
|
}
|
|
// Exclude if pattern appears as a path component
|
|
if strings.Contains(path, string(os.PathSeparator)+pattern+string(os.PathSeparator)) ||
|
|
strings.HasSuffix(path, string(os.PathSeparator)+pattern) {
|
|
return true
|
|
}
|
|
// Also exclude hidden files
|
|
if strings.Contains(filepath.Base(path), "/.") || strings.HasPrefix(filepath.Base(path), ".") {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func hashFileInfo(path string, info os.FileInfo) string {
|
|
h := sha256.New()
|
|
h.Write([]byte(path))
|
|
h.Write([]byte(fmt.Sprintf("%d", info.Size())))
|
|
h.Write([]byte(fmt.Sprintf("%d", info.ModTime().Unix())))
|
|
return hex.EncodeToString(h.Sum(nil))[:16]
|
|
}
|
|
|
|
func humanSize(bytes int64) string {
|
|
const unit = 1024
|
|
if bytes < unit {
|
|
return fmt.Sprintf("%d B", bytes)
|
|
}
|
|
div, exp := int64(unit), 0
|
|
for n := bytes / unit; n >= unit; n /= unit {
|
|
div *= unit
|
|
exp++
|
|
}
|
|
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
|
|
}
|
|
|
|
// BackupState tracks what has been uploaded
|
|
func NewBackupState(dbPath string) (*BackupState, error) {
|
|
db, err := sql.Open("sqlite3", dbPath)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
_, err = db.Exec(`
|
|
CREATE TABLE IF NOT EXISTS uploaded_files (
|
|
path TEXT PRIMARY KEY,
|
|
hash TEXT,
|
|
uploaded_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
|
)
|
|
`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return &BackupState{db: db}, nil
|
|
}
|
|
|
|
func (s *BackupState) IsUploaded(path, hash string) bool {
|
|
var storedHash string
|
|
err := s.db.QueryRow("SELECT hash FROM uploaded_files WHERE path = ?", path).Scan(&storedHash)
|
|
if err != nil {
|
|
return false
|
|
}
|
|
return storedHash == hash
|
|
}
|
|
|
|
func (s *BackupState) MarkUploaded(path, hash string) error {
|
|
_, err := s.db.Exec(
|
|
"INSERT OR REPLACE INTO uploaded_files (path, hash, uploaded_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
|
|
path, hash,
|
|
)
|
|
return err
|
|
}
|
|
|
|
func (s *BackupState) Close() error {
|
|
return s.db.Close()
|
|
}
|