clawvault/lib/dbcore.go

470 lines
12 KiB
Go

package lib
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"strings"
"time"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
)
var (
ErrNotFound = errors.New("not found")
ErrVersionConflict = errors.New("version conflict: entry was modified")
)
const schema = `
CREATE TABLE IF NOT EXISTS entries (
entry_id TEXT PRIMARY KEY,
parent_id TEXT NOT NULL DEFAULT '',
type TEXT NOT NULL,
title TEXT NOT NULL,
title_idx BLOB NOT NULL,
data BLOB NOT NULL,
data_level INTEGER NOT NULL DEFAULT 1,
created_at INTEGER NOT NULL,
updated_at INTEGER NOT NULL,
version INTEGER NOT NULL DEFAULT 1,
deleted_at INTEGER
);
CREATE INDEX IF NOT EXISTS idx_entries_parent ON entries(parent_id);
CREATE INDEX IF NOT EXISTS idx_entries_type ON entries(type);
CREATE INDEX IF NOT EXISTS idx_entries_title_idx ON entries(title_idx);
CREATE INDEX IF NOT EXISTS idx_entries_deleted ON entries(deleted_at);
CREATE TABLE IF NOT EXISTS sessions (
token TEXT PRIMARY KEY,
created_at INTEGER NOT NULL,
expires_at INTEGER NOT NULL,
actor TEXT NOT NULL DEFAULT 'web'
);
CREATE TABLE IF NOT EXISTS audit_log (
event_id TEXT PRIMARY KEY,
entry_id TEXT,
title TEXT,
action TEXT NOT NULL,
actor TEXT NOT NULL,
ip_addr TEXT,
created_at INTEGER NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_audit_entry ON audit_log(entry_id);
CREATE INDEX IF NOT EXISTS idx_audit_created ON audit_log(created_at);
CREATE TABLE IF NOT EXISTS webauthn_credentials (
cred_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
public_key BLOB NOT NULL,
prf_salt BLOB NOT NULL,
sign_count INTEGER NOT NULL DEFAULT 0,
created_at INTEGER NOT NULL
);
`
// OpenDB opens the SQLite database.
func OpenDB(dbPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
return &DB{Conn: conn}, nil
}
// MigrateDB runs the schema migrations.
func MigrateDB(db *DB) error {
_, err := db.Conn.Exec(schema)
return err
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.Conn.Close()
}
// ---------------------------------------------------------------------------
// Entry operations
// ---------------------------------------------------------------------------
// EntryCreate creates a new entry.
func EntryCreate(db *DB, cfg *Config, e *Entry) error {
if e.EntryID == "" {
e.EntryID = uuid.New().String()
}
now := time.Now().UnixMilli()
e.CreatedAt = now
e.UpdatedAt = now
e.Version = 1
if e.DataLevel == 0 {
e.DataLevel = DataLevelL1
}
// Derive keys and encrypt
entryKey, err := DeriveEntryKey(cfg.VaultKey, e.EntryID)
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(cfg.VaultKey)
if err != nil {
return err
}
// Create blind index for title
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
// Pack VaultData if present
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
_, err = db.Conn.Exec(
`INSERT INTO entries (entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
e.EntryID, e.ParentID, e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, e.CreatedAt, e.UpdatedAt, e.Version,
)
return err
}
// EntryGet retrieves an entry by ID.
func EntryGet(db *DB, cfg *Config, entryID string) (*Entry, error) {
var e Entry
var deletedAt sql.NullInt64
err := db.Conn.QueryRow(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version, deleted_at
FROM entries WHERE entry_id = ?`, entryID,
).Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version, &deletedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrNotFound
}
if err != nil {
return nil, err
}
if deletedAt.Valid {
v := deletedAt.Int64
e.DeletedAt = &v
}
// Unpack data
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, err := DeriveEntryKey(cfg.VaultKey, e.EntryID)
if err != nil {
return nil, err
}
dataText, err := Unpack(entryKey, e.Data)
if err != nil {
return nil, err
}
var vd VaultData
if err := json.Unmarshal([]byte(dataText), &vd); err != nil {
return nil, err
}
e.VaultData = &vd
}
return &e, nil
}
// EntryUpdate updates an existing entry with optimistic locking.
func EntryUpdate(db *DB, cfg *Config, e *Entry) error {
now := time.Now().UnixMilli()
// Derive keys
entryKey, err := DeriveEntryKey(cfg.VaultKey, e.EntryID)
if err != nil {
return err
}
hmacKey, err := DeriveHMACKey(cfg.VaultKey)
if err != nil {
return err
}
// Update blind index
e.TitleIdx = BlindIndex(hmacKey, strings.ToLower(e.Title))
// Pack VaultData if present
if e.VaultData != nil {
dataJSON, err := json.Marshal(e.VaultData)
if err != nil {
return err
}
packed, err := Pack(entryKey, string(dataJSON))
if err != nil {
return err
}
e.Data = packed
}
result, err := db.Conn.Exec(
`UPDATE entries SET parent_id=?, type=?, title=?, title_idx=?, data=?, data_level=?, updated_at=?, version=version+1
WHERE entry_id = ? AND version = ? AND deleted_at IS NULL`,
e.ParentID, e.Type, e.Title, e.TitleIdx, e.Data, e.DataLevel, now,
e.EntryID, e.Version,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrVersionConflict
}
e.Version++
e.UpdatedAt = now
return nil
}
// EntryDelete soft-deletes an entry.
func EntryDelete(db *DB, entryID string) error {
now := time.Now().UnixMilli()
result, err := db.Conn.Exec(
`UPDATE entries SET deleted_at = ?, updated_at = ? WHERE entry_id = ? AND deleted_at IS NULL`,
now, now, entryID,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
// EntryList returns all non-deleted entries, optionally filtered by parent.
func EntryList(db *DB, cfg *Config, parentID *string) ([]Entry, error) {
var rows *sql.Rows
var err error
if parentID != nil {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND parent_id = ? ORDER BY type, title`, *parentID,
)
} else {
rows, err = db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL ORDER BY type, title`,
)
}
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
// Unpack L1 data
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, err := DeriveEntryKey(cfg.VaultKey, e.EntryID)
if err == nil {
dataText, err := Unpack(entryKey, e.Data)
if err == nil {
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntrySearch searches entries by title (blind index lookup).
func EntrySearch(db *DB, cfg *Config, query string) ([]Entry, error) {
hmacKey, err := DeriveHMACKey(cfg.VaultKey)
if err != nil {
return nil, err
}
idx := BlindIndex(hmacKey, strings.ToLower(query))
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND title_idx = ? ORDER BY title`, idx,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(cfg.VaultKey, e.EntryID)
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// EntrySearchFuzzy searches entries by title using LIKE (less secure but more practical).
func EntrySearchFuzzy(db *DB, cfg *Config, query string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, parent_id, type, title, title_idx, data, data_level, created_at, updated_at, version
FROM entries WHERE deleted_at IS NULL AND title LIKE ? ORDER BY title`, "%"+query+"%",
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
var e Entry
if err := rows.Scan(&e.EntryID, &e.ParentID, &e.Type, &e.Title, &e.TitleIdx, &e.Data, &e.DataLevel, &e.CreatedAt, &e.UpdatedAt, &e.Version); err != nil {
return nil, err
}
if len(e.Data) > 0 && e.DataLevel == DataLevelL1 {
entryKey, _ := DeriveEntryKey(cfg.VaultKey, e.EntryID)
dataText, _ := Unpack(entryKey, e.Data)
var vd VaultData
if json.Unmarshal([]byte(dataText), &vd) == nil {
e.VaultData = &vd
}
}
entries = append(entries, e)
}
return entries, rows.Err()
}
// ---------------------------------------------------------------------------
// Session operations
// ---------------------------------------------------------------------------
// SessionCreate creates a new session.
func SessionCreate(db *DB, ttl int64, actor string) (*Session, error) {
now := time.Now().UnixMilli()
s := &Session{
Token: GenerateToken(),
CreatedAt: now,
ExpiresAt: now + (ttl * 1000),
Actor: actor,
}
_, err := db.Conn.Exec(
`INSERT INTO sessions (token, created_at, expires_at, actor) VALUES (?, ?, ?, ?)`,
s.Token, s.CreatedAt, s.ExpiresAt, s.Actor,
)
return s, err
}
// SessionGet retrieves a session by token.
func SessionGet(db *DB, token string) (*Session, error) {
var s Session
err := db.Conn.QueryRow(
`SELECT token, created_at, expires_at, actor FROM sessions WHERE token = ?`, token,
).Scan(&s.Token, &s.CreatedAt, &s.ExpiresAt, &s.Actor)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
// Check expiry
if s.ExpiresAt < time.Now().UnixMilli() {
return nil, nil
}
return &s, nil
}
// SessionDelete deletes a session.
func SessionDelete(db *DB, token string) error {
_, err := db.Conn.Exec(`DELETE FROM sessions WHERE token = ?`, token)
return err
}
// ---------------------------------------------------------------------------
// Audit operations
// ---------------------------------------------------------------------------
// AuditLog records an audit event.
func AuditLog(db *DB, ev *AuditEvent) error {
if ev.EventID == "" {
ev.EventID = uuid.New().String()
}
if ev.CreatedAt == 0 {
ev.CreatedAt = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`INSERT INTO audit_log (event_id, entry_id, title, action, actor, ip_addr, created_at)
VALUES (?, ?, ?, ?, ?, ?, ?)`,
ev.EventID, ev.EntryID, ev.Title, ev.Action, ev.Actor, ev.IPAddr, ev.CreatedAt,
)
return err
}
// AuditList returns recent audit events.
func AuditList(db *DB, limit int) ([]AuditEvent, error) {
if limit <= 0 {
limit = 100
}
rows, err := db.Conn.Query(
`SELECT event_id, entry_id, title, action, actor, ip_addr, created_at
FROM audit_log ORDER BY created_at DESC LIMIT ?`, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var events []AuditEvent
for rows.Next() {
var ev AuditEvent
var entryID, title, ipAddr sql.NullString
if err := rows.Scan(&ev.EventID, &entryID, &title, &ev.Action, &ev.Actor, &ipAddr, &ev.CreatedAt); err != nil {
return nil, err
}
if entryID.Valid {
ev.EntryID = entryID.String
}
if title.Valid {
ev.Title = title.String
}
if ipAddr.Valid {
ev.IPAddr = ipAddr.String
}
events = append(events, ev)
}
return events, rows.Err()
}
// EntryCount returns total entry count (for health check).
func EntryCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM entries WHERE deleted_at IS NULL`).Scan(&count)
return count, err
}