325 lines
7.3 KiB
Go
325 lines
7.3 KiB
Go
package lib
|
|
|
|
import (
|
|
"testing"
|
|
)
|
|
|
|
// testDB creates a temp database, migrates it, returns DB + cleanup.
|
|
func testDB(t *testing.T) *DB {
|
|
t.Helper()
|
|
db, err := OpenDB(t.TempDir() + "/test.db")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if err := MigrateDB(db); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
t.Cleanup(func() { db.Close() })
|
|
return db
|
|
}
|
|
|
|
// testVaultKey returns a fixed 16-byte key for testing.
|
|
func testVaultKey() []byte {
|
|
return []byte{1, 2, 3, 4, 5, 6, 7, 8, 1, 2, 3, 4, 5, 6, 7, 8}
|
|
}
|
|
|
|
func TestEntryCreate_and_Get(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
entry := &Entry{
|
|
Type: TypeCredential,
|
|
Title: "GitHub",
|
|
VaultData: &VaultData{
|
|
Title: "GitHub",
|
|
Type: "credential",
|
|
Fields: []VaultField{
|
|
{Label: "username", Value: "octocat", Kind: "text"},
|
|
{Label: "password", Value: "ghp_abc123", Kind: "password"},
|
|
},
|
|
URLs: []string{"https://github.com"},
|
|
},
|
|
}
|
|
|
|
if err := EntryCreate(db, vk, entry); err != nil {
|
|
t.Fatalf("create: %v", err)
|
|
}
|
|
if entry.EntryID == 0 {
|
|
t.Fatal("entry ID should be assigned")
|
|
}
|
|
if entry.Version != 1 {
|
|
t.Errorf("initial version should be 1, got %d", entry.Version)
|
|
}
|
|
|
|
got, err := EntryGet(db, vk, int64(entry.EntryID))
|
|
if err != nil {
|
|
t.Fatalf("get: %v", err)
|
|
}
|
|
if got.Title != "GitHub" {
|
|
t.Errorf("title = %q, want GitHub", got.Title)
|
|
}
|
|
if got.VaultData == nil {
|
|
t.Fatal("VaultData should be unpacked")
|
|
}
|
|
if len(got.VaultData.Fields) != 2 {
|
|
t.Fatalf("expected 2 fields, got %d", len(got.VaultData.Fields))
|
|
}
|
|
if got.VaultData.Fields[0].Value != "octocat" {
|
|
t.Errorf("username = %q, want octocat", got.VaultData.Fields[0].Value)
|
|
}
|
|
if got.VaultData.Fields[1].Value != "ghp_abc123" {
|
|
t.Errorf("password = %q, want ghp_abc123", got.VaultData.Fields[1].Value)
|
|
}
|
|
}
|
|
|
|
func TestEntryUpdate_optimistic_locking(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
entry := &Entry{
|
|
Type: TypeCredential,
|
|
Title: "Original",
|
|
VaultData: &VaultData{Title: "Original", Type: "credential"},
|
|
}
|
|
EntryCreate(db, vk, entry)
|
|
|
|
// Update with correct version
|
|
entry.Title = "Updated"
|
|
entry.VaultData.Title = "Updated"
|
|
if err := EntryUpdate(db, vk, entry); err != nil {
|
|
t.Fatalf("update: %v", err)
|
|
}
|
|
if entry.Version != 2 {
|
|
t.Errorf("version after update should be 2, got %d", entry.Version)
|
|
}
|
|
|
|
// Update with stale version should fail
|
|
entry.Version = 1 // stale
|
|
entry.Title = "Stale"
|
|
err := EntryUpdate(db, vk, entry)
|
|
if err != ErrVersionConflict {
|
|
t.Errorf("expected ErrVersionConflict, got %v", err)
|
|
}
|
|
}
|
|
|
|
func TestEntryDelete_soft_delete(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
entry := &Entry{
|
|
Type: TypeCredential,
|
|
Title: "ToDelete",
|
|
VaultData: &VaultData{Title: "ToDelete", Type: "credential"},
|
|
}
|
|
EntryCreate(db, vk, entry)
|
|
|
|
if err := EntryDelete(db, int64(entry.EntryID)); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
// Should not appear in list
|
|
entries, err := EntryList(db, vk, nil)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
for _, e := range entries {
|
|
if e.EntryID == entry.EntryID {
|
|
t.Error("deleted entry should not appear in list")
|
|
}
|
|
}
|
|
|
|
// Direct get should still work but have DeletedAt set
|
|
got, err := EntryGet(db, vk, int64(entry.EntryID))
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got.DeletedAt == nil {
|
|
t.Error("deleted entry should have DeletedAt set")
|
|
}
|
|
}
|
|
|
|
func TestEntryList_filters_by_parent(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
folder := &Entry{Type: TypeFolder, Title: "Work", VaultData: &VaultData{Title: "Work", Type: "folder"}}
|
|
EntryCreate(db, vk, folder)
|
|
|
|
child := &Entry{
|
|
Type: TypeCredential,
|
|
Title: "WorkGitHub",
|
|
ParentID: folder.EntryID,
|
|
VaultData: &VaultData{Title: "WorkGitHub", Type: "credential"},
|
|
}
|
|
EntryCreate(db, vk, child)
|
|
|
|
orphan := &Entry{
|
|
Type: TypeCredential,
|
|
Title: "Personal",
|
|
VaultData: &VaultData{Title: "Personal", Type: "credential"},
|
|
}
|
|
EntryCreate(db, vk, orphan)
|
|
|
|
parentID := int64(folder.EntryID)
|
|
children, err := EntryList(db, vk, &parentID)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(children) != 1 {
|
|
t.Fatalf("expected 1 child, got %d", len(children))
|
|
}
|
|
if children[0].Title != "WorkGitHub" {
|
|
t.Errorf("child title = %q", children[0].Title)
|
|
}
|
|
}
|
|
|
|
func TestEntrySearchFuzzy(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
for _, title := range []string{"GitHub", "GitLab", "AWS Console"} {
|
|
EntryCreate(db, vk, &Entry{
|
|
Type: TypeCredential,
|
|
Title: title,
|
|
VaultData: &VaultData{Title: title, Type: "credential"},
|
|
})
|
|
}
|
|
|
|
results, err := EntrySearchFuzzy(db, vk, "Git")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(results) != 2 {
|
|
t.Errorf("search for 'Git' should return 2 results, got %d", len(results))
|
|
}
|
|
}
|
|
|
|
func TestEntryCount(t *testing.T) {
|
|
db := testDB(t)
|
|
vk := testVaultKey()
|
|
|
|
count, _ := EntryCount(db)
|
|
if count != 0 {
|
|
t.Errorf("empty db should have 0 entries, got %d", count)
|
|
}
|
|
|
|
EntryCreate(db, vk, &Entry{
|
|
Type: TypeCredential, Title: "One",
|
|
VaultData: &VaultData{Title: "One", Type: "credential"},
|
|
})
|
|
EntryCreate(db, vk, &Entry{
|
|
Type: TypeCredential, Title: "Two",
|
|
VaultData: &VaultData{Title: "Two", Type: "credential"},
|
|
})
|
|
|
|
count, _ = EntryCount(db)
|
|
if count != 2 {
|
|
t.Errorf("expected 2 entries, got %d", count)
|
|
}
|
|
}
|
|
|
|
func TestAuditLog_write_and_read(t *testing.T) {
|
|
db := testDB(t)
|
|
|
|
AuditLog(db, &AuditEvent{
|
|
Action: ActionCreate,
|
|
Actor: ActorWeb,
|
|
Title: "GitHub",
|
|
IPAddr: "127.0.0.1",
|
|
})
|
|
AuditLog(db, &AuditEvent{
|
|
Action: ActionRead,
|
|
Actor: ActorAgent,
|
|
Title: "GitHub",
|
|
IPAddr: "10.0.0.1",
|
|
})
|
|
|
|
events, err := AuditList(db, 10)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(events) != 2 {
|
|
t.Fatalf("expected 2 audit events, got %d", len(events))
|
|
}
|
|
// Both actions should be present (order depends on timestamp resolution)
|
|
actions := map[string]bool{}
|
|
for _, e := range events {
|
|
actions[e.Action] = true
|
|
}
|
|
if !actions[ActionCreate] {
|
|
t.Error("missing create action")
|
|
}
|
|
if !actions[ActionRead] {
|
|
t.Error("missing read action")
|
|
}
|
|
}
|
|
|
|
func TestSessionCreate_and_Get(t *testing.T) {
|
|
db := testDB(t)
|
|
|
|
session, err := SessionCreate(db, 3600, ActorWeb)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if session.Token == "" {
|
|
t.Fatal("session token should not be empty")
|
|
}
|
|
|
|
got, err := SessionGet(db, session.Token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got == nil {
|
|
t.Fatal("session should exist")
|
|
}
|
|
if got.Actor != ActorWeb {
|
|
t.Errorf("actor = %q, want web", got.Actor)
|
|
}
|
|
}
|
|
|
|
func TestSessionGet_expired(t *testing.T) {
|
|
db := testDB(t)
|
|
|
|
// Create session with negative TTL (guaranteed expired)
|
|
session, _ := SessionCreate(db, -1, ActorWeb)
|
|
|
|
got, err := SessionGet(db, session.Token)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if got != nil {
|
|
t.Error("expired session should return nil")
|
|
}
|
|
}
|
|
|
|
func TestWebAuthnCredential_store_and_list(t *testing.T) {
|
|
db := testDB(t)
|
|
|
|
cred := &WebAuthnCredential{
|
|
CredID: HexID(NewID()),
|
|
Name: "YubiKey",
|
|
PublicKey: []byte{1, 2, 3},
|
|
CredentialID: []byte{4, 5, 6},
|
|
PRFSalt: []byte{7, 8, 9},
|
|
}
|
|
if err := StoreWebAuthnCredential(db, cred); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
|
|
creds, err := GetWebAuthnCredentials(db)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if len(creds) != 1 {
|
|
t.Fatalf("expected 1 credential, got %d", len(creds))
|
|
}
|
|
if creds[0].Name != "YubiKey" {
|
|
t.Errorf("name = %q", creds[0].Name)
|
|
}
|
|
|
|
count, _ := WebAuthnCredentialCount(db)
|
|
if count != 1 {
|
|
t.Errorf("count = %d, want 1", count)
|
|
}
|
|
}
|