dealspace/lib/dbcore.go

816 lines
22 KiB
Go

package lib
import (
"crypto/subtle"
"database/sql"
"errors"
"fmt"
"os"
"time"
"github.com/google/uuid"
_ "github.com/mattn/go-sqlite3"
)
var (
ErrNotFound = errors.New("not found")
ErrVersionConflict = errors.New("version conflict: entry was modified by another request")
ErrSoftDeleted = errors.New("entry has been deleted")
)
// OpenDB opens (or creates) the SQLite database and runs migrations.
func OpenDB(dbPath string, migrationPath string) (*DB, error) {
conn, err := sql.Open("sqlite3", dbPath+"?_journal_mode=WAL&_foreign_keys=ON&_busy_timeout=5000")
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
if err := conn.Ping(); err != nil {
return nil, fmt.Errorf("ping db: %w", err)
}
db := &DB{Conn: conn}
if migrationPath != "" {
if err := db.runMigration(migrationPath); err != nil {
return nil, fmt.Errorf("migration: %w", err)
}
}
return db, nil
}
func (db *DB) runMigration(path string) error {
sql, err := os.ReadFile(path)
if err != nil {
return err
}
_, err = db.Conn.Exec(string(sql))
return err
}
// Close closes the database connection.
func (db *DB) Close() error {
return db.Conn.Close()
}
// ---------------------------------------------------------------------------
// THE THREE CHOKE POINTS — all entry access goes through these
// ---------------------------------------------------------------------------
// EntryRead queries entries with RBAC enforcement.
// Buyers cannot see pre_dataroom entries.
func EntryRead(db *DB, cfg *Config, actorID, projectID string, filter EntryFilter) ([]Entry, error) {
// Check read access
role, err := GetUserHighestRole(db, actorID, projectID)
if err != nil {
return nil, err
}
entries, err := entryQuery(db, filter)
if err != nil {
return nil, err
}
// Filter: buyers cannot see pre_dataroom entries
if IsBuyerRole(role) {
filtered := entries[:0]
for _, e := range entries {
if e.Stage != StagePreDataroom {
filtered = append(filtered, e)
}
}
entries = filtered
}
// Unpack encrypted fields
for i := range entries {
if err := unpackEntry(cfg, &entries[i]); err != nil {
return nil, err
}
}
return entries, nil
}
// EntryWrite creates or updates entries with RBAC enforcement and optimistic locking.
func EntryWrite(db *DB, cfg *Config, actorID string, entries ...*Entry) error {
now := time.Now().UnixMilli()
for _, entry := range entries {
// Check write access
if err := CheckAccessWrite(db, actorID, entry.ProjectID, ""); err != nil {
return err
}
// Pack encrypted fields
if err := packEntry(cfg, entry); err != nil {
return err
}
if entry.EntryID == "" {
// New entry
entry.EntryID = uuid.New().String()
entry.CreatedBy = actorID
entry.CreatedAt = now
entry.UpdatedAt = now
entry.Version = 1
entry.KeyVersion = 1
if err := entryInsert(db, entry); err != nil {
return err
}
} else {
// Update with optimistic locking
entry.UpdatedAt = now
if err := entryUpdate(db, entry); err != nil {
return err
}
}
}
return nil
}
// EntryDelete soft-deletes entries with RBAC enforcement.
func EntryDelete(db *DB, actorID, projectID string, entryIDs ...string) error {
if err := CheckAccessDelete(db, actorID, projectID, ""); err != nil {
return err
}
now := time.Now().UnixMilli()
for _, id := range entryIDs {
if err := entrySoftDelete(db, id, actorID, now); err != nil {
return err
}
}
return nil
}
// entryReadSystem is the unexported system bypass — NO RBAC checks.
// Used internally for tree walks (e.g., ResolveWorkstreamID).
func entryReadSystem(db *DB, entryID string) (*Entry, error) {
row := db.Conn.QueryRow(
`SELECT entry_id, project_id, parent_id, type, depth, search_key, search_key2,
summary, data, stage, assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version, created_at, updated_at, created_by
FROM entries WHERE entry_id = ?`, entryID,
)
return scanEntry(row)
}
// ---------------------------------------------------------------------------
// Internal SQL — never exported, never called outside this file
// ---------------------------------------------------------------------------
func entryQuery(db *DB, f EntryFilter) ([]Entry, error) {
q := `SELECT entry_id, project_id, parent_id, type, depth, search_key, search_key2,
summary, data, stage, assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version, created_at, updated_at, created_by
FROM entries WHERE deleted_at IS NULL`
var args []any
if f.ProjectID != "" {
q += " AND project_id = ?"
args = append(args, f.ProjectID)
}
if f.ParentID != nil {
q += " AND parent_id = ?"
args = append(args, *f.ParentID)
}
if f.Type != "" {
q += " AND type = ?"
args = append(args, f.Type)
}
if f.Stage != "" {
q += " AND stage = ?"
args = append(args, f.Stage)
}
if f.AssigneeID != "" {
q += " AND assignee_id = ?"
args = append(args, f.AssigneeID)
}
if f.SearchKey != nil {
q += " AND search_key = ?"
args = append(args, f.SearchKey)
}
q += " ORDER BY created_at DESC"
if f.Limit > 0 {
q += " LIMIT ?"
args = append(args, f.Limit)
}
if f.Offset > 0 {
q += " OFFSET ?"
args = append(args, f.Offset)
}
rows, err := db.Conn.Query(q, args...)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
e, err := scanEntryRow(rows)
if err != nil {
return nil, err
}
entries = append(entries, *e)
}
return entries, rows.Err()
}
func entryInsert(db *DB, e *Entry) error {
_, err := db.Conn.Exec(
`INSERT INTO entries (entry_id, project_id, parent_id, type, depth,
search_key, search_key2, summary, data, stage,
assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version,
created_at, updated_at, created_by)
VALUES (?,?,?,?,?, ?,?,?,?,?, ?,?,?, ?,?,?,?, ?,?,?)`,
e.EntryID, e.ProjectID, e.ParentID, e.Type, e.Depth,
e.SearchKey, e.SearchKey2, e.Summary, e.Data, e.Stage,
e.AssigneeID, e.ReturnToID, e.OriginID,
e.Version, e.DeletedAt, e.DeletedBy, e.KeyVersion,
e.CreatedAt, e.UpdatedAt, e.CreatedBy,
)
return err
}
func entryUpdate(db *DB, e *Entry) error {
result, err := db.Conn.Exec(
`UPDATE entries SET
parent_id=?, type=?, depth=?, search_key=?, search_key2=?,
summary=?, data=?, stage=?,
assignee_id=?, return_to_id=?, origin_id=?,
version=version+1, updated_at=?
WHERE entry_id = ? AND version = ? AND deleted_at IS NULL`,
e.ParentID, e.Type, e.Depth, e.SearchKey, e.SearchKey2,
e.Summary, e.Data, e.Stage,
e.AssigneeID, e.ReturnToID, e.OriginID,
e.UpdatedAt,
e.EntryID, e.Version,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrVersionConflict
}
e.Version++
return nil
}
func entrySoftDelete(db *DB, entryID, actorID string, now int64) error {
result, err := db.Conn.Exec(
`UPDATE entries SET deleted_at = ?, deleted_by = ?, updated_at = ?
WHERE entry_id = ? AND deleted_at IS NULL`,
now, actorID, now, entryID,
)
if err != nil {
return err
}
affected, err := result.RowsAffected()
if err != nil {
return err
}
if affected == 0 {
return ErrNotFound
}
return nil
}
// scanEntry scans a single row into an Entry.
func scanEntry(row *sql.Row) (*Entry, error) {
var e Entry
err := row.Scan(
&e.EntryID, &e.ProjectID, &e.ParentID, &e.Type, &e.Depth,
&e.SearchKey, &e.SearchKey2, &e.Summary, &e.Data, &e.Stage,
&e.AssigneeID, &e.ReturnToID, &e.OriginID,
&e.Version, &e.DeletedAt, &e.DeletedBy, &e.KeyVersion,
&e.CreatedAt, &e.UpdatedAt, &e.CreatedBy,
)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
return &e, nil
}
// scanEntryRow scans a row from sql.Rows into an Entry.
func scanEntryRow(rows *sql.Rows) (*Entry, error) {
var e Entry
err := rows.Scan(
&e.EntryID, &e.ProjectID, &e.ParentID, &e.Type, &e.Depth,
&e.SearchKey, &e.SearchKey2, &e.Summary, &e.Data, &e.Stage,
&e.AssigneeID, &e.ReturnToID, &e.OriginID,
&e.Version, &e.DeletedAt, &e.DeletedBy, &e.KeyVersion,
&e.CreatedAt, &e.UpdatedAt, &e.CreatedBy,
)
if err != nil {
return nil, err
}
return &e, nil
}
// packEntry encrypts the text fields of an entry before storage.
func packEntry(cfg *Config, e *Entry) error {
key, err := DeriveProjectKey(cfg.MasterKey, e.ProjectID)
if err != nil {
return err
}
if e.SummaryText != "" {
packed, err := Pack(key, e.SummaryText)
if err != nil {
return err
}
e.Summary = packed
}
if e.DataText != "" {
packed, err := Pack(key, e.DataText)
if err != nil {
return err
}
e.Data = packed
}
return nil
}
// unpackEntry decrypts the packed fields of an entry after read.
func unpackEntry(cfg *Config, e *Entry) error {
key, err := DeriveProjectKey(cfg.MasterKey, e.ProjectID)
if err != nil {
return err
}
if len(e.Summary) > 0 {
text, err := Unpack(key, e.Summary)
if err != nil {
return err
}
e.SummaryText = text
}
if len(e.Data) > 0 {
text, err := Unpack(key, e.Data)
if err != nil {
return err
}
e.DataText = text
}
return nil
}
// ---------------------------------------------------------------------------
// User operations
// ---------------------------------------------------------------------------
// UserCreate inserts a new user.
func UserCreate(db *DB, u *User) error {
_, err := db.Conn.Exec(
`INSERT INTO users (user_id, email, name, password, org_id, org_name, mfa_secret, active, created_at, updated_at)
VALUES (?,?,?,?,?,?,?,?,?,?)`,
u.UserID, u.Email, u.Name, u.Password, u.OrgID, u.OrgName, u.MFASecret, u.Active, u.CreatedAt, u.UpdatedAt,
)
return err
}
// UserByEmail looks up a user by email.
func UserByEmail(db *DB, email string) (*User, error) {
var u User
var active int
err := db.Conn.QueryRow(
`SELECT user_id, email, name, password, org_id, org_name, mfa_secret, active, created_at, updated_at
FROM users WHERE email = ?`, email,
).Scan(&u.UserID, &u.Email, &u.Name, &u.Password, &u.OrgID, &u.OrgName, &u.MFASecret, &active, &u.CreatedAt, &u.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
u.Active = active == 1
return &u, nil
}
// UserByID looks up a user by ID.
func UserByID(db *DB, userID string) (*User, error) {
var u User
var active int
err := db.Conn.QueryRow(
`SELECT user_id, email, name, password, org_id, org_name, mfa_secret, active, created_at, updated_at
FROM users WHERE user_id = ?`, userID,
).Scan(&u.UserID, &u.Email, &u.Name, &u.Password, &u.OrgID, &u.OrgName, &u.MFASecret, &active, &u.CreatedAt, &u.UpdatedAt)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
u.Active = active == 1
return &u, nil
}
// UserCount returns the number of users in the database.
func UserCount(db *DB) (int, error) {
var count int
err := db.Conn.QueryRow(`SELECT COUNT(*) FROM users`).Scan(&count)
return count, err
}
// ProjectsByUser returns all projects a user has access to.
func ProjectsByUser(db *DB, cfg *Config, userID string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT DISTINCT e.entry_id, e.project_id, e.parent_id, e.type, e.depth,
e.search_key, e.search_key2, e.summary, e.data, e.stage,
e.assignee_id, e.return_to_id, e.origin_id,
e.version, e.deleted_at, e.deleted_by, e.key_version,
e.created_at, e.updated_at, e.created_by
FROM entries e
JOIN access a ON a.project_id = e.project_id
WHERE a.user_id = ? AND a.revoked_at IS NULL AND e.type = 'project' AND e.deleted_at IS NULL
ORDER BY e.updated_at DESC`, userID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
e, err := scanEntryRow(rows)
if err != nil {
return nil, err
}
if err := unpackEntry(cfg, e); err != nil {
return nil, err
}
entries = append(entries, *e)
}
return entries, rows.Err()
}
// TasksByUser returns all entries assigned to a user across all projects.
func TasksByUser(db *DB, cfg *Config, userID string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, project_id, parent_id, type, depth,
search_key, search_key2, summary, data, stage,
assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version,
created_at, updated_at, created_by
FROM entries
WHERE assignee_id = ? AND deleted_at IS NULL
ORDER BY created_at DESC`, userID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
e, err := scanEntryRow(rows)
if err != nil {
return nil, err
}
if err := unpackEntry(cfg, e); err != nil {
return nil, err
}
entries = append(entries, *e)
}
return entries, rows.Err()
}
// EntriesByParent returns entries with a given parent ID.
func EntriesByParent(db *DB, cfg *Config, parentID string) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, project_id, parent_id, type, depth,
search_key, search_key2, summary, data, stage,
assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version,
created_at, updated_at, created_by
FROM entries
WHERE parent_id = ? AND deleted_at IS NULL
ORDER BY created_at ASC`, parentID,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
e, err := scanEntryRow(rows)
if err != nil {
return nil, err
}
if err := unpackEntry(cfg, e); err != nil {
return nil, err
}
entries = append(entries, *e)
}
return entries, rows.Err()
}
// EntryByID returns a single entry by ID (with RBAC bypass for internal use).
func EntryByID(db *DB, cfg *Config, entryID string) (*Entry, error) {
e, err := entryReadSystem(db, entryID)
if err != nil {
return nil, err
}
if e == nil {
return nil, nil
}
if err := unpackEntry(cfg, e); err != nil {
return nil, err
}
return e, nil
}
// RequestCountByProject returns the number of requests in a project.
func RequestCountByProject(db *DB, projectID string) (int, int, error) {
var total, open int
err := db.Conn.QueryRow(
`SELECT COUNT(*) FROM entries WHERE project_id = ? AND type = 'request' AND deleted_at IS NULL`,
projectID,
).Scan(&total)
if err != nil {
return 0, 0, err
}
err = db.Conn.QueryRow(
`SELECT COUNT(*) FROM entries WHERE project_id = ? AND type = 'request' AND deleted_at IS NULL AND stage = 'pre_dataroom'`,
projectID,
).Scan(&open)
return total, open, err
}
// WorkstreamCountByProject returns the number of workstreams in a project.
func WorkstreamCountByProject(db *DB, projectID string) (int, error) {
var count int
err := db.Conn.QueryRow(
`SELECT COUNT(*) FROM entries WHERE project_id = ? AND type = 'workstream' AND deleted_at IS NULL`,
projectID,
).Scan(&count)
return count, err
}
// ---------------------------------------------------------------------------
// Challenge operations (passwordless OTP auth)
// ---------------------------------------------------------------------------
// ChallengeCreate inserts a new email challenge.
func ChallengeCreate(db *DB, c *Challenge) error {
_, err := db.Conn.Exec(
`INSERT INTO challenges (challenge_id, email, code, created_at, expires_at, used)
VALUES (?,?,?,?,?,?)`,
c.ChallengeID, c.Email, c.Code, c.CreatedAt, c.ExpiresAt, c.Used,
)
return err
}
// ChallengeVerify looks up the most recent unused challenge for an email and marks it used if the code matches.
// Returns the challenge if valid, nil if not found or invalid.
func ChallengeVerify(db *DB, email, code string) (*Challenge, error) {
var c Challenge
var used int
err := db.Conn.QueryRow(
`SELECT challenge_id, email, code, created_at, expires_at, used
FROM challenges
WHERE email = ? AND used = 0
ORDER BY created_at DESC LIMIT 1`, email,
).Scan(&c.ChallengeID, &c.Email, &c.Code, &c.CreatedAt, &c.ExpiresAt, &used)
if err != nil {
return nil, nil // no challenge found
}
c.Used = used == 1
// Check expiry
if c.ExpiresAt < time.Now().UnixMilli() {
return nil, nil
}
// Check code - constant-time comparison to prevent timing attacks
if subtle.ConstantTimeCompare([]byte(c.Code), []byte(code)) != 1 {
return nil, nil
}
// Mark as used
_, err = db.Conn.Exec(`UPDATE challenges SET used = 1 WHERE challenge_id = ?`, c.ChallengeID)
if err != nil {
return nil, err
}
return &c, nil
}
// ---------------------------------------------------------------------------
// Admin query helpers
// ---------------------------------------------------------------------------
// AllUsers returns all users (for super admin).
func AllUsers(db *DB) ([]User, error) {
rows, err := db.Conn.Query(
`SELECT user_id, email, name, password, org_id, org_name, mfa_secret, active, created_at, updated_at
FROM users ORDER BY created_at DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var users []User
for rows.Next() {
var u User
var active int
if err := rows.Scan(&u.UserID, &u.Email, &u.Name, &u.Password, &u.OrgID, &u.OrgName, &u.MFASecret, &active, &u.CreatedAt, &u.UpdatedAt); err != nil {
return nil, err
}
u.Active = active == 1
users = append(users, u)
}
return users, rows.Err()
}
// AllProjects returns all projects (for super admin, no RBAC).
func AllProjects(db *DB, cfg *Config) ([]Entry, error) {
rows, err := db.Conn.Query(
`SELECT entry_id, project_id, parent_id, type, depth,
search_key, search_key2, summary, data, stage,
assignee_id, return_to_id, origin_id,
version, deleted_at, deleted_by, key_version,
created_at, updated_at, created_by
FROM entries WHERE type = 'project' AND deleted_at IS NULL
ORDER BY updated_at DESC`,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []Entry
for rows.Next() {
e, err := scanEntryRow(rows)
if err != nil {
return nil, err
}
if err := unpackEntry(cfg, e); err != nil {
return nil, err
}
entries = append(entries, *e)
}
return entries, rows.Err()
}
// AuditRecent returns the most recent audit entries (for super admin).
func AuditRecent(db *DB, limit int) ([]AuditEntry, error) {
rows, err := db.Conn.Query(
`SELECT id, project_id, actor_id, action, target_id, details, ip, ts
FROM audit ORDER BY ts DESC LIMIT ?`, limit,
)
if err != nil {
return nil, err
}
defer rows.Close()
var entries []AuditEntry
for rows.Next() {
var a AuditEntry
var targetID, ip *string
if err := rows.Scan(&a.ID, &a.ProjectID, &a.ActorID, &a.Action, &targetID, &a.Details, &ip, &a.Ts); err != nil {
return nil, err
}
if targetID != nil {
a.TargetID = *targetID
}
if ip != nil {
a.IP = *ip
}
entries = append(entries, a)
}
return entries, rows.Err()
}
// ---------------------------------------------------------------------------
// Access operations
// ---------------------------------------------------------------------------
// AccessGrant creates a new access grant.
func AccessGrant(db *DB, a *Access) error {
_, err := db.Conn.Exec(
`INSERT INTO access (id, project_id, workstream_id, user_id, role, ops, can_grant, granted_by, granted_at)
VALUES (?,?,?,?,?,?,?,?,?)`,
a.ID, a.ProjectID, a.WorkstreamID, a.UserID, a.Role, a.Ops, a.CanGrant, a.GrantedBy, a.GrantedAt,
)
return err
}
// AccessRevoke soft-revokes an access grant.
func AccessRevoke(db *DB, accessID, revokedBy string) error {
now := time.Now().UnixMilli()
_, err := db.Conn.Exec(
`UPDATE access SET revoked_at = ?, revoked_by = ? WHERE id = ? AND revoked_at IS NULL`,
now, revokedBy, accessID,
)
return err
}
// ---------------------------------------------------------------------------
// Session operations
// ---------------------------------------------------------------------------
// SessionCreate inserts a new session.
func SessionCreate(db *DB, s *Session) error {
_, err := db.Conn.Exec(
`INSERT INTO sessions (id, user_id, fingerprint, created_at, expires_at, revoked)
VALUES (?,?,?,?,?,?)`,
s.ID, s.UserID, s.Fingerprint, s.CreatedAt, s.ExpiresAt, s.Revoked,
)
return err
}
// SessionByID retrieves a session.
func SessionByID(db *DB, sessionID string) (*Session, error) {
var s Session
var revoked int
err := db.Conn.QueryRow(
`SELECT id, user_id, fingerprint, created_at, expires_at, revoked
FROM sessions WHERE id = ?`, sessionID,
).Scan(&s.ID, &s.UserID, &s.Fingerprint, &s.CreatedAt, &s.ExpiresAt, &revoked)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
if err != nil {
return nil, err
}
s.Revoked = revoked == 1
return &s, nil
}
// SessionRevoke revokes a session.
func SessionRevoke(db *DB, sessionID string) error {
_, err := db.Conn.Exec(
`UPDATE sessions SET revoked = 1 WHERE id = ?`, sessionID,
)
return err
}
// SessionRevokeAllForUser revokes all sessions for a user.
func SessionRevokeAllForUser(db *DB, userID string) error {
_, err := db.Conn.Exec(
`UPDATE sessions SET revoked = 1 WHERE user_id = ? AND revoked = 0`, userID,
)
return err
}
// ---------------------------------------------------------------------------
// Audit operations
// ---------------------------------------------------------------------------
// AuditLog records a security event.
func AuditLog(db *DB, cfg *Config, a *AuditEntry) error {
if a.ID == "" {
a.ID = uuid.New().String()
}
if a.Ts == 0 {
a.Ts = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`INSERT INTO audit (id, project_id, actor_id, action, target_id, details, ip, ts)
VALUES (?,?,?,?,?,?,?,?)`,
a.ID, a.ProjectID, a.ActorID, a.Action, a.TargetID, a.Details, a.IP, a.Ts,
)
return err
}
// ---------------------------------------------------------------------------
// Entry event operations
// ---------------------------------------------------------------------------
// EntryEventCreate records a workflow event on an entry.
func EntryEventCreate(db *DB, ev *EntryEvent) error {
if ev.ID == "" {
ev.ID = uuid.New().String()
}
if ev.Ts == 0 {
ev.Ts = time.Now().UnixMilli()
}
_, err := db.Conn.Exec(
`INSERT INTO entry_events (id, entry_id, actor_id, channel, action, data, ts)
VALUES (?,?,?,?,?,?,?)`,
ev.ID, ev.EntryID, ev.ActorID, ev.Channel, ev.Action, ev.Data, ev.Ts,
)
return err
}