clavitor/clavis/clavis-vault/lib/dbcore.go

981 lines
29 KiB
Go

package lib
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
_ "github.com/mattn/go-sqlite3"
)
var (
ErrNotFound = errors.New("not found")
ErrVersionConflict = errors.New("version conflict: entry was modified")
)
const schema = `
CREATE TABLE IF NOT EXISTS entries (
entry_id INTEGER PRIMARY KEY,
parent_id INTEGER NOT NULL DEFAULT 0,
type TEXT NOT NULL,
title TEXT NOT NULL,
title_idx BLOB NOT NULL,
data BLOB NOT NULL,
data_level INTEGER NOT NULL DEFAULT 1,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
deleted_at INTEGER,
checksum INTEGER
);
CREATE INDEX IF NOT EXISTS idx_entries_parent ON entries(parent_id);
CREATE INDEX IF NOT EXISTS idx_entries_type ON entries(type);
CREATE INDEX IF NOT EXISTS idx_entries_title_idx ON entries(title_idx);
CREATE INDEX IF NOT EXISTS idx_entries_deleted ON entries(deleted_at);
CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
actor TEXT NOT NULL DEFAULT 'web'
);
CREATE TABLE IF NOT EXISTS audit_log (
event_id INTEGER PRIMARY KEY,
entry_id INTEGER,
title TEXT,
action TEXT NOT NULL,
actor TEXT NOT NULL,
ip_addr TEXT,
created_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_entry ON audit_log(entry_id);
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_log(created_at);
CREATE TABLE IF NOT EXISTS webauthn_credentials (
cred_id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
public_key BLOB NOT NULL,
credential_id BLOB NOT NULL DEFAULT X'',
prf_salt BLOB NOT NULL,
sign_count INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS webauthn_challenges (
challenge BLOB PRIMARY KEY,
type TEXT NOT NULL,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS mcp_tokens (
id INTEGER PRIMARY KEY,
label TEXT NOT NULL,
token TEXT UNIQUE NOT NULL,
tags TEXT,
entry_ids TEXT,
read_only INTEGER NOT NULL DEFAULT 0,
expires_at INTEGER NOT NULL DEFAULT 0,
last_used INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS agents (
id INTEGER PRIMARY KEY,
name TEXT UNIQUE NOT NULL,
ip_whitelist TEXT DEFAULT '["0.0.0.0/0"]',
rate_limit_minute INTEGER DEFAULT 5,
rate_limit_hour INTEGER DEFAULT 10,
status TEXT DEFAULT 'active',
locked_reason TEXT,
locked_at INTEGER DEFAULT 0,
last_used INTEGER DEFAULT 0,
last_ip TEXT,
created_at INTEGER NOT NULL
);
CREATE TABLE IF NOT EXISTS agent_requests (
id INTEGER PRIMARY KEY,
agent_id INTEGER NOT NULL,
ip TEXT NOT NULL,
path TEXT NOT NULL DEFAULT '',
timestamp INTEGER NOT NULL,
FOREIGN KEY (agent_id) REFERENCES agents(id)
);
CREATE INDEX IF NOT EXISTS idx_agent_requests ON agent_requests(agent_id, timestamp);
CREATE TABLE IF NOT EXISTS vault_lock (
id INTEGER PRIMARY KEY CHECK (id = 1),
locked INTEGER DEFAULT 0,
locked_reason TEXT,
locked_at INTEGER DEFAULT 0
);
`
// OpenDB opens the SQLite database.
func OpenDB(dbPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
return &DB{Conn: conn, DBPath: dbPath}, nil
}
// MigrateDB runs the schema migrations.
func MigrateDB(db *DB) error {
if _, err := db.Conn.Exec(schema); err != nil {
return err
}
// Migration: add credential_id column if missing (existing DBs)
_, err := db.Conn.Exec(`ALTER TABLE webauthn_credentials ADD COLUMN credential_id BLOB NOT NULL DEFAULT X''`)
if err != nil && !strings.Contains(err.Error(), "duplicate column") {
// Ignore "duplicate column" — migration already applied
}
// Seed vault_lock row
db.Conn.Exec(`INSERT OR IGNORE INTO vault_lock (id) VALUES (1)`)
return nil
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.Conn.Close()
}
// ---------------------------------------------------------------------------
// Entry operations
// ---------------------------------------------------------------------------
// EntryCreate creates a new entry.
func EntryCreate(db *DB, vaultKey []byte, e *Entry) error {
if e.EntryID == 0 {
e.EntryID = HexID(NewID())
}
now := time.Now().UnixMilli()
e.CreatedAt = now
e.UpdatedAt = now
e.Version = 1
if e.DataLevel == 0 {
e.DataLevel = DataLevelL1
}
// Derive keys and encrypt
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return err
}
// Create blind index for title
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
// Pack VaultData if present
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
_, err = db.Conn.Exec(
`INSERT INTO entries (entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
int64(e.EntryID), int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, e.CreatedAt, e.UpdatedAt, e.Version,
)
return err
}
// EntryGet retrieves an entry by ID.
func EntryGet(db *DB, vaultKey []byte, entryID int64) (*Entry, error) {
var e Entry
var deletedAt sql.NullInt64
err := db.Conn.QueryRow(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version, deleted_at
FROM entries WHERE entry_id = ?`, entryID,
).Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version, &deletedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if deletedAt.Valid {
v := deletedAt.Int64
e.DeletedAt = &v
}
// Unpack data
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return nil, err
}
dataText, err := Unpack(entryKey, e.Data)
if err != nil {
return nil, err
}
var vd VaultData
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
return nil, err
}
e.VaultData = &vd
}
return &e, nil
}
// EntryUpdate updates an existing entry with optimistic locking.
func EntryUpdate(db *DB, vaultKey []byte, e *Entry) error {
now := time.Now().UnixMilli()
// Derive keys
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return err
}
// Update blind index
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
// Pack VaultData if present
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
result, err := db.Conn.Exec(
`UPDATE entries SET parent_id=?, type=?, title=?, title_idx=?, data=?, data_level=?, updated_at=?, version=version+1
WHERE entry_id = ? AND version = ? AND deleted_at IS NULL`,
int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, now,
int64(e.EntryID), e.Version,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrVersionConflict
}
e.Version++
e.UpdatedAt = now
return nil
}
// EntryDelete soft-deletes an entry.
func EntryDelete(db *DB, entryID int64) error {
now := time.Now().UnixMilli()
result, err := db.Conn.Exec(
`UPDATE entries SET deleted_at = ?, updated_at = ? WHERE entry_id = ? AND deleted_at IS NULL`,
now, now, entryID,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
// EntryList returns all non-deleted entries, optionally filtered by parent.
func EntryList(db *DB, vaultKey []byte, parentID *int64) ([]Entry, error) {
var rows *sql.Rows
var err error
if parentID != nil {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND parent_id = ? ORDER BY type, title`, *parentID,
)
} else {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`,
)
}
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
// Unpack L1 data
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
if err == nil {
dataText, err := Unpack(entryKey, e.Data)
if err == nil {
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntryListMeta returns entry metadata only — no decryption, no field data.
// Used for list views. Individual entries fetched on demand via EntryGet.
func EntryListMeta(db *DB) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntrySearch searches entries by title (blind index lookup).
func EntrySearch(db *DB, vaultKey []byte, query string) ([]Entry, error) {
hmacKey, err := DeriveHMACKey(vaultKey)
if err != nil {
return nil, err
}
idx := BlindIndex(hmacKey, strings.ToLower(query))
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND title_idx = ? ORDER BY title`, idx,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntrySearchFuzzy searches entries by title using LIKE (less secure but more practical).
func EntrySearchFuzzy(db *DB, vaultKey []byte, query string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND title LIKE ? ORDER BY title`, "%"+query+"%",
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// ---------------------------------------------------------------------------
// Session operations
// ---------------------------------------------------------------------------
// SessionCreate creates a new session.
func SessionCreate(db *DB, ttl int64, actor string) (*Session, error) {
now := time.Now().UnixMilli()
s := &Session{
Token: GenerateToken(),
CreatedAt: now,
ExpiresAt: now + (ttl * 1000),
Actor: actor,
}
_, err := db.Conn.Exec(
`INSERT INTO sessions (token, created_at, expires_at, actor) VALUES (?, ?, ?, ?)`,
s.Token, s.CreatedAt, s.ExpiresAt, s.Actor,
)
return s, err
}
// SessionGet retrieves a session by token.
func SessionGet(db *DB, token string) (*Session, error) {
var s Session
err := db.Conn.QueryRow(
`SELECT token, created_at, expires_at, actor FROM sessions WHERE token = ?`, token,
).Scan(&s.Token, &s.CreatedAt, &s.ExpiresAt, &s.Actor)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
// Check expiry
if s.ExpiresAt < time.Now().UnixMilli() {
return nil, nil
}
return &s, nil
}
// SessionDelete deletes a session.
func SessionDelete(db *DB, token string) error {
_, err := db.Conn.Exec(`DELETE FROM sessions WHERE token = ?`, token)
return err
}
// ---------------------------------------------------------------------------
// Audit operations
// ---------------------------------------------------------------------------
// AuditLog records an audit event.
func AuditLog(db *DB, ev *AuditEvent) error {
if ev.EventID == 0 {
ev.EventID = HexID(NewID())
}
if ev.CreatedAt == 0 {
ev.CreatedAt = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`INSERT INTO audit_log (event_id, entry_id, title, action, actor, ip_addr, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
int64(ev.EventID), int64(ev.EntryID), ev.Title, ev.Action, ev.Actor, ev.IPAddr, ev.CreatedAt,
)
return err
}
// AuditList returns recent audit events.
func AuditList(db *DB, limit int) ([]AuditEvent, error) {
if limit <= 0 {
limit = 100
}
rows, err := db.Conn.Query(
`SELECT event_id, entry_id, title, action, actor, ip_addr, created_at
FROM audit_log ORDER BY created_at DESC LIMIT ?`, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var events []AuditEvent
for rows.Next() {
var ev AuditEvent
var entryID sql.NullInt64
var title, ipAddr sql.NullString
if err := rows.Scan(&ev.EventID, &entryID, &title, &ev.Action, &ev.Actor, &ipAddr, &ev.CreatedAt); err != nil {
return nil, err
}
if entryID.Valid {
ev.EntryID = HexID(entryID.Int64)
}
if title.Valid {
ev.Title = title.String
}
if ipAddr.Valid {
ev.IPAddr = ipAddr.String
}
events = append(events, ev)
}
return events, rows.Err()
}
// EntryCount returns total entry count (for health check).
func EntryCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM entries WHERE deleted_at IS NULL`).Scan(&count)
return count, err
}
// ---------------------------------------------------------------------------
// MCP Token operations
// ---------------------------------------------------------------------------
// CreateMCPToken inserts a new scoped MCP token.
func CreateMCPToken(db *DB, t *MCPToken) error {
if t.ID == 0 {
t.ID = HexID(NewID())
}
if t.Token == "" {
t.Token = GenerateToken()
}
if t.CreatedAt == 0 {
t.CreatedAt = time.Now().Unix()
}
idsJSON, _ := json.Marshal(t.EntryIDs)
readOnly := 0
if t.ReadOnly {
readOnly = 1
}
_, err := db.Conn.Exec(
`INSERT INTO mcp_tokens (id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
int64(t.ID), t.Label, t.Token, "[]", string(idsJSON), readOnly, t.ExpiresAt, t.LastUsed, t.CreatedAt,
)
return err
}
// ListMCPTokens returns all MCP tokens.
func ListMCPTokens(db *DB) ([]MCPToken, error) {
rows, err := db.Conn.Query(
`SELECT id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at
FROM mcp_tokens ORDER BY created_at DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var tokens []MCPToken
for rows.Next() {
var t MCPToken
var tagsStr, idsStr string
var readOnly int
if err := rows.Scan(&t.ID, &t.Label, &t.Token, &tagsStr, &idsStr, &readOnly, &t.ExpiresAt, &t.LastUsed, &t.CreatedAt); err != nil {
return nil, err
}
t.ReadOnly = readOnly != 0
if idsStr != "" {
json.Unmarshal([]byte(idsStr), &t.EntryIDs)
}
tokens = append(tokens, t)
}
return tokens, rows.Err()
}
// GetMCPTokenByValue looks up an MCP token by its raw token string.
func GetMCPTokenByValue(db *DB, tokenValue string) (*MCPToken, error) {
var t MCPToken
var tagsStr, idsStr string
var readOnly int
err := db.Conn.QueryRow(
`SELECT id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at
FROM mcp_tokens WHERE token = ?`, tokenValue,
).Scan(&t.ID, &t.Label, &t.Token, &tagsStr, &idsStr, &readOnly, &t.ExpiresAt, &t.LastUsed, &t.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
t.ReadOnly = readOnly != 0
if idsStr != "" {
json.Unmarshal([]byte(idsStr), &t.EntryIDs)
}
return &t, nil
}
// DeleteMCPToken deletes an MCP token by ID.
func DeleteMCPToken(db *DB, id int64) error {
result, err := db.Conn.Exec(`DELETE FROM mcp_tokens WHERE id = ?`, id)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
// UpdateMCPTokenLastUsed updates the last_used timestamp on an MCP token.
func UpdateMCPTokenLastUsed(db *DB, id int64) error {
_, err := db.Conn.Exec(`UPDATE mcp_tokens SET last_used = ? WHERE id = ?`, time.Now().Unix(), id)
return err
}
// ---------------------------------------------------------------------------
// WebAuthn credential operations
// ---------------------------------------------------------------------------
// StoreWebAuthnCredential inserts a new WebAuthn credential.
func StoreWebAuthnCredential(db *DB, c *WebAuthnCredential) error {
if c.CreatedAt == 0 {
c.CreatedAt = time.Now().Unix()
}
_, err := db.Conn.Exec(
`INSERT INTO webauthn_credentials (cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
int64(c.CredID), c.Name, c.PublicKey, c.CredentialID, c.PRFSalt, c.SignCount, c.CreatedAt,
)
return err
}
// GetWebAuthnCredentials returns all registered WebAuthn credentials.
func GetWebAuthnCredentials(db *DB) ([]WebAuthnCredential, error) {
rows, err := db.Conn.Query(
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
FROM webauthn_credentials ORDER BY created_at DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var creds []WebAuthnCredential
for rows.Next() {
var c WebAuthnCredential
if err := rows.Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt); err != nil {
return nil, err
}
creds = append(creds, c)
}
return creds, rows.Err()
}
// WebAuthnCredentialCount returns the number of registered WebAuthn credentials.
func WebAuthnCredentialCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM webauthn_credentials`).Scan(&count)
return count, err
}
// GetFirstCredentialPublicKey returns the public key of the first registered credential.
// Returns nil, nil if no credentials exist yet.
func GetFirstCredentialPublicKey(db *DB) ([]byte, error) {
var pubkey []byte
err := db.Conn.QueryRow(
`SELECT public_key FROM webauthn_credentials ORDER BY created_at ASC LIMIT 1`,
).Scan(&pubkey)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return pubkey, err
}
// GetWebAuthnCredentialByRawID looks up a credential by its raw WebAuthn credential ID.
func GetWebAuthnCredentialByRawID(db *DB, credentialID []byte) (*WebAuthnCredential, error) {
var c WebAuthnCredential
err := db.Conn.QueryRow(
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
FROM webauthn_credentials WHERE credential_id = ?`, credentialID,
).Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
return &c, err
}
// DeleteWebAuthnCredential removes a WebAuthn credential by ID.
func DeleteWebAuthnCredential(db *DB, credID int64) error {
result, err := db.Conn.Exec(`DELETE FROM webauthn_credentials WHERE cred_id = ?`, credID)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
// UpdateWebAuthnSignCount increments the sign count for a credential.
func UpdateWebAuthnSignCount(db *DB, credID int64, count int) error {
_, err := db.Conn.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE cred_id = ?`, count, credID)
return err
}
// ---------------------------------------------------------------------------
// WebAuthn challenge operations
// ---------------------------------------------------------------------------
// StoreWebAuthnChallenge stores a challenge for later verification.
func StoreWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
_, err := db.Conn.Exec(
`INSERT INTO webauthn_challenges (challenge, type, created_at) VALUES (?, ?, ?)`,
challenge, challengeType, time.Now().Unix(),
)
return err
}
// ConsumeWebAuthnChallenge verifies and removes a challenge. Returns error if not found or expired (5min TTL).
func ConsumeWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
fiveMinAgo := time.Now().Unix() - 300
result, err := db.Conn.Exec(
`DELETE FROM webauthn_challenges WHERE challenge = ? AND type = ? AND created_at > ?`,
challenge, challengeType, fiveMinAgo,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return errors.New("challenge not found or expired")
}
return nil
}
// CleanExpiredChallenges removes challenges older than 5 minutes.
func CleanExpiredChallenges(db *DB) {
fiveMinAgo := time.Now().Unix() - 300
db.Conn.Exec(`DELETE FROM webauthn_challenges WHERE created_at < ?`, fiveMinAgo)
}
// ---------------------------------------------------------------------------
// Agent operations
// ---------------------------------------------------------------------------
// AgentCreate creates a new agent.
func AgentCreate(db *DB, a *Agent) error {
if a.ID == 0 {
a.ID = HexID(NewID())
}
a.CreatedAt = time.Now().UnixMilli()
if a.Status == "" {
a.Status = AgentStatusActive
}
wl, _ := json.Marshal(a.IPWhitelist)
if a.RateLimitMinute == 0 {
a.RateLimitMinute = 5
}
if a.RateLimitHour == 0 {
a.RateLimitHour = 10
}
_, err := db.Conn.Exec(
`INSERT INTO agents (id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
int64(a.ID), a.Name, string(wl), a.RateLimitMinute, a.RateLimitHour, a.Status, a.CreatedAt)
return err
}
// AgentGet returns an agent by ID.
func AgentGet(db *DB, agentID int64) (*Agent, error) {
var a Agent
var wlStr string
err := db.Conn.QueryRow(
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
FROM agents WHERE id = ?`, agentID,
).Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
return &a, nil
}
// AgentGetByName returns an agent by name.
func AgentGetByName(db *DB, name string) (*Agent, error) {
var a Agent
var wlStr string
err := db.Conn.QueryRow(
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
FROM agents WHERE name = ?`, name,
).Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
return &a, nil
}
// AgentList returns all agents.
func AgentList(db *DB) ([]Agent, error) {
rows, err := db.Conn.Query(
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
FROM agents ORDER BY created_at DESC`)
if err != nil {
return nil, err
}
defer rows.Close()
var agents []Agent
for rows.Next() {
var a Agent
var wlStr string
if err := rows.Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt); err != nil {
return nil, err
}
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
agents = append(agents, a)
}
return agents, rows.Err()
}
// AgentUpdateStatus sets an agent's status and optional reason.
func AgentUpdateStatus(db *DB, agentID int64, status, reason string) error {
lockedAt := int64(0)
if status == AgentStatusLocked {
lockedAt = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`UPDATE agents SET status = ?, locked_reason = ?, locked_at = ? WHERE id = ?`,
status, reason, lockedAt, agentID)
return err
}
// AgentUpdateWhitelist sets an agent's IP whitelist.
func AgentUpdateWhitelist(db *DB, agentID int64, whitelist []string) error {
wl, _ := json.Marshal(whitelist)
_, err := db.Conn.Exec(`UPDATE agents SET ip_whitelist = ? WHERE id = ?`, string(wl), agentID)
return err
}
// AgentUpdateRateLimits sets an agent's rate limits.
func AgentUpdateRateLimits(db *DB, agentID int64, perMin, perHour int) error {
_, err := db.Conn.Exec(`UPDATE agents SET rate_limit_minute = ?, rate_limit_hour = ? WHERE id = ?`, perMin, perHour, agentID)
return err
}
// AgentUpdateLastUsed updates the last_used timestamp and IP.
func AgentUpdateLastUsed(db *DB, agentID int64, ip string) error {
_, err := db.Conn.Exec(`UPDATE agents SET last_used = ?, last_ip = ? WHERE id = ?`, time.Now().UnixMilli(), ip, agentID)
return err
}
// AgentDelete hard-deletes an agent.
func AgentDelete(db *DB, agentID int64) error {
db.Conn.Exec(`DELETE FROM agent_requests WHERE agent_id = ?`, agentID)
_, err := db.Conn.Exec(`DELETE FROM agents WHERE id = ?`, agentID)
return err
}
// AgentRequestLog logs an agent request for rate limiting.
// Repeated requests to the same path are not logged (don't count against limits).
func AgentRequestLog(db *DB, agentID int64, ip string, path string) error {
// Check if same path was requested in the last 60 seconds
var exists int
db.Conn.QueryRow(
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND path = ? AND timestamp > ?`,
agentID, path, time.Now().Unix()-60).Scan(&exists)
if exists > 0 {
return nil // same request, don't count
}
_, err := db.Conn.Exec(`INSERT INTO agent_requests (agent_id, ip, path, timestamp) VALUES (?, ?, ?, ?)`,
agentID, ip, path, time.Now().Unix())
return err
}
// AgentRequestCountMinute returns the number of distinct requests in the last 60 seconds.
func AgentRequestCountMinute(db *DB, agentID int64) (int, error) {
var count int
err := db.Conn.QueryRow(
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND timestamp > ?`,
agentID, time.Now().Unix()-60).Scan(&count)
return count, err
}
// AgentRequestCountHour returns the number of distinct requests in the last 3600 seconds.
func AgentRequestCountHour(db *DB, agentID int64) (int, error) {
var count int
err := db.Conn.QueryRow(
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND timestamp > ?`,
agentID, time.Now().Unix()-3600).Scan(&count)
return count, err
}
// AgentRequestCleanup deletes request logs older than 2 hours.
func AgentRequestCleanup(db *DB) {
db.Conn.Exec(`DELETE FROM agent_requests WHERE timestamp < ?`, time.Now().Unix()-7200)
}
// ---------------------------------------------------------------------------
// Vault lock operations
// ---------------------------------------------------------------------------
// VaultLockGet returns the vault lock state.
func VaultLockGet(db *DB) (*VaultLock, error) {
var vl VaultLock
var locked int
err := db.Conn.QueryRow(`SELECT locked, COALESCE(locked_reason,''), locked_at FROM vault_lock WHERE id = 1`).
Scan(&locked, &vl.LockedReason, &vl.LockedAt)
if err != nil {
return nil, err
}
vl.Locked = locked != 0
return &vl, nil
}
// VaultLockSet sets the vault lock state.
func VaultLockSet(db *DB, locked bool, reason string) error {
lockedInt := 0
lockedAt := int64(0)
if locked {
lockedInt = 1
lockedAt = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(`UPDATE vault_lock SET locked = ?, locked_reason = ?, locked_at = ? WHERE id = 1`,
lockedInt, reason, lockedAt)
return err
}