958 lines
28 KiB
Go
958 lines
28 KiB
Go
package lib
|
|
|
|
import (
|
|
"database/sql"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
var (
|
|
ErrNotFound = errors.New("not found")
|
|
ErrVersionConflict = errors.New("version conflict: entry was modified")
|
|
)
|
|
|
|
const schema = `
|
|
CREATE TABLE IF NOT EXISTS entries (
|
|
entry_id INTEGER PRIMARY KEY,
|
|
parent_id INTEGER NOT NULL DEFAULT 0,
|
|
type TEXT NOT NULL,
|
|
title TEXT NOT NULL,
|
|
title_idx BLOB NOT NULL,
|
|
data BLOB NOT NULL,
|
|
data_level INTEGER NOT NULL DEFAULT 1,
|
|
created_at INTEGER NOT NULL,
|
|
updated_at INTEGER NOT NULL,
|
|
version INTEGER NOT NULL DEFAULT 1,
|
|
deleted_at INTEGER,
|
|
checksum INTEGER
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_entries_parent ON entries(parent_id);
|
|
CREATE INDEX IF NOT EXISTS idx_entries_type ON entries(type);
|
|
CREATE INDEX IF NOT EXISTS idx_entries_title_idx ON entries(title_idx);
|
|
CREATE INDEX IF NOT EXISTS idx_entries_deleted ON entries(deleted_at);
|
|
|
|
CREATE TABLE IF NOT EXISTS sessions (
|
|
token TEXT PRIMARY KEY,
|
|
created_at INTEGER NOT NULL,
|
|
expires_at INTEGER NOT NULL,
|
|
actor TEXT NOT NULL DEFAULT 'web'
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS audit_log (
|
|
event_id INTEGER PRIMARY KEY,
|
|
entry_id INTEGER,
|
|
title TEXT,
|
|
action TEXT NOT NULL,
|
|
actor TEXT NOT NULL,
|
|
ip_addr TEXT,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_audit_entry ON audit_log(entry_id);
|
|
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_log(created_at);
|
|
|
|
CREATE TABLE IF NOT EXISTS webauthn_credentials (
|
|
cred_id INTEGER PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
public_key BLOB NOT NULL,
|
|
credential_id BLOB NOT NULL DEFAULT X'',
|
|
prf_salt BLOB NOT NULL,
|
|
sign_count INTEGER NOT NULL DEFAULT 0,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS webauthn_challenges (
|
|
challenge BLOB PRIMARY KEY,
|
|
type TEXT NOT NULL,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS mcp_tokens (
|
|
id INTEGER PRIMARY KEY,
|
|
label TEXT NOT NULL,
|
|
token TEXT UNIQUE NOT NULL,
|
|
tags TEXT,
|
|
entry_ids TEXT,
|
|
read_only INTEGER NOT NULL DEFAULT 0,
|
|
expires_at INTEGER NOT NULL DEFAULT 0,
|
|
last_used INTEGER NOT NULL DEFAULT 0,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS agents (
|
|
id INTEGER PRIMARY KEY,
|
|
name TEXT UNIQUE NOT NULL,
|
|
ip_whitelist TEXT DEFAULT '["0.0.0.0/0"]',
|
|
rate_limit_minute INTEGER DEFAULT 5,
|
|
rate_limit_hour INTEGER DEFAULT 10,
|
|
status TEXT DEFAULT 'active',
|
|
locked_reason TEXT,
|
|
locked_at INTEGER DEFAULT 0,
|
|
last_used INTEGER DEFAULT 0,
|
|
last_ip TEXT,
|
|
created_at INTEGER NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS agent_requests (
|
|
id INTEGER PRIMARY KEY,
|
|
agent_id INTEGER NOT NULL,
|
|
ip TEXT NOT NULL,
|
|
path TEXT NOT NULL DEFAULT '',
|
|
timestamp INTEGER NOT NULL,
|
|
FOREIGN KEY (agent_id) REFERENCES agents(id)
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_agent_requests ON agent_requests(agent_id, timestamp);
|
|
|
|
CREATE TABLE IF NOT EXISTS vault_lock (
|
|
id INTEGER PRIMARY KEY CHECK (id = 1),
|
|
locked INTEGER DEFAULT 0,
|
|
locked_reason TEXT,
|
|
locked_at INTEGER DEFAULT 0
|
|
);
|
|
`
|
|
|
|
// OpenDB opens the SQLite database.
|
|
func OpenDB(dbPath string) (*DB, error) {
|
|
conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000")
|
|
if err != nil {
|
|
return nil, fmt.Errorf("open db: %w", err)
|
|
}
|
|
if err := conn.Ping(); err != nil {
|
|
return nil, fmt.Errorf("ping db: %w", err)
|
|
}
|
|
return &DB{Conn: conn, DBPath: dbPath}, nil
|
|
}
|
|
|
|
// MigrateDB runs the schema migrations.
|
|
func MigrateDB(db *DB) error {
|
|
if _, err := db.Conn.Exec(schema); err != nil {
|
|
return err
|
|
}
|
|
// Migration: add credential_id column if missing (existing DBs)
|
|
_, err := db.Conn.Exec(`ALTER TABLE webauthn_credentials ADD COLUMN credential_id BLOB NOT NULL DEFAULT X''`)
|
|
if err != nil && !strings.Contains(err.Error(), "duplicate column") {
|
|
// Ignore "duplicate column" — migration already applied
|
|
}
|
|
// Seed vault_lock row
|
|
db.Conn.Exec(`INSERT OR IGNORE INTO vault_lock (id) VALUES (1)`)
|
|
return nil
|
|
}
|
|
|
|
// Close closes the database connection.
|
|
func (db *DB) Close() error {
|
|
return db.Conn.Close()
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Entry operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// EntryCreate creates a new entry.
|
|
func EntryCreate(db *DB, vaultKey []byte, e *Entry) error {
|
|
if e.EntryID == 0 {
|
|
e.EntryID = HexID(NewID())
|
|
}
|
|
|
|
now := time.Now().UnixMilli()
|
|
e.CreatedAt = now
|
|
e.UpdatedAt = now
|
|
e.Version = 1
|
|
if e.DataLevel == 0 {
|
|
e.DataLevel = DataLevelL1
|
|
}
|
|
|
|
// Derive keys and encrypt
|
|
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hmacKey, err := DeriveHMACKey(vaultKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Create blind index for title
|
|
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
|
|
|
|
// Pack VaultData if present
|
|
if e.VaultData != nil {
|
|
dataJSON, err := json.Marshal(e.VaultData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
packed, err := Pack(entryKey, string(dataJSON))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.Data = packed
|
|
}
|
|
|
|
_, err = db.Conn.Exec(
|
|
`INSERT INTO entries (entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
int64(e.EntryID), int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, e.CreatedAt, e.UpdatedAt, e.Version,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// EntryGet retrieves an entry by ID.
|
|
func EntryGet(db *DB, vaultKey []byte, entryID int64) (*Entry, error) {
|
|
var e Entry
|
|
var deletedAt sql.NullInt64
|
|
err := db.Conn.QueryRow(
|
|
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version, deleted_at
|
|
FROM entries WHERE entry_id = ?`, entryID,
|
|
).Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version, &deletedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if deletedAt.Valid {
|
|
v := deletedAt.Int64
|
|
e.DeletedAt = &v
|
|
}
|
|
|
|
// Unpack data
|
|
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
|
|
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
dataText, err := Unpack(entryKey, e.Data)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var vd VaultData
|
|
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
|
|
return nil, err
|
|
}
|
|
e.VaultData = &vd
|
|
}
|
|
|
|
return &e, nil
|
|
}
|
|
|
|
// EntryUpdate updates an existing entry with optimistic locking.
|
|
func EntryUpdate(db *DB, vaultKey []byte, e *Entry) error {
|
|
now := time.Now().UnixMilli()
|
|
|
|
// Derive keys
|
|
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
hmacKey, err := DeriveHMACKey(vaultKey)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
// Update blind index
|
|
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
|
|
|
|
// Pack VaultData if present
|
|
if e.VaultData != nil {
|
|
dataJSON, err := json.Marshal(e.VaultData)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
packed, err := Pack(entryKey, string(dataJSON))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
e.Data = packed
|
|
}
|
|
|
|
result, err := db.Conn.Exec(
|
|
`UPDATE entries SET parent_id=?, type=?, title=?, title_idx=?, data=?, data_level=?, updated_at=?, version=version+1
|
|
WHERE entry_id = ? AND version = ? AND deleted_at IS NULL`,
|
|
int64(e.ParentID), e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, now,
|
|
int64(e.EntryID), e.Version,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return ErrVersionConflict
|
|
}
|
|
e.Version++
|
|
e.UpdatedAt = now
|
|
return nil
|
|
}
|
|
|
|
// EntryDelete soft-deletes an entry.
|
|
func EntryDelete(db *DB, entryID int64) error {
|
|
now := time.Now().UnixMilli()
|
|
result, err := db.Conn.Exec(
|
|
`UPDATE entries SET deleted_at = ?, updated_at = ? WHERE entry_id = ? AND deleted_at IS NULL`,
|
|
now, now, entryID,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// EntryList returns all non-deleted entries, optionally filtered by parent.
|
|
func EntryList(db *DB, vaultKey []byte, parentID *int64) ([]Entry, error) {
|
|
var rows *sql.Rows
|
|
var err error
|
|
|
|
if parentID != nil {
|
|
rows, err = db.Conn.Query(
|
|
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
|
|
FROM entries WHERE deleted_at IS NULL AND parent_id = ? ORDER BY type, title`, *parentID,
|
|
)
|
|
} else {
|
|
rows, err = db.Conn.Query(
|
|
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
|
|
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`,
|
|
)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entries []Entry
|
|
for rows.Next() {
|
|
var e Entry
|
|
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
|
|
return nil, err
|
|
}
|
|
// Unpack L1 data
|
|
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
|
|
entryKey, err := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
if err == nil {
|
|
dataText, err := Unpack(entryKey, e.Data)
|
|
if err == nil {
|
|
var vd VaultData
|
|
if json.Unmarshal([]byte(dataText), &vd) == nil {
|
|
e.VaultData = &vd
|
|
}
|
|
}
|
|
}
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
// EntrySearch searches entries by title (blind index lookup).
|
|
func EntrySearch(db *DB, vaultKey []byte, query string) ([]Entry, error) {
|
|
hmacKey, err := DeriveHMACKey(vaultKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
idx := BlindIndex(hmacKey, strings.ToLower(query))
|
|
|
|
rows, err := db.Conn.Query(
|
|
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
|
|
FROM entries WHERE deleted_at IS NULL AND title_idx = ? ORDER BY title`, idx,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entries []Entry
|
|
for rows.Next() {
|
|
var e Entry
|
|
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
|
|
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
dataText, _ := Unpack(entryKey, e.Data)
|
|
var vd VaultData
|
|
if json.Unmarshal([]byte(dataText), &vd) == nil {
|
|
e.VaultData = &vd
|
|
}
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
// EntrySearchFuzzy searches entries by title using LIKE (less secure but more practical).
|
|
func EntrySearchFuzzy(db *DB, vaultKey []byte, query string) ([]Entry, error) {
|
|
rows, err := db.Conn.Query(
|
|
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
|
|
FROM entries WHERE deleted_at IS NULL AND title LIKE ? ORDER BY title`, "%"+query+"%",
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var entries []Entry
|
|
for rows.Next() {
|
|
var e Entry
|
|
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
|
|
return nil, err
|
|
}
|
|
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
|
|
entryKey, _ := DeriveEntryKey(vaultKey, int64(e.EntryID))
|
|
dataText, _ := Unpack(entryKey, e.Data)
|
|
var vd VaultData
|
|
if json.Unmarshal([]byte(dataText), &vd) == nil {
|
|
e.VaultData = &vd
|
|
}
|
|
}
|
|
entries = append(entries, e)
|
|
}
|
|
return entries, rows.Err()
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Session operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// SessionCreate creates a new session.
|
|
func SessionCreate(db *DB, ttl int64, actor string) (*Session, error) {
|
|
now := time.Now().UnixMilli()
|
|
s := &Session{
|
|
Token: GenerateToken(),
|
|
CreatedAt: now,
|
|
ExpiresAt: now + (ttl * 1000),
|
|
Actor: actor,
|
|
}
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO sessions (token, created_at, expires_at, actor) VALUES (?, ?, ?, ?)`,
|
|
s.Token, s.CreatedAt, s.ExpiresAt, s.Actor,
|
|
)
|
|
return s, err
|
|
}
|
|
|
|
// SessionGet retrieves a session by token.
|
|
func SessionGet(db *DB, token string) (*Session, error) {
|
|
var s Session
|
|
err := db.Conn.QueryRow(
|
|
`SELECT token, created_at, expires_at, actor FROM sessions WHERE token = ?`, token,
|
|
).Scan(&s.Token, &s.CreatedAt, &s.ExpiresAt, &s.Actor)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
// Check expiry
|
|
if s.ExpiresAt < time.Now().UnixMilli() {
|
|
return nil, nil
|
|
}
|
|
return &s, nil
|
|
}
|
|
|
|
// SessionDelete deletes a session.
|
|
func SessionDelete(db *DB, token string) error {
|
|
_, err := db.Conn.Exec(`DELETE FROM sessions WHERE token = ?`, token)
|
|
return err
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Audit operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// AuditLog records an audit event.
|
|
func AuditLog(db *DB, ev *AuditEvent) error {
|
|
if ev.EventID == 0 {
|
|
ev.EventID = HexID(NewID())
|
|
}
|
|
if ev.CreatedAt == 0 {
|
|
ev.CreatedAt = time.Now().UnixMilli()
|
|
}
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO audit_log (event_id, entry_id, title, action, actor, ip_addr, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
int64(ev.EventID), int64(ev.EntryID), ev.Title, ev.Action, ev.Actor, ev.IPAddr, ev.CreatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// AuditList returns recent audit events.
|
|
func AuditList(db *DB, limit int) ([]AuditEvent, error) {
|
|
if limit <= 0 {
|
|
limit = 100
|
|
}
|
|
rows, err := db.Conn.Query(
|
|
`SELECT event_id, entry_id, title, action, actor, ip_addr, created_at
|
|
FROM audit_log ORDER BY created_at DESC LIMIT ?`, limit,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var events []AuditEvent
|
|
for rows.Next() {
|
|
var ev AuditEvent
|
|
var entryID sql.NullInt64
|
|
var title, ipAddr sql.NullString
|
|
if err := rows.Scan(&ev.EventID, &entryID, &title, &ev.Action, &ev.Actor, &ipAddr, &ev.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
if entryID.Valid {
|
|
ev.EntryID = HexID(entryID.Int64)
|
|
}
|
|
if title.Valid {
|
|
ev.Title = title.String
|
|
}
|
|
if ipAddr.Valid {
|
|
ev.IPAddr = ipAddr.String
|
|
}
|
|
events = append(events, ev)
|
|
}
|
|
return events, rows.Err()
|
|
}
|
|
|
|
// EntryCount returns total entry count (for health check).
|
|
func EntryCount(db *DB) (int, error) {
|
|
var count int
|
|
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM entries WHERE deleted_at IS NULL`).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// MCP Token operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// CreateMCPToken inserts a new scoped MCP token.
|
|
func CreateMCPToken(db *DB, t *MCPToken) error {
|
|
if t.ID == 0 {
|
|
t.ID = HexID(NewID())
|
|
}
|
|
if t.Token == "" {
|
|
t.Token = GenerateToken()
|
|
}
|
|
if t.CreatedAt == 0 {
|
|
t.CreatedAt = time.Now().Unix()
|
|
}
|
|
|
|
idsJSON, _ := json.Marshal(t.EntryIDs)
|
|
|
|
readOnly := 0
|
|
if t.ReadOnly {
|
|
readOnly = 1
|
|
}
|
|
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO mcp_tokens (id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)`,
|
|
int64(t.ID), t.Label, t.Token, "[]", string(idsJSON), readOnly, t.ExpiresAt, t.LastUsed, t.CreatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// ListMCPTokens returns all MCP tokens.
|
|
func ListMCPTokens(db *DB) ([]MCPToken, error) {
|
|
rows, err := db.Conn.Query(
|
|
`SELECT id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at
|
|
FROM mcp_tokens ORDER BY created_at DESC`,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var tokens []MCPToken
|
|
for rows.Next() {
|
|
var t MCPToken
|
|
var tagsStr, idsStr string
|
|
var readOnly int
|
|
if err := rows.Scan(&t.ID, &t.Label, &t.Token, &tagsStr, &idsStr, &readOnly, &t.ExpiresAt, &t.LastUsed, &t.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
t.ReadOnly = readOnly != 0
|
|
if idsStr != "" {
|
|
json.Unmarshal([]byte(idsStr), &t.EntryIDs)
|
|
}
|
|
tokens = append(tokens, t)
|
|
}
|
|
return tokens, rows.Err()
|
|
}
|
|
|
|
// GetMCPTokenByValue looks up an MCP token by its raw token string.
|
|
func GetMCPTokenByValue(db *DB, tokenValue string) (*MCPToken, error) {
|
|
var t MCPToken
|
|
var tagsStr, idsStr string
|
|
var readOnly int
|
|
err := db.Conn.QueryRow(
|
|
`SELECT id, label, token, tags, entry_ids, read_only, expires_at, last_used, created_at
|
|
FROM mcp_tokens WHERE token = ?`, tokenValue,
|
|
).Scan(&t.ID, &t.Label, &t.Token, &tagsStr, &idsStr, &readOnly, &t.ExpiresAt, &t.LastUsed, &t.CreatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
t.ReadOnly = readOnly != 0
|
|
if idsStr != "" {
|
|
json.Unmarshal([]byte(idsStr), &t.EntryIDs)
|
|
}
|
|
return &t, nil
|
|
}
|
|
|
|
// DeleteMCPToken deletes an MCP token by ID.
|
|
func DeleteMCPToken(db *DB, id int64) error {
|
|
result, err := db.Conn.Exec(`DELETE FROM mcp_tokens WHERE id = ?`, id)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateMCPTokenLastUsed updates the last_used timestamp on an MCP token.
|
|
func UpdateMCPTokenLastUsed(db *DB, id int64) error {
|
|
_, err := db.Conn.Exec(`UPDATE mcp_tokens SET last_used = ? WHERE id = ?`, time.Now().Unix(), id)
|
|
return err
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// WebAuthn credential operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// StoreWebAuthnCredential inserts a new WebAuthn credential.
|
|
func StoreWebAuthnCredential(db *DB, c *WebAuthnCredential) error {
|
|
if c.CreatedAt == 0 {
|
|
c.CreatedAt = time.Now().Unix()
|
|
}
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO webauthn_credentials (cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
int64(c.CredID), c.Name, c.PublicKey, c.CredentialID, c.PRFSalt, c.SignCount, c.CreatedAt,
|
|
)
|
|
return err
|
|
}
|
|
|
|
// GetWebAuthnCredentials returns all registered WebAuthn credentials.
|
|
func GetWebAuthnCredentials(db *DB) ([]WebAuthnCredential, error) {
|
|
rows, err := db.Conn.Query(
|
|
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
|
|
FROM webauthn_credentials ORDER BY created_at DESC`,
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
var creds []WebAuthnCredential
|
|
for rows.Next() {
|
|
var c WebAuthnCredential
|
|
if err := rows.Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
creds = append(creds, c)
|
|
}
|
|
return creds, rows.Err()
|
|
}
|
|
|
|
// WebAuthnCredentialCount returns the number of registered WebAuthn credentials.
|
|
func WebAuthnCredentialCount(db *DB) (int, error) {
|
|
var count int
|
|
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM webauthn_credentials`).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// GetFirstCredentialPublicKey returns the public key of the first registered credential.
|
|
// Returns nil, nil if no credentials exist yet.
|
|
func GetFirstCredentialPublicKey(db *DB) ([]byte, error) {
|
|
var pubkey []byte
|
|
err := db.Conn.QueryRow(
|
|
`SELECT public_key FROM webauthn_credentials ORDER BY created_at ASC LIMIT 1`,
|
|
).Scan(&pubkey)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
return pubkey, err
|
|
}
|
|
|
|
// GetWebAuthnCredentialByRawID looks up a credential by its raw WebAuthn credential ID.
|
|
func GetWebAuthnCredentialByRawID(db *DB, credentialID []byte) (*WebAuthnCredential, error) {
|
|
var c WebAuthnCredential
|
|
err := db.Conn.QueryRow(
|
|
`SELECT cred_id, name, public_key, credential_id, prf_salt, sign_count, created_at
|
|
FROM webauthn_credentials WHERE credential_id = ?`, credentialID,
|
|
).Scan(&c.CredID, &c.Name, &c.PublicKey, &c.CredentialID, &c.PRFSalt, &c.SignCount, &c.CreatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, ErrNotFound
|
|
}
|
|
return &c, err
|
|
}
|
|
|
|
// DeleteWebAuthnCredential removes a WebAuthn credential by ID.
|
|
func DeleteWebAuthnCredential(db *DB, credID int64) error {
|
|
result, err := db.Conn.Exec(`DELETE FROM webauthn_credentials WHERE cred_id = ?`, credID)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return ErrNotFound
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// UpdateWebAuthnSignCount increments the sign count for a credential.
|
|
func UpdateWebAuthnSignCount(db *DB, credID int64, count int) error {
|
|
_, err := db.Conn.Exec(`UPDATE webauthn_credentials SET sign_count = ? WHERE cred_id = ?`, count, credID)
|
|
return err
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// WebAuthn challenge operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// StoreWebAuthnChallenge stores a challenge for later verification.
|
|
func StoreWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO webauthn_challenges (challenge, type, created_at) VALUES (?, ?, ?)`,
|
|
challenge, challengeType, time.Now().Unix(),
|
|
)
|
|
return err
|
|
}
|
|
|
|
// ConsumeWebAuthnChallenge verifies and removes a challenge. Returns error if not found or expired (5min TTL).
|
|
func ConsumeWebAuthnChallenge(db *DB, challenge []byte, challengeType string) error {
|
|
fiveMinAgo := time.Now().Unix() - 300
|
|
result, err := db.Conn.Exec(
|
|
`DELETE FROM webauthn_challenges WHERE challenge = ? AND type = ? AND created_at > ?`,
|
|
challenge, challengeType, fiveMinAgo,
|
|
)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
affected, err := result.RowsAffected()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if affected == 0 {
|
|
return errors.New("challenge not found or expired")
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// CleanExpiredChallenges removes challenges older than 5 minutes.
|
|
func CleanExpiredChallenges(db *DB) {
|
|
fiveMinAgo := time.Now().Unix() - 300
|
|
db.Conn.Exec(`DELETE FROM webauthn_challenges WHERE created_at < ?`, fiveMinAgo)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Agent operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// AgentCreate creates a new agent.
|
|
func AgentCreate(db *DB, a *Agent) error {
|
|
if a.ID == 0 {
|
|
a.ID = HexID(NewID())
|
|
}
|
|
a.CreatedAt = time.Now().UnixMilli()
|
|
if a.Status == "" {
|
|
a.Status = AgentStatusActive
|
|
}
|
|
wl, _ := json.Marshal(a.IPWhitelist)
|
|
if a.RateLimitMinute == 0 {
|
|
a.RateLimitMinute = 5
|
|
}
|
|
if a.RateLimitHour == 0 {
|
|
a.RateLimitHour = 10
|
|
}
|
|
_, err := db.Conn.Exec(
|
|
`INSERT INTO agents (id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, created_at)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?)`,
|
|
int64(a.ID), a.Name, string(wl), a.RateLimitMinute, a.RateLimitHour, a.Status, a.CreatedAt)
|
|
return err
|
|
}
|
|
|
|
// AgentGet returns an agent by ID.
|
|
func AgentGet(db *DB, agentID int64) (*Agent, error) {
|
|
var a Agent
|
|
var wlStr string
|
|
err := db.Conn.QueryRow(
|
|
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
|
|
FROM agents WHERE id = ?`, agentID,
|
|
).Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
|
|
return &a, nil
|
|
}
|
|
|
|
// AgentGetByName returns an agent by name.
|
|
func AgentGetByName(db *DB, name string) (*Agent, error) {
|
|
var a Agent
|
|
var wlStr string
|
|
err := db.Conn.QueryRow(
|
|
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
|
|
FROM agents WHERE name = ?`, name,
|
|
).Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt)
|
|
if errors.Is(err, sql.ErrNoRows) {
|
|
return nil, nil
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
|
|
return &a, nil
|
|
}
|
|
|
|
// AgentList returns all agents.
|
|
func AgentList(db *DB) ([]Agent, error) {
|
|
rows, err := db.Conn.Query(
|
|
`SELECT id, name, ip_whitelist, rate_limit_minute, rate_limit_hour, status, COALESCE(locked_reason,''), locked_at, last_used, COALESCE(last_ip,''), created_at
|
|
FROM agents ORDER BY created_at DESC`)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
var agents []Agent
|
|
for rows.Next() {
|
|
var a Agent
|
|
var wlStr string
|
|
if err := rows.Scan(&a.ID, &a.Name, &wlStr, &a.RateLimitMinute, &a.RateLimitHour, &a.Status, &a.LockedReason, &a.LockedAt, &a.LastUsed, &a.LastIP, &a.CreatedAt); err != nil {
|
|
return nil, err
|
|
}
|
|
json.Unmarshal([]byte(wlStr), &a.IPWhitelist)
|
|
agents = append(agents, a)
|
|
}
|
|
return agents, rows.Err()
|
|
}
|
|
|
|
// AgentUpdateStatus sets an agent's status and optional reason.
|
|
func AgentUpdateStatus(db *DB, agentID int64, status, reason string) error {
|
|
lockedAt := int64(0)
|
|
if status == AgentStatusLocked {
|
|
lockedAt = time.Now().UnixMilli()
|
|
}
|
|
_, err := db.Conn.Exec(
|
|
`UPDATE agents SET status = ?, locked_reason = ?, locked_at = ? WHERE id = ?`,
|
|
status, reason, lockedAt, agentID)
|
|
return err
|
|
}
|
|
|
|
// AgentUpdateWhitelist sets an agent's IP whitelist.
|
|
func AgentUpdateWhitelist(db *DB, agentID int64, whitelist []string) error {
|
|
wl, _ := json.Marshal(whitelist)
|
|
_, err := db.Conn.Exec(`UPDATE agents SET ip_whitelist = ? WHERE id = ?`, string(wl), agentID)
|
|
return err
|
|
}
|
|
|
|
// AgentUpdateRateLimits sets an agent's rate limits.
|
|
func AgentUpdateRateLimits(db *DB, agentID int64, perMin, perHour int) error {
|
|
_, err := db.Conn.Exec(`UPDATE agents SET rate_limit_minute = ?, rate_limit_hour = ? WHERE id = ?`, perMin, perHour, agentID)
|
|
return err
|
|
}
|
|
|
|
// AgentUpdateLastUsed updates the last_used timestamp and IP.
|
|
func AgentUpdateLastUsed(db *DB, agentID int64, ip string) error {
|
|
_, err := db.Conn.Exec(`UPDATE agents SET last_used = ?, last_ip = ? WHERE id = ?`, time.Now().UnixMilli(), ip, agentID)
|
|
return err
|
|
}
|
|
|
|
// AgentDelete hard-deletes an agent.
|
|
func AgentDelete(db *DB, agentID int64) error {
|
|
db.Conn.Exec(`DELETE FROM agent_requests WHERE agent_id = ?`, agentID)
|
|
_, err := db.Conn.Exec(`DELETE FROM agents WHERE id = ?`, agentID)
|
|
return err
|
|
}
|
|
|
|
// AgentRequestLog logs an agent request for rate limiting.
|
|
// Repeated requests to the same path are not logged (don't count against limits).
|
|
func AgentRequestLog(db *DB, agentID int64, ip string, path string) error {
|
|
// Check if same path was requested in the last 60 seconds
|
|
var exists int
|
|
db.Conn.QueryRow(
|
|
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND path = ? AND timestamp > ?`,
|
|
agentID, path, time.Now().Unix()-60).Scan(&exists)
|
|
if exists > 0 {
|
|
return nil // same request, don't count
|
|
}
|
|
_, err := db.Conn.Exec(`INSERT INTO agent_requests (agent_id, ip, path, timestamp) VALUES (?, ?, ?, ?)`,
|
|
agentID, ip, path, time.Now().Unix())
|
|
return err
|
|
}
|
|
|
|
// AgentRequestCountMinute returns the number of distinct requests in the last 60 seconds.
|
|
func AgentRequestCountMinute(db *DB, agentID int64) (int, error) {
|
|
var count int
|
|
err := db.Conn.QueryRow(
|
|
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND timestamp > ?`,
|
|
agentID, time.Now().Unix()-60).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// AgentRequestCountHour returns the number of distinct requests in the last 3600 seconds.
|
|
func AgentRequestCountHour(db *DB, agentID int64) (int, error) {
|
|
var count int
|
|
err := db.Conn.QueryRow(
|
|
`SELECT COUNT(*) FROM agent_requests WHERE agent_id = ? AND timestamp > ?`,
|
|
agentID, time.Now().Unix()-3600).Scan(&count)
|
|
return count, err
|
|
}
|
|
|
|
// AgentRequestCleanup deletes request logs older than 2 hours.
|
|
func AgentRequestCleanup(db *DB) {
|
|
db.Conn.Exec(`DELETE FROM agent_requests WHERE timestamp < ?`, time.Now().Unix()-7200)
|
|
}
|
|
|
|
// ---------------------------------------------------------------------------
|
|
// Vault lock operations
|
|
// ---------------------------------------------------------------------------
|
|
|
|
// VaultLockGet returns the vault lock state.
|
|
func VaultLockGet(db *DB) (*VaultLock, error) {
|
|
var vl VaultLock
|
|
var locked int
|
|
err := db.Conn.QueryRow(`SELECT locked, COALESCE(locked_reason,''), locked_at FROM vault_lock WHERE id = 1`).
|
|
Scan(&locked, &vl.LockedReason, &vl.LockedAt)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
vl.Locked = locked != 0
|
|
return &vl, nil
|
|
}
|
|
|
|
// VaultLockSet sets the vault lock state.
|
|
func VaultLockSet(db *DB, locked bool, reason string) error {
|
|
lockedInt := 0
|
|
lockedAt := int64(0)
|
|
if locked {
|
|
lockedInt = 1
|
|
lockedAt = time.Now().UnixMilli()
|
|
}
|
|
_, err := db.Conn.Exec(`UPDATE vault_lock SET locked = ?, locked_reason = ?, locked_at = ? WHERE id = 1`,
|
|
lockedInt, reason, lockedAt)
|
|
return err
|
|
}
|