clavitor/clavis/clavis-vault/api/integration_test.go

603 lines
18 KiB
Go

package api
// Integration tests for the Clavitor vault API.
//
// The test client authenticates exactly as production does:
// - 8-byte L1 key sent as base64url Bearer on every request
// - DB filename derived from L1[:4]: clavitor-{base64url(l1[:4])}
// - L1 normalized to 16 bytes for AES-128 vault encryption
//
// Each test gets an isolated vault (temp dir + fresh DB).
// Run: go test ./api/... -v
import (
"bytes"
"embed"
"encoding/base64"
"encoding/json"
"io"
"mime/multipart"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/johanj/clavitor/lib"
)
// --- test client ---
type tc struct {
srv *httptest.Server
bearer string // base64url-encoded L1 key (8 bytes)
t *testing.T
}
// newTestClient creates an isolated vault and test server.
// The L1 key is deterministic so tests are reproducible.
func newTestClient(t *testing.T) *tc {
t.Helper()
tmpDir := t.TempDir()
cfg := &lib.Config{
Port: "0",
DataDir: tmpDir,
SessionTTL: 86400,
}
// Fixed 8-byte L1 key for testing.
l1Raw := []byte{0xAA, 0xBB, 0xCC, 0xDD, 0x11, 0x22, 0x33, 0x44}
bearer := base64.RawURLEncoding.EncodeToString(l1Raw)
// DB filename must match what L1Middleware derives: clavitor-{base64url(l1[:4])}
prefix := base64.RawURLEncoding.EncodeToString(l1Raw[:4])
dbPath := tmpDir + "/clavitor-" + prefix
db, err := lib.OpenDB(dbPath)
if err != nil {
t.Fatalf("opendb: %v", err)
}
if err := lib.MigrateDB(db); err != nil {
t.Fatalf("migrate: %v", err)
}
db.Close()
var emptyFS embed.FS
srv := httptest.NewServer(NewRouter(cfg, emptyFS))
t.Cleanup(srv.Close)
return &tc{srv: srv, bearer: bearer, t: t}
}
// req sends an authenticated HTTP request.
func (c *tc) req(method, path string, body any) *http.Response {
c.t.Helper()
var r io.Reader
if body != nil {
b, _ := json.Marshal(body)
r = bytes.NewReader(b)
}
req, _ := http.NewRequest(method, c.srv.URL+path, r)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
req.Header.Set("Authorization", "Bearer "+c.bearer)
resp, err := c.srv.Client().Do(req)
if err != nil {
c.t.Fatalf("req %s %s: %v", method, path, err)
}
return resp
}
// reqNoAuth sends an unauthenticated request.
func (c *tc) reqNoAuth(method, path string, body any) *http.Response {
c.t.Helper()
var r io.Reader
if body != nil {
b, _ := json.Marshal(body)
r = bytes.NewReader(b)
}
req, _ := http.NewRequest(method, c.srv.URL+path, r)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
resp, err := c.srv.Client().Do(req)
if err != nil {
c.t.Fatalf("req %s %s: %v", method, path, err)
}
return resp
}
// must asserts status code and returns parsed JSON object.
func (c *tc) must(resp *http.Response, wantStatus int) map[string]any {
c.t.Helper()
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != wantStatus {
c.t.Fatalf("expected %d, got %d: %s", wantStatus, resp.StatusCode, body)
}
var out map[string]any
json.Unmarshal(body, &out)
return out
}
// mustList asserts status code and returns parsed JSON array.
func (c *tc) mustList(resp *http.Response, wantStatus int) []map[string]any {
c.t.Helper()
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
if resp.StatusCode != wantStatus {
c.t.Fatalf("expected %d, got %d: %s", wantStatus, resp.StatusCode, body)
}
var out []map[string]any
json.Unmarshal(body, &out)
return out
}
// --- test data ---
func credentialEntry(title, username, password string, urls []string) map[string]any {
return map[string]any{
"title": title,
"type": "credential",
"data": map[string]any{
"title": title,
"type": "credential",
"fields": []map[string]any{
{"label": "username", "value": username, "kind": "text"},
{"label": "password", "value": password, "kind": "password"},
},
"urls": urls,
},
}
}
// ---------------------------------------------------------------------------
// Health & Ping
// ---------------------------------------------------------------------------
func TestHealth(t *testing.T) {
c := newTestClient(t)
result := c.must(c.reqNoAuth("GET", "/health", nil), 200)
if result["status"] != "ok" {
t.Errorf("status = %v, want ok", result["status"])
}
}
func TestPing(t *testing.T) {
c := newTestClient(t)
result := c.must(c.reqNoAuth("GET", "/ping", nil), 200)
if result["ok"] != true {
t.Errorf("ok = %v, want true", result["ok"])
}
if result["node"] == nil || result["node"] == "" {
t.Error("node should not be empty")
}
if result["ts"] == nil {
t.Error("ts should be present")
}
}
// ---------------------------------------------------------------------------
// L1 Auth
// ---------------------------------------------------------------------------
func TestL1Auth_valid_key(t *testing.T) {
c := newTestClient(t)
// Should return empty array, not an auth error
c.mustList(c.req("GET", "/api/entries?meta=1", nil), 200)
}
func TestL1Auth_bad_bearer_rejected(t *testing.T) {
c := newTestClient(t)
req, _ := http.NewRequest("GET", c.srv.URL+"/api/entries", nil)
req.Header.Set("Authorization", "Bearer not-valid-base64")
resp, _ := c.srv.Client().Do(req)
defer resp.Body.Close()
if resp.StatusCode != 401 {
t.Errorf("bad bearer should return 401, got %d", resp.StatusCode)
}
}
func TestL1Auth_wrong_key_vault_not_found(t *testing.T) {
c := newTestClient(t)
// Valid base64url but points to a non-existent vault
wrongL1 := base64.RawURLEncoding.EncodeToString([]byte{0xFF, 0xFE, 0xFD, 0xFC, 0xFB, 0xFA, 0xF9, 0xF8})
req, _ := http.NewRequest("GET", c.srv.URL+"/api/entries", nil)
req.Header.Set("Authorization", "Bearer "+wrongL1)
resp, _ := c.srv.Client().Do(req)
defer resp.Body.Close()
if resp.StatusCode != 404 {
t.Errorf("wrong L1 key should return 404 (vault not found), got %d", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// Entry CRUD
// ---------------------------------------------------------------------------
func TestCreateEntry(t *testing.T) {
c := newTestClient(t)
result := c.must(c.req("POST", "/api/entries", credentialEntry("GitHub", "octocat", "hunter2", []string{"https://github.com"})), 201)
if result["entry_id"] == nil || result["entry_id"] == "" {
t.Fatal("create should return entry_id")
}
}
func TestCreateEntry_missing_title(t *testing.T) {
c := newTestClient(t)
resp := c.req("POST", "/api/entries", map[string]any{"type": "credential"})
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("missing title should return 400, got %d", resp.StatusCode)
}
}
func TestReadEntry_roundtrip(t *testing.T) {
c := newTestClient(t)
created := c.must(c.req("POST", "/api/entries", credentialEntry("GitHub", "octocat", "hunter2", nil)), 201)
id := created["entry_id"].(string)
got := c.must(c.req("GET", "/api/entries/"+id, nil), 200)
data := got["data"].(map[string]any)
fields := data["fields"].([]any)
found := map[string]string{}
for _, f := range fields {
fm := f.(map[string]any)
found[fm["label"].(string)] = fm["value"].(string)
}
if found["username"] != "octocat" {
t.Errorf("username = %q, want octocat", found["username"])
}
if found["password"] != "hunter2" {
t.Errorf("password = %q, want hunter2", found["password"])
}
}
func TestUpdateEntry(t *testing.T) {
c := newTestClient(t)
created := c.must(c.req("POST", "/api/entries", credentialEntry("Old", "user", "pass", nil)), 201)
id := created["entry_id"].(string)
updated := c.must(c.req("PUT", "/api/entries/"+id, map[string]any{
"title": "New",
"version": 1,
"data": map[string]any{
"title": "New", "type": "credential",
"fields": []map[string]any{{"label": "username", "value": "newuser", "kind": "text"}},
},
}), 200)
if updated["title"] != "New" {
t.Errorf("title = %v, want New", updated["title"])
}
}
func TestUpdateEntry_version_conflict(t *testing.T) {
c := newTestClient(t)
created := c.must(c.req("POST", "/api/entries", credentialEntry("Test", "u", "p", nil)), 201)
id := created["entry_id"].(string)
c.must(c.req("PUT", "/api/entries/"+id, map[string]any{
"title": "V2", "version": 1,
"data": map[string]any{"title": "V2", "type": "credential"},
}), 200)
resp := c.req("PUT", "/api/entries/"+id, map[string]any{
"title": "Stale", "version": 1,
"data": map[string]any{"title": "Stale", "type": "credential"},
})
defer resp.Body.Close()
if resp.StatusCode != 409 {
t.Errorf("stale version should return 409, got %d", resp.StatusCode)
}
}
func TestDeleteEntry(t *testing.T) {
c := newTestClient(t)
created := c.must(c.req("POST", "/api/entries", credentialEntry("ToDelete", "u", "p", nil)), 201)
id := created["entry_id"].(string)
c.must(c.req("DELETE", "/api/entries/"+id, nil), 200)
entries := c.mustList(c.req("GET", "/api/entries?meta=1", nil), 200)
for _, e := range entries {
if e["entry_id"] == id {
t.Error("deleted entry should not appear in list")
}
}
}
func TestListEntries_meta(t *testing.T) {
c := newTestClient(t)
c.must(c.req("POST", "/api/entries", credentialEntry("One", "u", "p", nil)), 201)
c.must(c.req("POST", "/api/entries", credentialEntry("Two", "u", "p", nil)), 201)
entries := c.mustList(c.req("GET", "/api/entries?meta=1", nil), 200)
if len(entries) != 2 {
t.Errorf("expected 2 entries, got %d", len(entries))
}
for _, e := range entries {
if e["data"] != nil {
t.Error("meta mode should not include field data")
}
}
}
// ---------------------------------------------------------------------------
// Search
// ---------------------------------------------------------------------------
func TestSearch(t *testing.T) {
c := newTestClient(t)
c.must(c.req("POST", "/api/entries", credentialEntry("GitHub", "u", "p", nil)), 201)
c.must(c.req("POST", "/api/entries", credentialEntry("GitLab", "u", "p", nil)), 201)
c.must(c.req("POST", "/api/entries", credentialEntry("AWS", "u", "p", nil)), 201)
entries := c.mustList(c.req("GET", "/api/search?q=Git", nil), 200)
if len(entries) != 2 {
t.Errorf("search for 'Git' should return 2, got %d", len(entries))
}
}
func TestSearch_no_query(t *testing.T) {
c := newTestClient(t)
resp := c.req("GET", "/api/search", nil)
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("missing query should return 400, got %d", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// TOTP
// ---------------------------------------------------------------------------
func TestTOTP_valid_code(t *testing.T) {
c := newTestClient(t)
entry := map[string]any{
"title": "2FA Test", "type": "credential",
"data": map[string]any{
"title": "2FA Test", "type": "credential",
"fields": []map[string]any{{"label": "totp", "value": "JBSWY3DPEHPK3PXP", "kind": "totp"}},
},
}
created := c.must(c.req("POST", "/api/entries", entry), 201)
id := created["entry_id"].(string)
result := c.must(c.req("GET", "/api/ext/totp/"+id, nil), 200)
code, _ := result["code"].(string)
if len(code) != 6 {
t.Errorf("TOTP code = %q, want 6 digits", code)
}
expiresIn, _ := result["expires_in"].(float64)
if expiresIn <= 0 || expiresIn > 30 {
t.Errorf("expires_in = %v, want 1-30", expiresIn)
}
}
func TestTOTP_L2_returns_locked(t *testing.T) {
c := newTestClient(t)
entry := map[string]any{
"title": "L2 TOTP", "type": "credential",
"data": map[string]any{
"title": "L2 TOTP", "type": "credential",
"fields": []map[string]any{{"label": "totp", "value": "JBSWY3DPEHPK3PXP", "kind": "totp", "l2": true}},
},
}
created := c.must(c.req("POST", "/api/entries", entry), 201)
id := created["entry_id"].(string)
result := c.must(c.req("GET", "/api/ext/totp/"+id, nil), 200)
if result["l2"] != true {
t.Error("L2 TOTP should return l2:true")
}
}
// ---------------------------------------------------------------------------
// URL Match
// ---------------------------------------------------------------------------
func TestURLMatch(t *testing.T) {
c := newTestClient(t)
c.must(c.req("POST", "/api/entries", credentialEntry("GitHub", "u", "p", []string{"https://github.com"})), 201)
matches := c.mustList(c.req("GET", "/api/ext/match?url=https://github.com/login", nil), 200)
if len(matches) == 0 {
t.Error("should match github.com for github.com/login")
}
}
func TestURLMatch_no_match(t *testing.T) {
c := newTestClient(t)
c.must(c.req("POST", "/api/entries", credentialEntry("GitHub", "u", "p", []string{"https://github.com"})), 201)
matches := c.mustList(c.req("GET", "/api/ext/match?url=https://example.com", nil), 200)
if len(matches) != 0 {
t.Errorf("should not match, got %d", len(matches))
}
}
// ---------------------------------------------------------------------------
// Import
// ---------------------------------------------------------------------------
func TestImport_ChromeCSV(t *testing.T) {
c := newTestClient(t)
csv := "name,url,username,password\nGitHub,https://github.com,octocat,hunter2\n"
body := &bytes.Buffer{}
w := multipart.NewWriter(body)
part, _ := w.CreateFormFile("file", "passwords.csv")
part.Write([]byte(csv))
w.Close()
req, _ := http.NewRequest("POST", c.srv.URL+"/api/import", body)
req.Header.Set("Content-Type", w.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.bearer)
resp, err := c.srv.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 200 {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("import returned %d: %s", resp.StatusCode, b)
}
}
func TestImport_unknown_format_rejected(t *testing.T) {
c := newTestClient(t)
body := &bytes.Buffer{}
w := multipart.NewWriter(body)
part, _ := w.CreateFormFile("file", "garbage.txt")
part.Write([]byte("this is not a password export"))
w.Close()
req, _ := http.NewRequest("POST", c.srv.URL+"/api/import", body)
req.Header.Set("Content-Type", w.FormDataContentType())
req.Header.Set("Authorization", "Bearer "+c.bearer)
resp, err := c.srv.Client().Do(req)
if err != nil {
t.Fatal(err)
}
defer resp.Body.Close()
if resp.StatusCode != 400 {
t.Errorf("unknown format should return 400, got %d", resp.StatusCode)
}
}
// ---------------------------------------------------------------------------
// Password Generator
// ---------------------------------------------------------------------------
func TestPasswordGenerator(t *testing.T) {
c := newTestClient(t)
result := c.must(c.req("GET", "/api/generate?length=24", nil), 200)
pw, _ := result["password"].(string)
if len(pw) != 24 {
t.Errorf("password length = %d, want 24", len(pw))
}
}
func TestPasswordGenerator_passphrase(t *testing.T) {
c := newTestClient(t)
result := c.must(c.req("GET", "/api/generate?words=4", nil), 200)
pw, _ := result["password"].(string)
words := strings.Split(pw, "-")
if len(words) != 4 {
t.Errorf("passphrase should have 4 words, got %d: %q", len(words), pw)
}
}
// ---------------------------------------------------------------------------
// Audit Log
// ---------------------------------------------------------------------------
func TestAuditLog(t *testing.T) {
c := newTestClient(t)
created := c.must(c.req("POST", "/api/entries", credentialEntry("Audited", "u", "p", nil)), 201)
id := created["entry_id"].(string)
c.must(c.req("GET", "/api/entries/"+id, nil), 200)
events := c.mustList(c.req("GET", "/api/audit", nil), 200)
if len(events) < 2 {
t.Errorf("expected at least 2 events (create + read), got %d", len(events))
}
actions := map[string]bool{}
for _, e := range events {
if a, ok := e["action"].(string); ok {
actions[a] = true
}
}
if !actions["create"] {
t.Error("missing 'create' in audit log")
}
if !actions["read"] {
t.Error("missing 'read' in audit log")
}
}
// ---------------------------------------------------------------------------
// WebAuthn Auth Flow
// ---------------------------------------------------------------------------
func TestAuthStatus_fresh(t *testing.T) {
c := newTestClient(t)
result := c.must(c.reqNoAuth("GET", "/api/auth/status", nil), 200)
if result["state"] != "fresh" {
t.Errorf("state = %v, want fresh", result["state"])
}
}
func TestAuthRegisterBegin_fresh(t *testing.T) {
c := newTestClient(t)
resp := c.reqNoAuth("POST", "/api/auth/register/begin", map[string]any{})
defer resp.Body.Close()
if resp.StatusCode != 200 {
b, _ := io.ReadAll(resp.Body)
t.Fatalf("expected 200, got %d: %s", resp.StatusCode, b)
}
var result map[string]any
json.NewDecoder(resp.Body).Decode(&result)
pk := result["publicKey"].(map[string]any)
if pk["challenge"] == nil {
t.Fatal("response should contain a challenge")
}
}
// ---------------------------------------------------------------------------
// Tier Isolation
// ---------------------------------------------------------------------------
// TestTierIsolation verifies that L2/L3 encrypted blobs survive the L1
// envelope encrypt/decrypt roundtrip intact. The server packs all fields
// into a single AES-GCM envelope (L1). L2/L3 field values are opaque
// ciphertext — the server stores them, never inspects them.
func TestTierIsolation(t *testing.T) {
c := newTestClient(t)
l2Blob := "AQIDBAUGB5iL2EncryptedBlob+test=="
l3Blob := "AQIDBAUGB5iL3EncryptedBlob+test=="
created := c.must(c.req("POST", "/api/entries", map[string]any{
"type": "credential", "title": "TierTest",
"data": map[string]any{
"title": "TierTest", "type": "credential",
"fields": []map[string]any{
{"label": "Username", "value": "testuser", "kind": "text"},
{"label": "Password", "value": l2Blob, "kind": "password", "tier": 2},
{"label": "SSN", "value": l3Blob, "kind": "text", "tier": 3, "l2": true},
},
},
}), 201)
id := created["entry_id"].(string)
got := c.must(c.req("GET", "/api/entries/"+id, nil), 200)
data := got["data"].(map[string]any)
fields := data["fields"].([]any)
found := map[string]string{}
for _, raw := range fields {
f := raw.(map[string]any)
found[f["label"].(string)], _ = f["value"].(string)
}
if found["Username"] != "testuser" {
t.Errorf("L1 Username = %q, want testuser", found["Username"])
}
if found["Password"] != l2Blob {
t.Errorf("L2 Password blob changed: %q", found["Password"])
}
if found["SSN"] != l3Blob {
t.Errorf("L3 SSN blob changed: %q", found["SSN"])
}
}