package lib import ( "crypto/aes" "crypto/cipher" "crypto/hmac" "crypto/rand" "crypto/sha256" "errors" "io" "github.com/klauspost/compress/zstd" "golang.org/x/crypto/hkdf" ) var ( ErrDecryptionFailed = errors.New("decryption failed") ErrInvalidCiphertext = errors.New("invalid ciphertext") ) // DeriveEntryKey derives a per-entry AES-256 key from the vault key using HKDF-SHA256. func DeriveEntryKey(vaultKey []byte, entryID string) ([]byte, error) { info := []byte("clawvault-entry-" + entryID) reader := hkdf.New(sha256.New, vaultKey, nil, info) key := make([]byte, 32) // AES-256 if _, err := io.ReadFull(reader, key); err != nil { return nil, err } return key, nil } // DeriveHMACKey derives a separate HMAC key for blind indexes. func DeriveHMACKey(vaultKey []byte) ([]byte, error) { info := []byte("clawvault-hmac-index") reader := hkdf.New(sha256.New, vaultKey, nil, info) key := make([]byte, 32) if _, err := io.ReadFull(reader, key); err != nil { return nil, err } return key, nil } // BlindIndex computes an HMAC-SHA256 blind index for searchable encrypted fields. // Returns truncated hash (16 bytes) for storage efficiency. func BlindIndex(hmacKey []byte, plaintext string) []byte { h := hmac.New(sha256.New, hmacKey) h.Write([]byte(plaintext)) return h.Sum(nil)[:16] // truncate to 16 bytes } // Pack compresses with zstd then encrypts with AES-256-GCM (random nonce). func Pack(key []byte, plaintext string) ([]byte, error) { compressed, err := zstdCompress([]byte(plaintext)) if err != nil { return nil, err } block, err := aes.NewCipher(key) if err != nil { return nil, err } gcm, err := cipher.NewGCM(block) if err != nil { return nil, err } nonce := make([]byte, gcm.NonceSize()) if _, err := io.ReadFull(rand.Reader, nonce); err != nil { return nil, err } return gcm.Seal(nonce, nonce, compressed, nil), nil } // Unpack decrypts AES-256-GCM then decompresses zstd. func Unpack(key []byte, ciphertext []byte) (string, error) { if len(ciphertext) == 0 { return "", nil } block, err := aes.NewCipher(key) if err != nil { return "", err } gcm, err := cipher.NewGCM(block) if err != nil { return "", err } nonceSize := gcm.NonceSize() if len(ciphertext) < nonceSize { return "", ErrInvalidCiphertext } nonce, ct := ciphertext[:nonceSize], ciphertext[nonceSize:] compressed, err := gcm.Open(nil, nonce, ct, nil) if err != nil { return "", ErrDecryptionFailed } decompressed, err := zstdDecompress(compressed) if err != nil { return "", err } return string(decompressed), nil } // zstd encoder/decoder (reusable, goroutine-safe) var ( zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedDefault)) zstdDecoder, _ = zstd.NewReader(nil) ) func zstdCompress(data []byte) ([]byte, error) { return zstdEncoder.EncodeAll(data, nil), nil } func zstdDecompress(data []byte) ([]byte, error) { return zstdDecoder.DecodeAll(data, nil) } // GenerateToken generates a random hex token (32 bytes = 64 hex chars). func GenerateToken() string { b := make([]byte, 32) rand.Read(b) const hex = "0123456789abcdef" result := make([]byte, 64) for i, v := range b { result[i*2] = hex[v>>4] result[i*2+1] = hex[v&0x0f] } return string(result) }