dealroom/internal/store/store.go

256 lines
6.1 KiB
Go

package store
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"fmt"
"io"
"os"
"path/filepath"
"strings"
"time"
"github.com/klauspost/compress/zstd"
)
// Store handles encrypted file storage
type Store struct {
basePath string
key []byte
gcm cipher.AEAD
encoder *zstd.Encoder
decoder *zstd.Decoder
}
// New creates a new encrypted file store
func New(basePath string, key []byte) (*Store, error) {
// Ensure base path exists
if err := os.MkdirAll(basePath, 0755); err != nil {
return nil, fmt.Errorf("failed to create base path: %w", err)
}
// Create AES cipher
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return nil, fmt.Errorf("failed to create GCM: %w", err)
}
// Create compression encoder/decoder
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
if err != nil {
return nil, fmt.Errorf("failed to create zstd encoder: %w", err)
}
decoder, err := zstd.NewReader(nil)
if err != nil {
return nil, fmt.Errorf("failed to create zstd decoder: %w", err)
}
return &Store{
basePath: basePath,
key: key,
gcm: gcm,
encoder: encoder,
decoder: decoder,
}, nil
}
// Store saves data to encrypted storage and returns the file path
func (s *Store) Store(entryID string, data []byte) (string, string, error) {
// Calculate hash of original data
hash := fmt.Sprintf("%x", sha256.Sum256(data))
// Compress data
compressed := s.encoder.EncodeAll(data, nil)
// Encrypt compressed data
nonce := make([]byte, s.gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return "", "", fmt.Errorf("failed to generate nonce: %w", err)
}
encrypted := s.gcm.Seal(nonce, nonce, compressed, nil)
// Generate date-based path
now := time.Now()
datePath := fmt.Sprintf("%d/%02d", now.Year(), now.Month())
fullDir := filepath.Join(s.basePath, datePath)
// Ensure directory exists
if err := os.MkdirAll(fullDir, 0755); err != nil {
return "", "", fmt.Errorf("failed to create date directory: %w", err)
}
// Write encrypted data
filePath := filepath.Join(datePath, entryID+".enc")
fullPath := filepath.Join(s.basePath, filePath)
if err := os.WriteFile(fullPath, encrypted, 0644); err != nil {
return "", "", fmt.Errorf("failed to write encrypted file: %w", err)
}
return filePath, hash, nil
}
// Retrieve loads and decrypts data from storage
func (s *Store) Retrieve(filePath string) ([]byte, error) {
fullPath := filepath.Join(s.basePath, filePath)
// Read encrypted data
encrypted, err := os.ReadFile(fullPath)
if err != nil {
return nil, fmt.Errorf("failed to read encrypted file: %w", err)
}
// Extract nonce
if len(encrypted) < s.gcm.NonceSize() {
return nil, fmt.Errorf("encrypted data too short")
}
nonce := encrypted[:s.gcm.NonceSize()]
ciphertext := encrypted[s.gcm.NonceSize():]
// Decrypt
compressed, err := s.gcm.Open(nil, nonce, ciphertext, nil)
if err != nil {
return nil, fmt.Errorf("failed to decrypt data: %w", err)
}
// Decompress
data, err := s.decoder.DecodeAll(compressed, nil)
if err != nil {
return nil, fmt.Errorf("failed to decompress data: %w", err)
}
return data, nil
}
// Delete removes a file from storage
func (s *Store) Delete(filePath string) error {
fullPath := filepath.Join(s.basePath, filePath)
return os.Remove(fullPath)
}
// Exists checks if a file exists in storage
func (s *Store) Exists(filePath string) bool {
fullPath := filepath.Join(s.basePath, filePath)
_, err := os.Stat(fullPath)
return err == nil
}
// StoreTemp saves data to temporary storage for processing
func (s *Store) StoreTemp(filename string, data []byte) (string, error) {
tempDir := filepath.Join(s.basePath, "temp")
if err := os.MkdirAll(tempDir, 0755); err != nil {
return "", fmt.Errorf("failed to create temp directory: %w", err)
}
// Sanitize filename
sanitized := sanitizeFilename(filename)
tempPath := filepath.Join(tempDir, fmt.Sprintf("%d_%s", time.Now().Unix(), sanitized))
if err := os.WriteFile(tempPath, data, 0644); err != nil {
return "", fmt.Errorf("failed to write temp file: %w", err)
}
return tempPath, nil
}
// CleanupTemp removes temporary files older than the specified duration
func (s *Store) CleanupTemp(maxAge time.Duration) error {
tempDir := filepath.Join(s.basePath, "temp")
entries, err := os.ReadDir(tempDir)
if err != nil {
if os.IsNotExist(err) {
return nil // No temp directory, nothing to clean
}
return fmt.Errorf("failed to read temp directory: %w", err)
}
cutoff := time.Now().Add(-maxAge)
for _, entry := range entries {
if entry.IsDir() {
continue
}
info, err := entry.Info()
if err != nil {
continue
}
if info.ModTime().Before(cutoff) {
filePath := filepath.Join(tempDir, entry.Name())
os.Remove(filePath) // Ignore errors, best effort cleanup
}
}
return nil
}
// GetStats returns storage statistics
func (s *Store) GetStats() (*StorageStats, error) {
stats := &StorageStats{}
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}
if !info.IsDir() {
stats.TotalFiles++
stats.TotalSize += info.Size()
if strings.HasSuffix(info.Name(), ".enc") {
stats.EncryptedFiles++
stats.EncryptedSize += info.Size()
}
}
return nil
})
return stats, err
}
// StorageStats represents storage usage statistics
type StorageStats struct {
TotalFiles int `json:"total_files"`
TotalSize int64 `json:"total_size"`
EncryptedFiles int `json:"encrypted_files"`
EncryptedSize int64 `json:"encrypted_size"`
}
// Close cleanly shuts down the store
func (s *Store) Close() error {
s.encoder.Close()
s.decoder.Close()
return nil
}
// Helper functions
func sanitizeFilename(filename string) string {
// Replace unsafe characters
unsafe := []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|"}
safe := filename
for _, char := range unsafe {
safe = strings.ReplaceAll(safe, char, "_")
}
// Limit length
if len(safe) > 200 {
safe = safe[:200]
}
return safe
}