dealspace/lib/crypto_test.go

252 lines
5.8 KiB
Go

package lib
import (
"bytes"
"strings"
"testing"
)
func TestPackUnpack(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tests := []struct {
name string
input string
}{
{"simple", "hello world"},
{"empty", ""},
{"unicode", "こんにちは世界 🌍 مرحبا"},
{"json", `{"key": "value", "nested": {"data": 123}}`},
{"large", strings.Repeat("a", 1024*1024)}, // 1MB
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
packed, err := Pack(key, tc.input)
if err != nil {
t.Fatalf("Pack failed: %v", err)
}
unpacked, err := Unpack(key, packed)
if err != nil {
t.Fatalf("Unpack failed: %v", err)
}
if unpacked != tc.input {
if len(tc.input) > 100 {
t.Errorf("round-trip failed: lengths differ (got %d, want %d)", len(unpacked), len(tc.input))
} else {
t.Errorf("round-trip failed: got %q, want %q", unpacked, tc.input)
}
}
})
}
}
func TestPackUnpackEmptyInput(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
// Unpack nil/empty ciphertext should return empty
result, err := Unpack(key, nil)
if err != nil {
t.Fatalf("Unpack nil failed: %v", err)
}
if result != "" {
t.Errorf("expected empty for nil input, got %q", result)
}
result, err = Unpack(key, []byte{})
if err != nil {
t.Fatalf("Unpack empty bytes failed: %v", err)
}
if result != "" {
t.Errorf("expected empty for empty bytes, got %q", result)
}
}
func TestBlindIndex(t *testing.T) {
key1 := make([]byte, 32)
key2 := make([]byte, 32)
for i := range key1 {
key1[i] = byte(i)
key2[i] = byte(i + 1) // Different key
}
plaintext := "searchable-term"
// Same input + same key = same index
index1 := BlindIndex(key1, plaintext)
index2 := BlindIndex(key1, plaintext)
if !bytes.Equal(index1, index2) {
t.Error("same input + key should produce same index")
}
// Same input + different key = different index
index3 := BlindIndex(key2, plaintext)
if bytes.Equal(index1, index3) {
t.Error("different keys should produce different indexes")
}
// Different input + same key = different index
index4 := BlindIndex(key1, "different-term")
if bytes.Equal(index1, index4) {
t.Error("different inputs should produce different indexes")
}
// Index should be 32 bytes (SHA-256)
if len(index1) != 32 {
t.Errorf("index length should be 32, got %d", len(index1))
}
}
func TestDeriveProjectKey(t *testing.T) {
masterKey := make([]byte, 32)
for i := range masterKey {
masterKey[i] = byte(i)
}
// Deterministic: same master + projectID = same key
key1, err := DeriveProjectKey(masterKey, "project-123")
if err != nil {
t.Fatalf("DeriveProjectKey failed: %v", err)
}
key2, err := DeriveProjectKey(masterKey, "project-123")
if err != nil {
t.Fatalf("DeriveProjectKey failed: %v", err)
}
if !bytes.Equal(key1, key2) {
t.Error("same master + projectID should produce same key")
}
// Different projectID = different key
key3, err := DeriveProjectKey(masterKey, "project-456")
if err != nil {
t.Fatalf("DeriveProjectKey failed: %v", err)
}
if bytes.Equal(key1, key3) {
t.Error("different projectID should produce different key")
}
// Key should be 32 bytes (AES-256)
if len(key1) != 32 {
t.Errorf("key length should be 32, got %d", len(key1))
}
}
func TestDeriveHMACKey(t *testing.T) {
masterKey := make([]byte, 32)
for i := range masterKey {
masterKey[i] = byte(i)
}
// HMAC key should be different from project key for same projectID
projectKey, _ := DeriveProjectKey(masterKey, "project-123")
hmacKey, err := DeriveHMACKey(masterKey, "project-123")
if err != nil {
t.Fatalf("DeriveHMACKey failed: %v", err)
}
if bytes.Equal(projectKey, hmacKey) {
t.Error("HMAC key should differ from project key")
}
// HMAC key should be 32 bytes
if len(hmacKey) != 32 {
t.Errorf("HMAC key length should be 32, got %d", len(hmacKey))
}
}
func TestAESGCM(t *testing.T) {
key := make([]byte, 32)
for i := range key {
key[i] = byte(i)
}
tests := []struct {
name string
data []byte
}{
{"simple", []byte("hello world")},
{"binary", []byte{0x00, 0x01, 0x02, 0xff, 0xfe}},
{"large", bytes.Repeat([]byte("x"), 1024*1024)}, // 1MB
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
encrypted, err := ObjectEncrypt(key, tc.data)
if err != nil {
t.Fatalf("ObjectEncrypt failed: %v", err)
}
decrypted, err := ObjectDecrypt(key, encrypted)
if err != nil {
t.Fatalf("ObjectDecrypt failed: %v", err)
}
if !bytes.Equal(decrypted, tc.data) {
t.Errorf("round-trip failed")
}
})
}
}
func TestObjectEncryptDecryptWrongKey(t *testing.T) {
key1 := make([]byte, 32)
key2 := make([]byte, 32)
for i := range key1 {
key1[i] = byte(i)
key2[i] = byte(i + 1)
}
data := []byte("secret data")
encrypted, err := ObjectEncrypt(key1, data)
if err != nil {
t.Fatalf("ObjectEncrypt failed: %v", err)
}
_, err = ObjectDecrypt(key2, encrypted)
if err == nil {
t.Error("decrypt with wrong key should fail")
}
}
func TestObjectDecryptInvalidCiphertext(t *testing.T) {
key := make([]byte, 32)
// Too short ciphertext
_, err := ObjectDecrypt(key, []byte{1, 2, 3})
if err == nil {
t.Error("decrypt too-short ciphertext should fail")
}
// Nil ciphertext
_, err = ObjectDecrypt(key, nil)
if err != ErrInvalidCiphertext {
t.Error("decrypt nil should return ErrInvalidCiphertext")
}
}
func TestContentHash(t *testing.T) {
data := []byte("test data")
hash1 := ContentHash(data)
hash2 := ContentHash(data)
if !bytes.Equal(hash1, hash2) {
t.Error("same data should produce same hash")
}
hash3 := ContentHash([]byte("different data"))
if bytes.Equal(hash1, hash3) {
t.Error("different data should produce different hash")
}
if len(hash1) != 32 {
t.Errorf("hash length should be 32, got %d", len(hash1))
}
}