252 lines
5.8 KiB
Go
252 lines
5.8 KiB
Go
package lib
|
|
|
|
import (
|
|
"bytes"
|
|
"strings"
|
|
"testing"
|
|
)
|
|
|
|
func TestPackUnpack(t *testing.T) {
|
|
key := make([]byte, 32)
|
|
for i := range key {
|
|
key[i] = byte(i)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
input string
|
|
}{
|
|
{"simple", "hello world"},
|
|
{"empty", ""},
|
|
{"unicode", "こんにちは世界 🌍 مرحبا"},
|
|
{"json", `{"key": "value", "nested": {"data": 123}}`},
|
|
{"large", strings.Repeat("a", 1024*1024)}, // 1MB
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
packed, err := Pack(key, tc.input)
|
|
if err != nil {
|
|
t.Fatalf("Pack failed: %v", err)
|
|
}
|
|
|
|
unpacked, err := Unpack(key, packed)
|
|
if err != nil {
|
|
t.Fatalf("Unpack failed: %v", err)
|
|
}
|
|
|
|
if unpacked != tc.input {
|
|
if len(tc.input) > 100 {
|
|
t.Errorf("round-trip failed: lengths differ (got %d, want %d)", len(unpacked), len(tc.input))
|
|
} else {
|
|
t.Errorf("round-trip failed: got %q, want %q", unpacked, tc.input)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestPackUnpackEmptyInput(t *testing.T) {
|
|
key := make([]byte, 32)
|
|
for i := range key {
|
|
key[i] = byte(i)
|
|
}
|
|
|
|
// Unpack nil/empty ciphertext should return empty
|
|
result, err := Unpack(key, nil)
|
|
if err != nil {
|
|
t.Fatalf("Unpack nil failed: %v", err)
|
|
}
|
|
if result != "" {
|
|
t.Errorf("expected empty for nil input, got %q", result)
|
|
}
|
|
|
|
result, err = Unpack(key, []byte{})
|
|
if err != nil {
|
|
t.Fatalf("Unpack empty bytes failed: %v", err)
|
|
}
|
|
if result != "" {
|
|
t.Errorf("expected empty for empty bytes, got %q", result)
|
|
}
|
|
}
|
|
|
|
func TestBlindIndex(t *testing.T) {
|
|
key1 := make([]byte, 32)
|
|
key2 := make([]byte, 32)
|
|
for i := range key1 {
|
|
key1[i] = byte(i)
|
|
key2[i] = byte(i + 1) // Different key
|
|
}
|
|
|
|
plaintext := "searchable-term"
|
|
|
|
// Same input + same key = same index
|
|
index1 := BlindIndex(key1, plaintext)
|
|
index2 := BlindIndex(key1, plaintext)
|
|
if !bytes.Equal(index1, index2) {
|
|
t.Error("same input + key should produce same index")
|
|
}
|
|
|
|
// Same input + different key = different index
|
|
index3 := BlindIndex(key2, plaintext)
|
|
if bytes.Equal(index1, index3) {
|
|
t.Error("different keys should produce different indexes")
|
|
}
|
|
|
|
// Different input + same key = different index
|
|
index4 := BlindIndex(key1, "different-term")
|
|
if bytes.Equal(index1, index4) {
|
|
t.Error("different inputs should produce different indexes")
|
|
}
|
|
|
|
// Index should be 32 bytes (SHA-256)
|
|
if len(index1) != 32 {
|
|
t.Errorf("index length should be 32, got %d", len(index1))
|
|
}
|
|
}
|
|
|
|
func TestDeriveProjectKey(t *testing.T) {
|
|
masterKey := make([]byte, 32)
|
|
for i := range masterKey {
|
|
masterKey[i] = byte(i)
|
|
}
|
|
|
|
// Deterministic: same master + projectID = same key
|
|
key1, err := DeriveProjectKey(masterKey, "project-123")
|
|
if err != nil {
|
|
t.Fatalf("DeriveProjectKey failed: %v", err)
|
|
}
|
|
key2, err := DeriveProjectKey(masterKey, "project-123")
|
|
if err != nil {
|
|
t.Fatalf("DeriveProjectKey failed: %v", err)
|
|
}
|
|
if !bytes.Equal(key1, key2) {
|
|
t.Error("same master + projectID should produce same key")
|
|
}
|
|
|
|
// Different projectID = different key
|
|
key3, err := DeriveProjectKey(masterKey, "project-456")
|
|
if err != nil {
|
|
t.Fatalf("DeriveProjectKey failed: %v", err)
|
|
}
|
|
if bytes.Equal(key1, key3) {
|
|
t.Error("different projectID should produce different key")
|
|
}
|
|
|
|
// Key should be 32 bytes (AES-256)
|
|
if len(key1) != 32 {
|
|
t.Errorf("key length should be 32, got %d", len(key1))
|
|
}
|
|
}
|
|
|
|
func TestDeriveHMACKey(t *testing.T) {
|
|
masterKey := make([]byte, 32)
|
|
for i := range masterKey {
|
|
masterKey[i] = byte(i)
|
|
}
|
|
|
|
// HMAC key should be different from project key for same projectID
|
|
projectKey, _ := DeriveProjectKey(masterKey, "project-123")
|
|
hmacKey, err := DeriveHMACKey(masterKey, "project-123")
|
|
if err != nil {
|
|
t.Fatalf("DeriveHMACKey failed: %v", err)
|
|
}
|
|
if bytes.Equal(projectKey, hmacKey) {
|
|
t.Error("HMAC key should differ from project key")
|
|
}
|
|
|
|
// HMAC key should be 32 bytes
|
|
if len(hmacKey) != 32 {
|
|
t.Errorf("HMAC key length should be 32, got %d", len(hmacKey))
|
|
}
|
|
}
|
|
|
|
func TestAESGCM(t *testing.T) {
|
|
key := make([]byte, 32)
|
|
for i := range key {
|
|
key[i] = byte(i)
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
data []byte
|
|
}{
|
|
{"simple", []byte("hello world")},
|
|
{"binary", []byte{0x00, 0x01, 0x02, 0xff, 0xfe}},
|
|
{"large", bytes.Repeat([]byte("x"), 1024*1024)}, // 1MB
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
encrypted, err := ObjectEncrypt(key, tc.data)
|
|
if err != nil {
|
|
t.Fatalf("ObjectEncrypt failed: %v", err)
|
|
}
|
|
|
|
decrypted, err := ObjectDecrypt(key, encrypted)
|
|
if err != nil {
|
|
t.Fatalf("ObjectDecrypt failed: %v", err)
|
|
}
|
|
|
|
if !bytes.Equal(decrypted, tc.data) {
|
|
t.Errorf("round-trip failed")
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestObjectEncryptDecryptWrongKey(t *testing.T) {
|
|
key1 := make([]byte, 32)
|
|
key2 := make([]byte, 32)
|
|
for i := range key1 {
|
|
key1[i] = byte(i)
|
|
key2[i] = byte(i + 1)
|
|
}
|
|
|
|
data := []byte("secret data")
|
|
encrypted, err := ObjectEncrypt(key1, data)
|
|
if err != nil {
|
|
t.Fatalf("ObjectEncrypt failed: %v", err)
|
|
}
|
|
|
|
_, err = ObjectDecrypt(key2, encrypted)
|
|
if err == nil {
|
|
t.Error("decrypt with wrong key should fail")
|
|
}
|
|
}
|
|
|
|
func TestObjectDecryptInvalidCiphertext(t *testing.T) {
|
|
key := make([]byte, 32)
|
|
|
|
// Too short ciphertext
|
|
_, err := ObjectDecrypt(key, []byte{1, 2, 3})
|
|
if err == nil {
|
|
t.Error("decrypt too-short ciphertext should fail")
|
|
}
|
|
|
|
// Nil ciphertext
|
|
_, err = ObjectDecrypt(key, nil)
|
|
if err != ErrInvalidCiphertext {
|
|
t.Error("decrypt nil should return ErrInvalidCiphertext")
|
|
}
|
|
}
|
|
|
|
func TestContentHash(t *testing.T) {
|
|
data := []byte("test data")
|
|
hash1 := ContentHash(data)
|
|
hash2 := ContentHash(data)
|
|
|
|
if !bytes.Equal(hash1, hash2) {
|
|
t.Error("same data should produce same hash")
|
|
}
|
|
|
|
hash3 := ContentHash([]byte("different data"))
|
|
if bytes.Equal(hash1, hash3) {
|
|
t.Error("different data should produce different hash")
|
|
}
|
|
|
|
if len(hash1) != 32 {
|
|
t.Errorf("hash length should be 32, got %d", len(hash1))
|
|
}
|
|
}
|