clavitor/clavis/clavis-vault/lib/dbcore_test.go

325 lines
7.3 KiB
Go

package lib
import (
"testing"
)
// testDB creates a temp database, migrates it, returns DB + cleanup.
func testDB(t *testing.T) *DB {
t.Helper()
db, err := OpenDB(t.TempDir() + "/test.db")
if err != nil {
t.Fatal(err)
}
if err := MigrateDB(db); err != nil {
t.Fatal(err)
}
t.Cleanup(func() { db.Close() })
return db
}
// testVaultKey returns a fixed 16-byte key for testing.
func testVaultKey() []byte {
return []byte{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}
}
func TestEntryCreate_and_Get(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
entry := &Entry{
Type: TypeCredential,
Title: "GitHub",
VaultData: &VaultData{
Title: "GitHub",
Type: "credential",
Fields: []VaultField{
{Label: "username", Value: "octocat", Kind: "text"},
{Label: "password", Value: "ghp_abc123", Kind: "password"},
},
URLs: []string{"https://github.com"},
},
}
if err := EntryCreate(db, vk, entry); err != nil {
t.Fatalf("create: %v", err)
}
if entry.EntryID == 0 {
t.Fatal("entry ID should be assigned")
}
if entry.Version != 1 {
t.Errorf("initial version should be 1, got %d", entry.Version)
}
got, err := EntryGet(db, vk, int64(entry.EntryID))
if err != nil {
t.Fatalf("get: %v", err)
}
if got.Title != "GitHub" {
t.Errorf("title = %q, want GitHub", got.Title)
}
if got.VaultData == nil {
t.Fatal("VaultData should be unpacked")
}
if len(got.VaultData.Fields) != 2 {
t.Fatalf("expected 2 fields, got %d", len(got.VaultData.Fields))
}
if got.VaultData.Fields[0].Value != "octocat" {
t.Errorf("username = %q, want octocat", got.VaultData.Fields[0].Value)
}
if got.VaultData.Fields[1].Value != "ghp_abc123" {
t.Errorf("password = %q, want ghp_abc123", got.VaultData.Fields[1].Value)
}
}
func TestEntryUpdate_optimistic_locking(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
entry := &Entry{
Type: TypeCredential,
Title: "Original",
VaultData: &VaultData{Title: "Original", Type: "credential"},
}
EntryCreate(db, vk, entry)
// Update with correct version
entry.Title = "Updated"
entry.VaultData.Title = "Updated"
if err := EntryUpdate(db, vk, entry); err != nil {
t.Fatalf("update: %v", err)
}
if entry.Version != 2 {
t.Errorf("version after update should be 2, got %d", entry.Version)
}
// Update with stale version should fail
entry.Version = 1 // stale
entry.Title = "Stale"
err := EntryUpdate(db, vk, entry)
if err != ErrVersionConflict {
t.Errorf("expected ErrVersionConflict, got %v", err)
}
}
func TestEntryDelete_soft_delete(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
entry := &Entry{
Type: TypeCredential,
Title: "ToDelete",
VaultData: &VaultData{Title: "ToDelete", Type: "credential"},
}
EntryCreate(db, vk, entry)
if err := EntryDelete(db, int64(entry.EntryID)); err != nil {
t.Fatal(err)
}
// Should not appear in list
entries, err := EntryList(db, vk, nil)
if err != nil {
t.Fatal(err)
}
for _, e := range entries {
if e.EntryID == entry.EntryID {
t.Error("deleted entry should not appear in list")
}
}
// Direct get should still work but have DeletedAt set
got, err := EntryGet(db, vk, int64(entry.EntryID))
if err != nil {
t.Fatal(err)
}
if got.DeletedAt == nil {
t.Error("deleted entry should have DeletedAt set")
}
}
func TestEntryList_filters_by_parent(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
folder := &Entry{Type: TypeFolder, Title: "Work", VaultData: &VaultData{Title: "Work", Type: "folder"}}
EntryCreate(db, vk, folder)
child := &Entry{
Type: TypeCredential,
Title: "WorkGitHub",
ParentID: folder.EntryID,
VaultData: &VaultData{Title: "WorkGitHub", Type: "credential"},
}
EntryCreate(db, vk, child)
orphan := &Entry{
Type: TypeCredential,
Title: "Personal",
VaultData: &VaultData{Title: "Personal", Type: "credential"},
}
EntryCreate(db, vk, orphan)
parentID := int64(folder.EntryID)
children, err := EntryList(db, vk, &parentID)
if err != nil {
t.Fatal(err)
}
if len(children) != 1 {
t.Fatalf("expected 1 child, got %d", len(children))
}
if children[0].Title != "WorkGitHub" {
t.Errorf("child title = %q", children[0].Title)
}
}
func TestEntrySearchFuzzy(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
for _, title := range []string{"GitHub", "GitLab", "AWS Console"} {
EntryCreate(db, vk, &Entry{
Type: TypeCredential,
Title: title,
VaultData: &VaultData{Title: title, Type: "credential"},
})
}
results, err := EntrySearchFuzzy(db, vk, "Git")
if err != nil {
t.Fatal(err)
}
if len(results) != 2 {
t.Errorf("search for 'Git' should return 2 results, got %d", len(results))
}
}
func TestEntryCount(t *testing.T) {
db := testDB(t)
vk := testVaultKey()
count, _ := EntryCount(db)
if count != 0 {
t.Errorf("empty db should have 0 entries, got %d", count)
}
EntryCreate(db, vk, &Entry{
Type: TypeCredential, Title: "One",
VaultData: &VaultData{Title: "One", Type: "credential"},
})
EntryCreate(db, vk, &Entry{
Type: TypeCredential, Title: "Two",
VaultData: &VaultData{Title: "Two", Type: "credential"},
})
count, _ = EntryCount(db)
if count != 2 {
t.Errorf("expected 2 entries, got %d", count)
}
}
func TestAuditLog_write_and_read(t *testing.T) {
db := testDB(t)
AuditLog(db, &AuditEvent{
Action: ActionCreate,
Actor: ActorWeb,
Title: "GitHub",
IPAddr: "127.0.0.1",
})
AuditLog(db, &AuditEvent{
Action: ActionRead,
Actor: ActorAgent,
Title: "GitHub",
IPAddr: "10.0.0.1",
})
events, err := AuditList(db, 10)
if err != nil {
t.Fatal(err)
}
if len(events) != 2 {
t.Fatalf("expected 2 audit events, got %d", len(events))
}
// Both actions should be present (order depends on timestamp resolution)
actions := map[string]bool{}
for _, e := range events {
actions[e.Action] = true
}
if !actions[ActionCreate] {
t.Error("missing create action")
}
if !actions[ActionRead] {
t.Error("missing read action")
}
}
func TestSessionCreate_and_Get(t *testing.T) {
db := testDB(t)
session, err := SessionCreate(db, 3600, ActorWeb)
if err != nil {
t.Fatal(err)
}
if session.Token == "" {
t.Fatal("session token should not be empty")
}
got, err := SessionGet(db, session.Token)
if err != nil {
t.Fatal(err)
}
if got == nil {
t.Fatal("session should exist")
}
if got.Actor != ActorWeb {
t.Errorf("actor = %q, want web", got.Actor)
}
}
func TestSessionGet_expired(t *testing.T) {
db := testDB(t)
// Create session with negative TTL (guaranteed expired)
session, _ := SessionCreate(db, -1, ActorWeb)
got, err := SessionGet(db, session.Token)
if err != nil {
t.Fatal(err)
}
if got != nil {
t.Error("expired session should return nil")
}
}
func TestWebAuthnCredential_store_and_list(t *testing.T) {
db := testDB(t)
cred := &WebAuthnCredential{
CredID: HexID(NewID()),
Name: "YubiKey",
PublicKey: []byte{1, 2, 3},
CredentialID: []byte{4, 5, 6},
PRFSalt: []byte{7, 8, 9},
}
if err := StoreWebAuthnCredential(db, cred); err != nil {
t.Fatal(err)
}
creds, err := GetWebAuthnCredentials(db)
if err != nil {
t.Fatal(err)
}
if len(creds) != 1 {
t.Fatalf("expected 1 credential, got %d", len(creds))
}
if creds[0].Name != "YubiKey" {
t.Errorf("name = %q", creds[0].Name)
}
count, _ := WebAuthnCredentialCount(db)
if count != 1 {
t.Errorf("count = %d, want 1", count)
}
}