package main import ( "bufio" "context" "crypto/sha256" "database/sql" "encoding/hex" "encoding/json" "flag" "fmt" "os" "path/filepath" "strings" "syscall" "time" proton "github.com/henrybear327/go-proton-api" proton_api_bridge "github.com/henrybear327/Proton-API-Bridge" "github.com/henrybear327/Proton-API-Bridge/common" _ "github.com/mattn/go-sqlite3" "golang.org/x/term" ) const ( credentialFile = "credentials.json" stateDBFile = "state.db" version = "0.1.0" ) type Config struct { SourceDir string RemoteDir string ExcludePatterns []string DryRun bool Verbose bool RateLimitMs int RetryCount int RetryDelayMs int } type BackupState struct { db *sql.DB } func main() { // Parse flags sourceDir := flag.String("source", "", "Source directory to backup") remoteDir := flag.String("remote", "", "Remote directory on Proton Drive (e.g., /backups/immich)") exclude := flag.String("exclude", "thumbs,encoded-video", "Comma-separated patterns to exclude") dryRun := flag.Bool("dry-run", false, "Show what would be uploaded without uploading") verbose := flag.Bool("verbose", false, "Verbose output") rateLimit := flag.Int("rate-limit", 2000, "Milliseconds between uploads") retryCount := flag.Int("retries", 5, "Number of retries on failure") retryDelay := flag.Int("retry-delay", 30000, "Milliseconds to wait before retry") showVersion := flag.Bool("version", false, "Show version") flag.Parse() if *showVersion { fmt.Printf("proton-backup %s\n", version) os.Exit(0) } if *sourceDir == "" || *remoteDir == "" { fmt.Println("Usage: proton-backup -source -remote ") fmt.Println() flag.PrintDefaults() os.Exit(1) } // Parse exclude patterns var excludePatterns []string if *exclude != "" { excludePatterns = strings.Split(*exclude, ",") } config := Config{ SourceDir: *sourceDir, RemoteDir: *remoteDir, ExcludePatterns: excludePatterns, DryRun: *dryRun, Verbose: *verbose, RateLimitMs: *rateLimit, RetryCount: *retryCount, RetryDelayMs: *retryDelay, } if err := run(config); err != nil { fmt.Fprintf(os.Stderr, "Error: %v\n", err) os.Exit(1) } } func run(config Config) error { ctx := context.Background() // Initialize state directory stateDir := filepath.Join(os.Getenv("HOME"), ".config", "proton-backup") if err := os.MkdirAll(stateDir, 0700); err != nil { return fmt.Errorf("failed to create state directory: %w", err) } // Initialize state database state, err := NewBackupState(filepath.Join(stateDir, stateDBFile)) if err != nil { return fmt.Errorf("failed to initialize state: %w", err) } defer state.Close() // Connect to Proton Drive credPath := filepath.Join(stateDir, credentialFile) protonDrive, err := connectProton(ctx, credPath) if err != nil { return fmt.Errorf("failed to connect to Proton Drive: %w", err) } defer protonDrive.Logout(ctx) fmt.Println("Connected to Proton Drive") // Get root folder rootLink := protonDrive.RootLink if rootLink == nil { return fmt.Errorf("failed to get root link") } // Create or get remote directory remoteFolderID, err := ensureRemoteDir(ctx, protonDrive, config.RemoteDir) if err != nil { return fmt.Errorf("failed to create remote directory: %w", err) } fmt.Printf("Remote folder: %s (ID: %s)\n", config.RemoteDir, remoteFolderID) // Walk source directory and sync return syncDirectory(ctx, protonDrive, state, config, remoteFolderID) } func connectProton(ctx context.Context, credPath string) (*proton_api_bridge.ProtonDrive, error) { cfg := proton_api_bridge.NewDefaultConfig() cfg.AppVersion = "web-drive@5.2.0" cfg.UserAgent = "proton-backup/0.1.0" cfg.ReplaceExistingDraft = true cfg.ConcurrentBlockUploadCount = 2 // Reset credentials from default config cfg.UseReusableLogin = false cfg.ReusableCredential = &common.ReusableCredentialData{} cfg.FirstLoginCredential = nil // Check for cached credentials if data, err := os.ReadFile(credPath); err == nil { var savedCreds common.ProtonDriveCredential if err := json.Unmarshal(data, &savedCreds); err == nil && savedCreds.UID != "" { cfg.UseReusableLogin = true cfg.ReusableCredential = &common.ReusableCredentialData{ UID: savedCreds.UID, AccessToken: savedCreds.AccessToken, RefreshToken: savedCreds.RefreshToken, SaltedKeyPass: savedCreds.SaltedKeyPass, } fmt.Println("Using cached credentials...") } } // If no cached credentials, prompt for login if !cfg.UseReusableLogin { fmt.Println("No cached credentials found, prompting for login...") username, password, err := promptCredentials() if err != nil { return nil, fmt.Errorf("failed to read credentials: %w", err) } if username == "" || password == "" { return nil, fmt.Errorf("username and password are required") } cfg.FirstLoginCredential = &common.FirstLoginCredentialData{ Username: username, Password: password, } fmt.Printf("Logging in as: %s\n", username) } // Auth handler - called when auth is refreshed authHandler := func(auth proton.Auth) { fmt.Println("Auth refreshed, will save on exit") } // Deauth handler deauthHandler := func() { fmt.Println("Deauthenticated, removing cached credentials") os.Remove(credPath) } protonDrive, newCreds, err := proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler) if err != nil { // If cached credentials failed, try fresh login if cfg.UseReusableLogin { fmt.Println("Cached credentials failed, please login again") os.Remove(credPath) username, password, err := promptCredentials() if err != nil { return nil, err } cfg.UseReusableLogin = false cfg.ReusableCredential = nil cfg.FirstLoginCredential = &common.FirstLoginCredentialData{ Username: username, Password: password, } protonDrive, newCreds, err = proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler) if err != nil { return nil, fmt.Errorf("authentication failed: %w", err) } } else { return nil, fmt.Errorf("authentication failed: %w", err) } } // Save credentials for next time if newCreds != nil { data, _ := json.MarshalIndent(newCreds, "", " ") if err := os.WriteFile(credPath, data, 0600); err != nil { fmt.Printf("Warning: failed to save credentials: %v\n", err) } else { fmt.Println("Credentials saved for next run") } } return protonDrive, nil } func promptCredentials() (string, string, error) { // Check environment variables first (for non-interactive use) username := os.Getenv("PROTON_USERNAME") password := os.Getenv("PROTON_PASSWORD") if username != "" && password != "" { fmt.Println("Using credentials from environment variables") return username, password, nil } reader := bufio.NewReader(os.Stdin) fmt.Print("Proton username: ") usernameInput, err := reader.ReadString('\n') if err != nil { return "", "", err } username = strings.TrimSpace(usernameInput) fmt.Print("Proton password: ") // Check if stdin is a terminal if term.IsTerminal(int(syscall.Stdin)) { passwordBytes, err := term.ReadPassword(int(syscall.Stdin)) fmt.Println() if err != nil { return "", "", err } password = string(passwordBytes) } else { // Non-terminal: read password from stdin (for piped input) passwordInput, err := reader.ReadString('\n') if err != nil { return "", "", err } password = strings.TrimSpace(passwordInput) } return username, password, nil } func ensureRemoteDir(ctx context.Context, pd *proton_api_bridge.ProtonDrive, remotePath string) (string, error) { parts := strings.Split(strings.Trim(remotePath, "/"), "/") currentID := pd.RootLink.LinkID for _, part := range parts { if part == "" { continue } // Try to find existing folder (searchForFile=false, searchForFolder=true, state=1 for active) link, err := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1) if err == nil && link != nil { currentID = link.LinkID continue } // Create folder newID, err := pd.CreateNewFolderByID(ctx, currentID, part) if err != nil { // Maybe it was created by another process, try to find it again link, err2 := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1) if err2 == nil && link != nil { currentID = link.LinkID continue } return "", fmt.Errorf("failed to create folder %s: %w", part, err) } currentID = newID fmt.Printf("Created remote folder: %s\n", part) } return currentID, nil } func syncDirectory(ctx context.Context, pd *proton_api_bridge.ProtonDrive, state *BackupState, config Config, remoteFolderID string) error { var totalFiles, uploadedFiles, skippedFiles, errorFiles int startTime := time.Now() // First pass: count files fmt.Println("Scanning source directory...") err := filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error { if err != nil { return nil // Skip errors } if !info.IsDir() && !shouldExclude(path, config.ExcludePatterns) { totalFiles++ } return nil }) if err != nil { return err } fmt.Printf("Found %d files to process\n", totalFiles) // Create a folder cache to avoid repeated API calls folderCache := make(map[string]string) folderCache[""] = remoteFolderID folderCache["."] = remoteFolderID // Second pass: upload files processed := 0 err = filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error { if err != nil { return nil } // Skip directories if info.IsDir() { return nil } // Check exclusions if shouldExclude(path, config.ExcludePatterns) { if config.Verbose { fmt.Printf("SKIP (excluded): %s\n", path) } return nil } processed++ // Get relative path relPath, err := filepath.Rel(config.SourceDir, path) if err != nil { return nil } // Check if already uploaded fileHash := hashFileInfo(path, info) if state.IsUploaded(relPath, fileHash) { skippedFiles++ if config.Verbose { fmt.Printf("SKIP (exists) [%d/%d]: %s\n", processed, totalFiles, relPath) } return nil } // Ensure parent folder exists on remote parentRelPath := filepath.Dir(relPath) parentID, err := ensureRemotePath(ctx, pd, folderCache, remoteFolderID, parentRelPath) if err != nil { fmt.Printf("ERROR (folder) [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err) errorFiles++ return nil } // Upload file if config.DryRun { fmt.Printf("DRY-RUN [%d/%d]: would upload %s\n", processed, totalFiles, relPath) uploadedFiles++ } else { err = uploadWithRetry(ctx, pd, path, parentID, info, config) if err != nil { fmt.Printf("ERROR [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err) errorFiles++ } else { uploadedFiles++ state.MarkUploaded(relPath, fileHash) fmt.Printf("OK [%d/%d]: %s (%s)\n", processed, totalFiles, relPath, humanSize(info.Size())) } // Rate limiting between uploads time.Sleep(time.Duration(config.RateLimitMs) * time.Millisecond) } return nil }) elapsed := time.Since(startTime) fmt.Printf("\n=== Summary ===\n") fmt.Printf("Duration: %s\n", elapsed.Round(time.Second)) fmt.Printf("Total files: %d\n", totalFiles) fmt.Printf("Uploaded: %d\n", uploadedFiles) fmt.Printf("Skipped (already uploaded): %d\n", skippedFiles) fmt.Printf("Errors: %d\n", errorFiles) return err } func ensureRemotePath(ctx context.Context, pd *proton_api_bridge.ProtonDrive, cache map[string]string, rootID, relPath string) (string, error) { if relPath == "." || relPath == "" { return rootID, nil } // Check cache if id, ok := cache[relPath]; ok { return id, nil } // Ensure parent exists first parentPath := filepath.Dir(relPath) parentID, err := ensureRemotePath(ctx, pd, cache, rootID, parentPath) if err != nil { return "", err } folderName := filepath.Base(relPath) // Try to find existing folder link, err := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1) if err == nil && link != nil { cache[relPath] = link.LinkID return link.LinkID, nil } // Create folder newID, err := pd.CreateNewFolderByID(ctx, parentID, folderName) if err != nil { // Maybe created concurrently, try to find again link, err2 := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1) if err2 == nil && link != nil { cache[relPath] = link.LinkID return link.LinkID, nil } return "", fmt.Errorf("failed to create folder %s: %w", folderName, err) } cache[relPath] = newID return newID, nil } func uploadWithRetry(ctx context.Context, pd *proton_api_bridge.ProtonDrive, localPath, parentID string, info os.FileInfo, config Config) error { var lastErr error for attempt := 0; attempt <= config.RetryCount; attempt++ { if attempt > 0 { delay := time.Duration(config.RetryDelayMs) * time.Millisecond // Exponential backoff for rate limit errors if lastErr != nil { errStr := lastErr.Error() if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") { delay = delay * time.Duration(attempt+1) } } fmt.Printf(" Retry %d/%d after %v...\n", attempt, config.RetryCount, delay) time.Sleep(delay) } // Open file for reading file, err := os.Open(localPath) if err != nil { return fmt.Errorf("failed to open file: %w", err) } // Upload using reader API (takes parentLinkID string) _, _, err = pd.UploadFileByReader(ctx, parentID, info.Name(), info.ModTime(), file, 0) file.Close() if err == nil { return nil } lastErr = err // Check if it's a retryable error errStr := err.Error() if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") || strings.Contains(errStr, "retry") || strings.Contains(errStr, "500") || strings.Contains(errStr, "503") { continue } // For non-retryable errors, check if file already exists if strings.Contains(errStr, "already exists") || strings.Contains(errStr, "Draft already exists") { // File exists, treat as success return nil } } return lastErr } func shouldExclude(path string, patterns []string) bool { for _, pattern := range patterns { pattern = strings.TrimSpace(pattern) if pattern == "" { continue } // Exclude if pattern appears as a path component if strings.Contains(path, string(os.PathSeparator)+pattern+string(os.PathSeparator)) || strings.HasSuffix(path, string(os.PathSeparator)+pattern) { return true } // Also exclude hidden files if strings.Contains(filepath.Base(path), "/.") || strings.HasPrefix(filepath.Base(path), ".") { return true } } return false } func hashFileInfo(path string, info os.FileInfo) string { h := sha256.New() h.Write([]byte(path)) h.Write([]byte(fmt.Sprintf("%d", info.Size()))) h.Write([]byte(fmt.Sprintf("%d", info.ModTime().Unix()))) return hex.EncodeToString(h.Sum(nil))[:16] } func humanSize(bytes int64) string { const unit = 1024 if bytes < unit { return fmt.Sprintf("%d B", bytes) } div, exp := int64(unit), 0 for n := bytes / unit; n >= unit; n /= unit { div *= unit exp++ } return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp]) } // BackupState tracks what has been uploaded func NewBackupState(dbPath string) (*BackupState, error) { db, err := sql.Open("sqlite3", dbPath) if err != nil { return nil, err } _, err = db.Exec(` CREATE TABLE IF NOT EXISTS uploaded_files ( path TEXT PRIMARY KEY, hash TEXT, uploaded_at DATETIME DEFAULT CURRENT_TIMESTAMP ) `) if err != nil { return nil, err } return &BackupState{db: db}, nil } func (s *BackupState) IsUploaded(path, hash string) bool { var storedHash string err := s.db.QueryRow("SELECT hash FROM uploaded_files WHERE path = ?", path).Scan(&storedHash) if err != nil { return false } return storedHash == hash } func (s *BackupState) MarkUploaded(path, hash string) error { _, err := s.db.Exec( "INSERT OR REPLACE INTO uploaded_files (path, hash, uploaded_at) VALUES (?, ?, CURRENT_TIMESTAMP)", path, hash, ) return err } func (s *BackupState) Close() error { return s.db.Close() }