256 lines
6.1 KiB
Go
256 lines
6.1 KiB
Go
package store
|
|
|
|
import (
|
|
"crypto/aes"
|
|
"crypto/cipher"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"fmt"
|
|
"io"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
)
|
|
|
|
// Store handles encrypted file storage
|
|
type Store struct {
|
|
basePath string
|
|
key []byte
|
|
gcm cipher.AEAD
|
|
encoder *zstd.Encoder
|
|
decoder *zstd.Decoder
|
|
}
|
|
|
|
// New creates a new encrypted file store
|
|
func New(basePath string, key []byte) (*Store, error) {
|
|
// Ensure base path exists
|
|
if err := os.MkdirAll(basePath, 0755); err != nil {
|
|
return nil, fmt.Errorf("failed to create base path: %w", err)
|
|
}
|
|
|
|
// Create AES cipher
|
|
block, err := aes.NewCipher(key)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create AES cipher: %w", err)
|
|
}
|
|
|
|
gcm, err := cipher.NewGCM(block)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create GCM: %w", err)
|
|
}
|
|
|
|
// Create compression encoder/decoder
|
|
encoder, err := zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedFastest))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create zstd encoder: %w", err)
|
|
}
|
|
|
|
decoder, err := zstd.NewReader(nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create zstd decoder: %w", err)
|
|
}
|
|
|
|
return &Store{
|
|
basePath: basePath,
|
|
key: key,
|
|
gcm: gcm,
|
|
encoder: encoder,
|
|
decoder: decoder,
|
|
}, nil
|
|
}
|
|
|
|
// Store saves data to encrypted storage and returns the file path
|
|
func (s *Store) Store(entryID string, data []byte) (string, string, error) {
|
|
// Calculate hash of original data
|
|
hash := fmt.Sprintf("%x", sha256.Sum256(data))
|
|
|
|
// Compress data
|
|
compressed := s.encoder.EncodeAll(data, nil)
|
|
|
|
// Encrypt compressed data
|
|
nonce := make([]byte, s.gcm.NonceSize())
|
|
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
|
return "", "", fmt.Errorf("failed to generate nonce: %w", err)
|
|
}
|
|
|
|
encrypted := s.gcm.Seal(nonce, nonce, compressed, nil)
|
|
|
|
// Generate date-based path
|
|
now := time.Now()
|
|
datePath := fmt.Sprintf("%d/%02d", now.Year(), now.Month())
|
|
fullDir := filepath.Join(s.basePath, datePath)
|
|
|
|
// Ensure directory exists
|
|
if err := os.MkdirAll(fullDir, 0755); err != nil {
|
|
return "", "", fmt.Errorf("failed to create date directory: %w", err)
|
|
}
|
|
|
|
// Write encrypted data
|
|
filePath := filepath.Join(datePath, entryID+".enc")
|
|
fullPath := filepath.Join(s.basePath, filePath)
|
|
|
|
if err := os.WriteFile(fullPath, encrypted, 0644); err != nil {
|
|
return "", "", fmt.Errorf("failed to write encrypted file: %w", err)
|
|
}
|
|
|
|
return filePath, hash, nil
|
|
}
|
|
|
|
// Retrieve loads and decrypts data from storage
|
|
func (s *Store) Retrieve(filePath string) ([]byte, error) {
|
|
fullPath := filepath.Join(s.basePath, filePath)
|
|
|
|
// Read encrypted data
|
|
encrypted, err := os.ReadFile(fullPath)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to read encrypted file: %w", err)
|
|
}
|
|
|
|
// Extract nonce
|
|
if len(encrypted) < s.gcm.NonceSize() {
|
|
return nil, fmt.Errorf("encrypted data too short")
|
|
}
|
|
|
|
nonce := encrypted[:s.gcm.NonceSize()]
|
|
ciphertext := encrypted[s.gcm.NonceSize():]
|
|
|
|
// Decrypt
|
|
compressed, err := s.gcm.Open(nil, nonce, ciphertext, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decrypt data: %w", err)
|
|
}
|
|
|
|
// Decompress
|
|
data, err := s.decoder.DecodeAll(compressed, nil)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to decompress data: %w", err)
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
// Delete removes a file from storage
|
|
func (s *Store) Delete(filePath string) error {
|
|
fullPath := filepath.Join(s.basePath, filePath)
|
|
return os.Remove(fullPath)
|
|
}
|
|
|
|
// Exists checks if a file exists in storage
|
|
func (s *Store) Exists(filePath string) bool {
|
|
fullPath := filepath.Join(s.basePath, filePath)
|
|
_, err := os.Stat(fullPath)
|
|
return err == nil
|
|
}
|
|
|
|
// StoreTemp saves data to temporary storage for processing
|
|
func (s *Store) StoreTemp(filename string, data []byte) (string, error) {
|
|
tempDir := filepath.Join(s.basePath, "temp")
|
|
if err := os.MkdirAll(tempDir, 0755); err != nil {
|
|
return "", fmt.Errorf("failed to create temp directory: %w", err)
|
|
}
|
|
|
|
// Sanitize filename
|
|
sanitized := sanitizeFilename(filename)
|
|
tempPath := filepath.Join(tempDir, fmt.Sprintf("%d_%s", time.Now().Unix(), sanitized))
|
|
|
|
if err := os.WriteFile(tempPath, data, 0644); err != nil {
|
|
return "", fmt.Errorf("failed to write temp file: %w", err)
|
|
}
|
|
|
|
return tempPath, nil
|
|
}
|
|
|
|
// CleanupTemp removes temporary files older than the specified duration
|
|
func (s *Store) CleanupTemp(maxAge time.Duration) error {
|
|
tempDir := filepath.Join(s.basePath, "temp")
|
|
|
|
entries, err := os.ReadDir(tempDir)
|
|
if err != nil {
|
|
if os.IsNotExist(err) {
|
|
return nil // No temp directory, nothing to clean
|
|
}
|
|
return fmt.Errorf("failed to read temp directory: %w", err)
|
|
}
|
|
|
|
cutoff := time.Now().Add(-maxAge)
|
|
|
|
for _, entry := range entries {
|
|
if entry.IsDir() {
|
|
continue
|
|
}
|
|
|
|
info, err := entry.Info()
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if info.ModTime().Before(cutoff) {
|
|
filePath := filepath.Join(tempDir, entry.Name())
|
|
os.Remove(filePath) // Ignore errors, best effort cleanup
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// GetStats returns storage statistics
|
|
func (s *Store) GetStats() (*StorageStats, error) {
|
|
stats := &StorageStats{}
|
|
|
|
err := filepath.Walk(s.basePath, func(path string, info os.FileInfo, err error) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if !info.IsDir() {
|
|
stats.TotalFiles++
|
|
stats.TotalSize += info.Size()
|
|
|
|
if strings.HasSuffix(info.Name(), ".enc") {
|
|
stats.EncryptedFiles++
|
|
stats.EncryptedSize += info.Size()
|
|
}
|
|
}
|
|
|
|
return nil
|
|
})
|
|
|
|
return stats, err
|
|
}
|
|
|
|
// StorageStats represents storage usage statistics
|
|
type StorageStats struct {
|
|
TotalFiles int `json:"total_files"`
|
|
TotalSize int64 `json:"total_size"`
|
|
EncryptedFiles int `json:"encrypted_files"`
|
|
EncryptedSize int64 `json:"encrypted_size"`
|
|
}
|
|
|
|
// Close cleanly shuts down the store
|
|
func (s *Store) Close() error {
|
|
s.encoder.Close()
|
|
s.decoder.Close()
|
|
return nil
|
|
}
|
|
|
|
// Helper functions
|
|
|
|
func sanitizeFilename(filename string) string {
|
|
// Replace unsafe characters
|
|
unsafe := []string{"/", "\\", ":", "*", "?", "\"", "<", ">", "|"}
|
|
safe := filename
|
|
|
|
for _, char := range unsafe {
|
|
safe = strings.ReplaceAll(safe, char, "_")
|
|
}
|
|
|
|
// Limit length
|
|
if len(safe) > 200 {
|
|
safe = safe[:200]
|
|
}
|
|
|
|
return safe
|
|
} |