clavitor/clavis/clavis-vault/lib/dbcore.go

709 lines
20 KiB
Go

package lib
import (
"crypto/rand"
"database/sql"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
var (
ErrNotFound = errors.New("not found")
ErrVersionConflict = errors.New("version conflict: entry was modified")
)
const schema = `
CREATE TABLE IF NOT EXISTS entries (
entry_id INTEGER PRIMARY KEY,
parent_id INTEGER NOT NULL DEFAULT 0,
type TEXT NOT NULL,
title TEXT NOT NULL,
title_idx BLOB NOT NULL,
data BLOB NOT NULL,
data_level INTEGER NOT NULL DEFAULT 1,
scopes TEXT NOT NULL DEFAULT '0000',
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
deleted_at INTEGER,
checksum INTEGER,
replicated_at INTEGER
);
CREATE INDEX IF NOT EXISTS idx_entries_parent ON entries(parent_id);
CREATE INDEX IF NOT EXISTS idx_entries_type ON entries(type);
CREATE INDEX IF NOT EXISTS idx_entries_title_idx ON entries(title_idx);
CREATE INDEX IF NOT EXISTS idx_entries_deleted ON entries(deleted_at);
CREATE TABLE IF NOT EXISTS audit_log (
event_id INTEGER PRIMARY KEY,
entry_id INTEGER,
title TEXT,
action TEXT NOT NULL,
actor TEXT NOT NULL,
ip_addr TEXT,
created_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_entry ON audit_log(entry_id);
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_log(created_at);
CREATE TABLE IF NOT EXISTS webauthn_credentials (
cred_id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
public_key BLOB NOT NULL,
credential_id BLOB NOT NULL DEFAULT X'',
prf_salt BLOB NOT NULL,
sign_count INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS webauthn_challenges (
challenge BLOB PRIMARY KEY,
type TEXT NOT NULL,
created_at INTEGER NOT NULL
);
`
// OpenDB opens the SQLite database.
func OpenDB(dbPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
return &DB{Conn: conn, DBPath: dbPath}, nil
}
// MigrateDB runs the schema.
func MigrateDB(db *DB) error {
_, err := db.Conn.Exec(schema)
return err
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.Conn.Close()
}
// ---------------------------------------------------------------------------
// Entry operations
// ---------------------------------------------------------------------------
// EntryCreate creates a new entry.
func EntryCreate(db *DB, vaultKey []byte, e *Entry) error {
if e.EntryID == 0 {
e.EntryID = HexID(NewID())
}
now := time.Now().UnixMilli()
e.CreatedAt = now
e.UpdatedAt = now
e.Version = 1
if e.DataLevel == 0 {
e.DataLevel = DataLevelL1
}
if e.Scopes == "" {
e.Scopes = ScopeOwner
}
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return err
}
// Agent entries pre-set TitleIdx to BlindIndex(agent_id) for lookup. Don't overwrite.
if len(e.TitleIdx) == 0 {
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
}
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
_, err = db.Conn.Exec(
`INSERT INTO entries (entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
int64(e.EntryID), int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, e.Scopes, e.CreatedAt, e.UpdatedAt, e.Version,
)
return err
}
// EntryGet retrieves an entry by ID.
func EntryGet(db *DB, vaultKey []byte, entryID int64) (*Entry, error) {
var e Entry
var deletedAt sql.NullInt64
var replicatedAt sql.NullInt64
err := db.Conn.QueryRow(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version, deleted_at, replicated_at
FROM entries WHERE entry_id = ?`, entryID,
).Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.Scopes, &e.CreatedAt, &e.UpdatedAt, &e.Version, &deletedAt, &replicatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if deletedAt.Valid {
v := deletedAt.Int64
e.DeletedAt = &v
}
if replicatedAt.Valid {
v := replicatedAt.Int64
e.ReplicatedAt = &v
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return nil, err
}
dataText, err := Unpack(entryKey, e.Data)
if err != nil {
return nil, err
}
var vd VaultData
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
return nil, err
}
e.VaultData = &vd
}
return &e, nil
}
// EntryUpdate updates an existing entry with optimistic locking.
func EntryUpdate(db *DB, vaultKey []byte, e *Entry) error {
now := time.Now().UnixMilli()
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return err
}
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
result, err := db.Conn.Exec(
`UPDATE entries SET parent_id=?, type=?, title=?, title_idx=?, data=?, data_level=?, scopes=?, updated_at=?, version=version+1
WHERE entry_id = ? AND version = ? AND deleted_at IS NULL`,
int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, e.Scopes, now,
int64(e.EntryID), e.Version,
)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return ErrVersionConflict
}
e.Version++
e.UpdatedAt = now
return nil
}
// EntryDelete soft-deletes an entry.
func EntryDelete(db *DB, entryID int64) error {
now := time.Now().UnixMilli()
result, err := db.Conn.Exec(
`UPDATE entries SET deleted_at = ?, updated_at = ? WHERE entry_id = ? AND deleted_at IS NULL`,
now, now, entryID)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return ErrNotFound
}
return nil
}
// EntryList returns all non-deleted entries, optionally filtered by parent.
func EntryList(db *DB, vaultKey []byte, parentID *int64) ([]Entry, error) {
var rows *sql.Rows
var err error
if parentID != nil {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND parent_id = ? ORDER BY type, title`, *parentID)
} else {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`)
}
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.Scopes, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntryListMeta returns metadata only — no decryption.
func EntryListMeta(db *DB) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, data_level, scopes, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.DataLevel, &e.Scopes, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntrySearchFuzzy searches entries by title using LIKE.
func EntrySearchFuzzy(db *DB, vaultKey []byte, query string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND title LIKE ? ORDER BY title`, "%"+query+"%")
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.Scopes, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntryUpdateScopes updates only the scopes column.
func EntryUpdateScopes(db *DB, entryID int64, scopes string) error {
now := time.Now().UnixMilli()
result, err := db.Conn.Exec(
`UPDATE entries SET scopes = ?, updated_at = ? WHERE entry_id = ? AND deleted_at IS NULL`,
scopes, now, entryID)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return ErrNotFound
}
return nil
}
// EntryMarkReplicated sets replicated_at to now.
func EntryMarkReplicated(db *DB, entryID int64) error {
now := time.Now().UnixMilli()
_, err := db.Conn.Exec(`UPDATE entries SET replicated_at = ? WHERE entry_id = ?`, now, entryID)
return err
}
// EntryListUnreplicated returns entries needing replication.
func EntryListUnreplicated(db *DB) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, scopes, created_at, updated_at, version, deleted_at
FROM entries WHERE replicated_at IS NULL OR replicated_at < updated_at`)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
var deletedAt sql.NullInt64
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.Scopes, &e.CreatedAt, &e.UpdatedAt, &e.Version, &deletedAt); err != nil {
return nil, err
}
if deletedAt.Valid {
v := deletedAt.Int64
e.DeletedAt = &v
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// ---------------------------------------------------------------------------
// Agent lookup (agents are entries)
// ---------------------------------------------------------------------------
// AgentLookup finds an agent entry by agent_id using blind index.
func AgentLookup(db *DB, vaultKey []byte, agentIDHex string) (*AgentData, error) {
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return nil, err
}
idx := BlindIndex(hmacKey, agentIDHex)
var e Entry
err = db.Conn.QueryRow(
`SELECT entry_id, type, title, data, data_level
FROM entries WHERE title_idx = ? AND type = ? AND deleted_at IS NULL`,
idx, TypeAgent,
).Scan(&e.EntryID, &e.Type, &e.Title, &e.Data, &e.DataLevel)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
if e.DataLevel != DataLevelL1 || len(e.Data) == 0 {
return nil, nil
}
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return nil, err
}
dataText, err := Unpack(entryKey, e.Data)
if err != nil {
return nil, err
}
var vd VaultData
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
return nil, err
}
return &AgentData{
AgentID: vd.AgentID,
Name: vd.Title,
Scopes: vd.Scopes,
AllAccess: vd.AllAccess,
Admin: vd.Admin,
AllowedIPs: vd.AllowedIPs,
RateLimit: vd.RateLimit,
EntryID: e.EntryID,
}, nil
}
// AgentUpdateAllowedIPs re-encrypts the agent entry data with updated AllowedIPs.
func AgentUpdateAllowedIPs(db *DB, vaultKey []byte, agent *AgentData) error {
var e Entry
err := db.Conn.QueryRow(
`SELECT entry_id, data, data_level FROM entries WHERE entry_id = ? AND deleted_at IS NULL`,
int64(agent.EntryID),
).Scan(&e.EntryID, &e.Data, &e.DataLevel)
if err != nil {
return err
}
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return err
}
dataText, err := Unpack(entryKey, e.Data)
if err != nil {
return err
}
var vd VaultData
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
return err
}
vd.AllowedIPs = agent.AllowedIPs
updated, err := json.Marshal(vd)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(updated))
if err != nil {
return err
}
_, err = db.Conn.Exec(`UPDATE entries SET data = ?, updated_at = ? WHERE entry_id = ?`,
packed, time.Now().Unix(), int64(agent.EntryID))
return err
}
// AgentCreate creates an agent entry and returns the client credential token.
func AgentCreate(db *DB, vaultKey, l0 []byte, name string, scopes string, allAccess, admin bool) (*AgentData, string, error) {
// Generate random 16-byte agent_id and scope_id
agentID := make([]byte, 16)
rand.Read(agentID)
agentIDHex := hex.EncodeToString(agentID)
// Auto-assign scope if not provided
if scopes == "" || scopes == "auto" {
scopeID := make([]byte, 16)
rand.Read(scopeID)
scopes = hex.EncodeToString(scopeID)
}
// Build L2 from vault key (L1 = vaultKey[:8], L2 = vaultKey padded to 16)
// Actually L2 comes from the admin's PRF — we need it passed in.
// For now: L2 = vaultKey (which is already 16 bytes after normalization)
l2 := vaultKey
if len(l2) < 16 {
return nil, "", fmt.Errorf("vault key too short for L2: %d", len(l2))
}
l2 = l2[:16]
// Mint client credential token
credential, err := MintCredential(l0, l2, agentID)
if err != nil {
return nil, "", err
}
// Create agent entry
vd := &VaultData{
Title: name,
Type: TypeAgent,
AgentID: agentIDHex,
Scopes: scopes,
AllAccess: allAccess,
Admin: admin,
}
entry := &Entry{
Type: TypeAgent,
Title: name,
DataLevel: DataLevelL1,
Scopes: ScopeOwner, // agent entries are owner-only
VaultData: vd,
}
// Use agent_id as the blind index key (for lookup)
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return nil, "", err
}
entry.TitleIdx = BlindIndex(hmacKey, agentIDHex)
if err := EntryCreate(db, vaultKey, entry); err != nil {
return nil, "", err
}
agent := &AgentData{
AgentID: agentIDHex,
Name: name,
Scopes: scopes,
AllAccess: allAccess,
Admin: admin,
}
return agent, credential, nil
}
// ---------------------------------------------------------------------------
// Audit operations
// ---------------------------------------------------------------------------
// AuditLog records an audit event.
func AuditLog(db *DB, ev *AuditEvent) error {
if ev.EventID == 0 {
ev.EventID = HexID(NewID())
}
if ev.CreatedAt == 0 {
ev.CreatedAt = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`INSERT INTO audit_log (event_id, entry_id, title, action, actor, ip_addr, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
int64(ev.EventID), int64(ev.EntryID), ev.Title, ev.Action, ev.Actor, ev.IPAddr, ev.CreatedAt)
return err
}
// AuditList returns recent audit events.
func AuditList(db *DB, limit int) ([]AuditEvent, error) {
if limit <= 0 {
limit = 100
}
rows, err := db.Conn.Query(
`SELECT event_id, entry_id, title, action, actor, ip_addr, created_at
FROM audit_log ORDER BY created_at DESC LIMIT ?`, limit)
if err != nil {
return nil, err
}
defer rows.Close()
var events []AuditEvent
for rows.Next() {
var ev AuditEvent
var entryID sql.NullInt64
var title, ipAddr sql.NullString
if err := rows.Scan(&ev.EventID, &entryID, &title, &ev.Action, &ev.Actor, &ipAddr, &ev.CreatedAt); err != nil {
return nil, err
}
if entryID.Valid {
ev.EntryID = HexID(entryID.Int64)
}
if title.Valid {
ev.Title = title.String
}
if ipAddr.Valid {
ev.IPAddr = ipAddr.String
}
events = append(events, ev)
}
return events, rows.Err()
}
// EntryCount returns total entry count (excluding agents/scopes).
func EntryCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(
`SELECT COUNT(*) FROM entries WHERE deleted_at IS NULL AND type NOT IN (?, ?)`,
TypeAgent, TypeScope).Scan(&count)
return count, err
}
// ---------------------------------------------------------------------------
// WebAuthn
// ---------------------------------------------------------------------------
func StoreWebAuthnCredential(db *DB, c *WebAuthnCredential) error {
if c.CreatedAt == 0 {
c.CreatedAt = time.Now().Unix()
}
_, err := db.Conn.Exec(
`INSERT INTO webauthn_credentials (cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
int64(c.CredID), c.Name, c.PublicKey, c.CredentialID, c.PRFSalt, c.SignCount, c.CreatedAt)
return err
}
func GetWebAuthnCredentials(db *DB) ([]WebAuthnCredential, error) {
rows, err := db.Conn.Query(
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
FROM webauthn_credentials ORDER BY created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
var creds []WebAuthnCredential
for rows.Next() {
var c WebAuthnCredential
if err := rows.Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt); err != nil {
return nil, err
}
creds = append(creds, c)
}
return creds, rows.Err()
}
func WebAuthnCredentialCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM webauthn_credentials`).Scan(&count)
return count, err
}
func GetWebAuthnCredentialByRawID(db *DB, credentialID []byte) (*WebAuthnCredential, error) {
var c WebAuthnCredential
err := db.Conn.QueryRow(
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
FROM webauthn_credentials WHERE credential_id = ?`, credentialID,
).Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return &c, err
}
func DeleteWebAuthnCredential(db *DB, credID int64) error {
result, err := db.Conn.Exec(`DELETE FROM webauthn_credentials WHERE cred_id = ?`, credID)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return ErrNotFound
}
return nil
}
func UpdateWebAuthnSignCount(db *DB, credID int64, count int) error {
_, err := db.Conn.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE cred_id = ?`, count, credID)
return err
}
func StoreWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
_, err := db.Conn.Exec(
`INSERT INTO webauthn_challenges (challenge, type, created_at) VALUES (?, ?, ?)`,
challenge, challengeType, time.Now().Unix())
return err
}
func ConsumeWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
fiveMinAgo := time.Now().Unix() - 300
result, err := db.Conn.Exec(
`DELETE FROM webauthn_challenges WHERE challenge = ? AND type = ? AND created_at > ?`,
challenge, challengeType, fiveMinAgo)
if err != nil {
return err
}
affected, _ := result.RowsAffected()
if affected == 0 {
return errors.New("challenge not found or expired")
}
return nil
}
func CleanExpiredChallenges(db *DB) {
fiveMinAgo := time.Now().Unix() - 300
db.Conn.Exec(`DELETE FROM webauthn_challenges WHERE created_at < ?`, fiveMinAgo)
}