package main
import (
"bufio"
"context"
"crypto/sha256"
"database/sql"
"encoding/hex"
"encoding/json"
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"syscall"
"time"
proton "github.com/henrybear327/go-proton-api"
proton_api_bridge "github.com/henrybear327/Proton-API-Bridge"
"github.com/henrybear327/Proton-API-Bridge/common"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/term"
)
const (
credentialFile = "credentials.json"
stateDBFile = "state.db"
version = "0.1.0"
)
type Config struct {
SourceDir string
RemoteDir string
ExcludePatterns []string
DryRun bool
Verbose bool
RateLimitMs int
RetryCount int
RetryDelayMs int
}
type BackupState struct {
db *sql.DB
}
func main() {
// Parse flags
sourceDir := flag.String("source", "", "Source directory to backup")
remoteDir := flag.String("remote", "", "Remote directory on Proton Drive (e.g., /backups/immich)")
exclude := flag.String("exclude", "thumbs,encoded-video", "Comma-separated patterns to exclude")
dryRun := flag.Bool("dry-run", false, "Show what would be uploaded without uploading")
verbose := flag.Bool("verbose", false, "Verbose output")
rateLimit := flag.Int("rate-limit", 2000, "Milliseconds between uploads")
retryCount := flag.Int("retries", 5, "Number of retries on failure")
retryDelay := flag.Int("retry-delay", 30000, "Milliseconds to wait before retry")
showVersion := flag.Bool("version", false, "Show version")
flag.Parse()
if *showVersion {
fmt.Printf("proton-backup %s\n", version)
os.Exit(0)
}
if *sourceDir == "" || *remoteDir == "" {
fmt.Println("Usage: proton-backup -source
-remote ")
fmt.Println()
flag.PrintDefaults()
os.Exit(1)
}
// Parse exclude patterns
var excludePatterns []string
if *exclude != "" {
excludePatterns = strings.Split(*exclude, ",")
}
config := Config{
SourceDir: *sourceDir,
RemoteDir: *remoteDir,
ExcludePatterns: excludePatterns,
DryRun: *dryRun,
Verbose: *verbose,
RateLimitMs: *rateLimit,
RetryCount: *retryCount,
RetryDelayMs: *retryDelay,
}
if err := run(config); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
os.Exit(1)
}
}
func run(config Config) error {
ctx := context.Background()
// Initialize state directory
stateDir := filepath.Join(os.Getenv("HOME"), ".config", "proton-backup")
if err := os.MkdirAll(stateDir, 0700); err != nil {
return fmt.Errorf("failed to create state directory: %w", err)
}
// Initialize state database
state, err := NewBackupState(filepath.Join(stateDir, stateDBFile))
if err != nil {
return fmt.Errorf("failed to initialize state: %w", err)
}
defer state.Close()
// Connect to Proton Drive
credPath := filepath.Join(stateDir, credentialFile)
protonDrive, err := connectProton(ctx, credPath)
if err != nil {
return fmt.Errorf("failed to connect to Proton Drive: %w", err)
}
defer protonDrive.Logout(ctx)
fmt.Println("Connected to Proton Drive")
// Get root folder
rootLink := protonDrive.RootLink
if rootLink == nil {
return fmt.Errorf("failed to get root link")
}
// Create or get remote directory
remoteFolderID, err := ensureRemoteDir(ctx, protonDrive, config.RemoteDir)
if err != nil {
return fmt.Errorf("failed to create remote directory: %w", err)
}
fmt.Printf("Remote folder: %s (ID: %s)\n", config.RemoteDir, remoteFolderID)
// Walk source directory and sync
return syncDirectory(ctx, protonDrive, state, config, remoteFolderID)
}
func connectProton(ctx context.Context, credPath string) (*proton_api_bridge.ProtonDrive, error) {
cfg := proton_api_bridge.NewDefaultConfig()
cfg.AppVersion = "web-drive@5.2.0"
cfg.UserAgent = "proton-backup/0.1.0"
cfg.ReplaceExistingDraft = true
cfg.ConcurrentBlockUploadCount = 2
// Reset credentials from default config
cfg.UseReusableLogin = false
cfg.ReusableCredential = &common.ReusableCredentialData{}
cfg.FirstLoginCredential = nil
// Check for cached credentials
if data, err := os.ReadFile(credPath); err == nil {
var savedCreds common.ProtonDriveCredential
if err := json.Unmarshal(data, &savedCreds); err == nil && savedCreds.UID != "" {
cfg.UseReusableLogin = true
cfg.ReusableCredential = &common.ReusableCredentialData{
UID: savedCreds.UID,
AccessToken: savedCreds.AccessToken,
RefreshToken: savedCreds.RefreshToken,
SaltedKeyPass: savedCreds.SaltedKeyPass,
}
fmt.Println("Using cached credentials...")
}
}
// If no cached credentials, prompt for login
if !cfg.UseReusableLogin {
fmt.Println("No cached credentials found, prompting for login...")
username, password, err := promptCredentials()
if err != nil {
return nil, fmt.Errorf("failed to read credentials: %w", err)
}
if username == "" || password == "" {
return nil, fmt.Errorf("username and password are required")
}
cfg.FirstLoginCredential = &common.FirstLoginCredentialData{
Username: username,
Password: password,
}
fmt.Printf("Logging in as: %s\n", username)
}
// Auth handler - called when auth is refreshed
authHandler := func(auth proton.Auth) {
fmt.Println("Auth refreshed, will save on exit")
}
// Deauth handler
deauthHandler := func() {
fmt.Println("Deauthenticated, removing cached credentials")
os.Remove(credPath)
}
protonDrive, newCreds, err := proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler)
if err != nil {
// If cached credentials failed, try fresh login
if cfg.UseReusableLogin {
fmt.Println("Cached credentials failed, please login again")
os.Remove(credPath)
username, password, err := promptCredentials()
if err != nil {
return nil, err
}
cfg.UseReusableLogin = false
cfg.ReusableCredential = nil
cfg.FirstLoginCredential = &common.FirstLoginCredentialData{
Username: username,
Password: password,
}
protonDrive, newCreds, err = proton_api_bridge.NewProtonDrive(ctx, cfg, authHandler, deauthHandler)
if err != nil {
return nil, fmt.Errorf("authentication failed: %w", err)
}
} else {
return nil, fmt.Errorf("authentication failed: %w", err)
}
}
// Save credentials for next time
if newCreds != nil {
data, _ := json.MarshalIndent(newCreds, "", " ")
if err := os.WriteFile(credPath, data, 0600); err != nil {
fmt.Printf("Warning: failed to save credentials: %v\n", err)
} else {
fmt.Println("Credentials saved for next run")
}
}
return protonDrive, nil
}
func promptCredentials() (string, string, error) {
// Check environment variables first (for non-interactive use)
username := os.Getenv("PROTON_USERNAME")
password := os.Getenv("PROTON_PASSWORD")
if username != "" && password != "" {
fmt.Println("Using credentials from environment variables")
return username, password, nil
}
reader := bufio.NewReader(os.Stdin)
fmt.Print("Proton username: ")
usernameInput, err := reader.ReadString('\n')
if err != nil {
return "", "", err
}
username = strings.TrimSpace(usernameInput)
fmt.Print("Proton password: ")
// Check if stdin is a terminal
if term.IsTerminal(int(syscall.Stdin)) {
passwordBytes, err := term.ReadPassword(int(syscall.Stdin))
fmt.Println()
if err != nil {
return "", "", err
}
password = string(passwordBytes)
} else {
// Non-terminal: read password from stdin (for piped input)
passwordInput, err := reader.ReadString('\n')
if err != nil {
return "", "", err
}
password = strings.TrimSpace(passwordInput)
}
return username, password, nil
}
func ensureRemoteDir(ctx context.Context, pd *proton_api_bridge.ProtonDrive, remotePath string) (string, error) {
parts := strings.Split(strings.Trim(remotePath, "/"), "/")
currentID := pd.RootLink.LinkID
for _, part := range parts {
if part == "" {
continue
}
// Try to find existing folder (searchForFile=false, searchForFolder=true, state=1 for active)
link, err := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1)
if err == nil && link != nil {
currentID = link.LinkID
continue
}
// Create folder
newID, err := pd.CreateNewFolderByID(ctx, currentID, part)
if err != nil {
// Maybe it was created by another process, try to find it again
link, err2 := pd.SearchByNameInActiveFolderByID(ctx, currentID, part, false, true, 1)
if err2 == nil && link != nil {
currentID = link.LinkID
continue
}
return "", fmt.Errorf("failed to create folder %s: %w", part, err)
}
currentID = newID
fmt.Printf("Created remote folder: %s\n", part)
}
return currentID, nil
}
func syncDirectory(ctx context.Context, pd *proton_api_bridge.ProtonDrive, state *BackupState, config Config, remoteFolderID string) error {
var totalFiles, uploadedFiles, skippedFiles, errorFiles int
startTime := time.Now()
// First pass: count files
fmt.Println("Scanning source directory...")
err := filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil // Skip errors
}
if !info.IsDir() && !shouldExclude(path, config.ExcludePatterns) {
totalFiles++
}
return nil
})
if err != nil {
return err
}
fmt.Printf("Found %d files to process\n", totalFiles)
// Create a folder cache to avoid repeated API calls
folderCache := make(map[string]string)
folderCache[""] = remoteFolderID
folderCache["."] = remoteFolderID
// Second pass: upload files
processed := 0
err = filepath.Walk(config.SourceDir, func(path string, info os.FileInfo, err error) error {
if err != nil {
return nil
}
// Skip directories
if info.IsDir() {
return nil
}
// Check exclusions
if shouldExclude(path, config.ExcludePatterns) {
if config.Verbose {
fmt.Printf("SKIP (excluded): %s\n", path)
}
return nil
}
processed++
// Get relative path
relPath, err := filepath.Rel(config.SourceDir, path)
if err != nil {
return nil
}
// Check if already uploaded
fileHash := hashFileInfo(path, info)
if state.IsUploaded(relPath, fileHash) {
skippedFiles++
if config.Verbose {
fmt.Printf("SKIP (exists) [%d/%d]: %s\n", processed, totalFiles, relPath)
}
return nil
}
// Ensure parent folder exists on remote
parentRelPath := filepath.Dir(relPath)
parentID, err := ensureRemotePath(ctx, pd, folderCache, remoteFolderID, parentRelPath)
if err != nil {
fmt.Printf("ERROR (folder) [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err)
errorFiles++
return nil
}
// Upload file
if config.DryRun {
fmt.Printf("DRY-RUN [%d/%d]: would upload %s\n", processed, totalFiles, relPath)
uploadedFiles++
} else {
err = uploadWithRetry(ctx, pd, path, parentID, info, config)
if err != nil {
fmt.Printf("ERROR [%d/%d]: %s - %v\n", processed, totalFiles, relPath, err)
errorFiles++
} else {
uploadedFiles++
state.MarkUploaded(relPath, fileHash)
fmt.Printf("OK [%d/%d]: %s (%s)\n", processed, totalFiles, relPath, humanSize(info.Size()))
}
// Rate limiting between uploads
time.Sleep(time.Duration(config.RateLimitMs) * time.Millisecond)
}
return nil
})
elapsed := time.Since(startTime)
fmt.Printf("\n=== Summary ===\n")
fmt.Printf("Duration: %s\n", elapsed.Round(time.Second))
fmt.Printf("Total files: %d\n", totalFiles)
fmt.Printf("Uploaded: %d\n", uploadedFiles)
fmt.Printf("Skipped (already uploaded): %d\n", skippedFiles)
fmt.Printf("Errors: %d\n", errorFiles)
return err
}
func ensureRemotePath(ctx context.Context, pd *proton_api_bridge.ProtonDrive, cache map[string]string, rootID, relPath string) (string, error) {
if relPath == "." || relPath == "" {
return rootID, nil
}
// Check cache
if id, ok := cache[relPath]; ok {
return id, nil
}
// Ensure parent exists first
parentPath := filepath.Dir(relPath)
parentID, err := ensureRemotePath(ctx, pd, cache, rootID, parentPath)
if err != nil {
return "", err
}
folderName := filepath.Base(relPath)
// Try to find existing folder
link, err := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1)
if err == nil && link != nil {
cache[relPath] = link.LinkID
return link.LinkID, nil
}
// Create folder
newID, err := pd.CreateNewFolderByID(ctx, parentID, folderName)
if err != nil {
// Maybe created concurrently, try to find again
link, err2 := pd.SearchByNameInActiveFolderByID(ctx, parentID, folderName, false, true, 1)
if err2 == nil && link != nil {
cache[relPath] = link.LinkID
return link.LinkID, nil
}
return "", fmt.Errorf("failed to create folder %s: %w", folderName, err)
}
cache[relPath] = newID
return newID, nil
}
func uploadWithRetry(ctx context.Context, pd *proton_api_bridge.ProtonDrive, localPath, parentID string, info os.FileInfo, config Config) error {
var lastErr error
for attempt := 0; attempt <= config.RetryCount; attempt++ {
if attempt > 0 {
delay := time.Duration(config.RetryDelayMs) * time.Millisecond
// Exponential backoff for rate limit errors
if lastErr != nil {
errStr := lastErr.Error()
if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") {
delay = delay * time.Duration(attempt+1)
}
}
fmt.Printf(" Retry %d/%d after %v...\n", attempt, config.RetryCount, delay)
time.Sleep(delay)
}
// Open file for reading
file, err := os.Open(localPath)
if err != nil {
return fmt.Errorf("failed to open file: %w", err)
}
// Upload using reader API (takes parentLinkID string)
_, _, err = pd.UploadFileByReader(ctx, parentID, info.Name(), info.ModTime(), file, 0)
file.Close()
if err == nil {
return nil
}
lastErr = err
// Check if it's a retryable error
errStr := err.Error()
if strings.Contains(errStr, "422") || strings.Contains(errStr, "429") ||
strings.Contains(errStr, "retry") || strings.Contains(errStr, "500") ||
strings.Contains(errStr, "503") {
continue
}
// For non-retryable errors, check if file already exists
if strings.Contains(errStr, "already exists") || strings.Contains(errStr, "Draft already exists") {
// File exists, treat as success
return nil
}
}
return lastErr
}
func shouldExclude(path string, patterns []string) bool {
for _, pattern := range patterns {
pattern = strings.TrimSpace(pattern)
if pattern == "" {
continue
}
// Exclude if pattern appears as a path component
if strings.Contains(path, string(os.PathSeparator)+pattern+string(os.PathSeparator)) ||
strings.HasSuffix(path, string(os.PathSeparator)+pattern) {
return true
}
// Also exclude hidden files
if strings.Contains(filepath.Base(path), "/.") || strings.HasPrefix(filepath.Base(path), ".") {
return true
}
}
return false
}
func hashFileInfo(path string, info os.FileInfo) string {
h := sha256.New()
h.Write([]byte(path))
h.Write([]byte(fmt.Sprintf("%d", info.Size())))
h.Write([]byte(fmt.Sprintf("%d", info.ModTime().Unix())))
return hex.EncodeToString(h.Sum(nil))[:16]
}
func humanSize(bytes int64) string {
const unit = 1024
if bytes < unit {
return fmt.Sprintf("%d B", bytes)
}
div, exp := int64(unit), 0
for n := bytes / unit; n >= unit; n /= unit {
div *= unit
exp++
}
return fmt.Sprintf("%.1f %cB", float64(bytes)/float64(div), "KMGTPE"[exp])
}
// BackupState tracks what has been uploaded
func NewBackupState(dbPath string) (*BackupState, error) {
db, err := sql.Open("sqlite3", dbPath)
if err != nil {
return nil, err
}
_, err = db.Exec(`
CREATE TABLE IF NOT EXISTS uploaded_files (
path TEXT PRIMARY KEY,
hash TEXT,
uploaded_at DATETIME DEFAULT CURRENT_TIMESTAMP
)
`)
if err != nil {
return nil, err
}
return &BackupState{db: db}, nil
}
func (s *BackupState) IsUploaded(path, hash string) bool {
var storedHash string
err := s.db.QueryRow("SELECT hash FROM uploaded_files WHERE path = ?", path).Scan(&storedHash)
if err != nil {
return false
}
return storedHash == hash
}
func (s *BackupState) MarkUploaded(path, hash string) error {
_, err := s.db.Exec(
"INSERT OR REPLACE INTO uploaded_files (path, hash, uploaded_at) VALUES (?, ?, CURRENT_TIMESTAMP)",
path, hash,
)
return err
}
func (s *BackupState) Close() error {
return s.db.Close()
}