package api import ( "context" "encoding/base64" "encoding/hex" "encoding/json" "log" "net/http" "os" "path/filepath" "strings" "sync" "time" "github.com/johanj/clavitor/edition" "github.com/johanj/clavitor/lib" ) // base64Decode handles both standard and url-safe base64 (with or without padding). func base64Decode(s string) ([]byte, error) { s = strings.TrimRight(s, "=") b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { b, err = base64.RawStdEncoding.DecodeString(s) } return b, err } func base64UrlEncode(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) } type contextKey string const ( ctxActor contextKey = "actor" ctxAgent contextKey = "agent" ctxDB contextKey = "db" ctxVaultKey contextKey = "vault_key" ) func ActorFromContext(ctx context.Context) string { v, ok := ctx.Value(ctxActor).(string) if !ok { return lib.ActorWeb } return v } func AgentFromContext(ctx context.Context) *lib.AgentData { v, _ := ctx.Value(ctxAgent).(*lib.AgentData) return v } func DBFromContext(ctx context.Context) *lib.DB { v, _ := ctx.Value(ctxDB).(*lib.DB) return v } func VaultKeyFromContext(ctx context.Context) []byte { v, _ := ctx.Value(ctxVaultKey).([]byte) return v } // IsAgentRequest returns true if the request was made with a cvt_ agent token. func IsAgentRequest(r *http.Request) bool { return AgentFromContext(r.Context()) != nil } // L1Middleware extracts L1 from Bearer token and opens the vault DB. // Supports two token formats: // - cvt_ prefix: CVT wire token (type 0x00) — extract L0 for routing, L1 for encryption, agent_id for scope lookup // - raw base64: legacy L1 bearer (8 bytes) — vault owner, full access func L1Middleware(dataDir string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") // No auth = unauthenticated request (registration, login, etc.) // We CANNOT open a random vault via wildcard - in hosted mode there are many vaults. // The request must either provide a bearer token or target a specific vault via L0. if auth == "" || !strings.HasPrefix(auth, "Bearer ") { next.ServeHTTP(w, r) return } bearerVal := strings.TrimPrefix(auth, "Bearer ") if strings.HasPrefix(bearerVal, "cvt_") { // --- CVT wire token --- l0, l1Raw, agentID, err := lib.ParseWireToken(bearerVal) if err != nil { ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid CVT token") return } l1Key := lib.NormalizeKey(l1Raw) vaultPrefix := base64UrlEncode(l0) dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix) db, err := lib.OpenDB(dbPath) if err != nil { ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found") return } defer db.Close() // Look up agent by agent_id via blind index agentIDHex := hex.EncodeToString(agentID) agent, err := lib.AgentLookup(db, l1Key, agentIDHex) if err != nil { // Community: Log to stderr. Commercial: Also POSTs to telemetry endpoint. // This indicates DB corruption, decryption failure, or disk issues. edition.Current.AlertOperator(r.Context(), "auth_system_error", "Agent lookup failed (DB/decryption error)", map[string]any{"error": err.Error()}) ErrorResponse(w, http.StatusInternalServerError, "system_error", "Authentication system error - contact support") return } if agent == nil { ErrorResponse(w, http.StatusUnauthorized, "unknown_agent", "Invalid or revoked token") return } clientIP := realIP(r) // IP whitelist: first contact fills it, subsequent requests checked if agent.AllowedIPs == "" { // First contact — record the IP // // SECURITY NOTE: There is a theoretical race condition here. // If two parallel requests from different IPs arrive simultaneously // for the same agent's first contact, both could pass the empty check // before either writes to the database. // // This was reviewed and accepted because: // 1. Requires a stolen agent token (already a compromise scenario) // 2. Requires two agents with the same token racing first contact // 3. The "loser" simply won't be auto-whitelisted (one IP wins) // 4. Cannot be reproduced in testing; practically impossible to trigger // 5. Per-vault SQLite isolation limits blast radius // // The fix would require plaintext allowed_ips column + atomic conditional // update. Not worth the complexity for this edge case. agent.AllowedIPs = clientIP if err := lib.AgentUpdateAllowedIPs(db, l1Key, agent); err != nil { log.Printf("agent %s: failed to record first-contact IP: %v", agent.Name, err) ErrorResponse(w, http.StatusInternalServerError, "ip_record_failed", "Failed to record agent IP") return } log.Printf("agent %s: first contact from %s, IP recorded", agent.Name, clientIP) } else if !lib.AgentIPAllowed(agent, clientIP) { log.Printf("agent %s: blocked IP %s (allowed: %s)", agent.Name, clientIP, agent.AllowedIPs) ErrorResponse(w, http.StatusForbidden, "ip_blocked", "IP not allowed for this agent") return } // Locked agents are refused immediately, before any handler runs. // This is the second-strike state — the owner has to PRF-unlock // before this agent can do anything again. if agent.Locked { lib.AuditLog(db, &lib.AuditEvent{ Action: "agent_locked_request_refused", Actor: lib.ActorAgent, Title: agent.Name, IPAddr: clientIP, }) ErrorResponse(w, http.StatusLocked, "agent_locked", "Agent is locked. Owner unlock required.") return } // Per-agent rate limiting (unique-entries quota) is enforced in // handlers that read a specific entry — GetEntry, MatchURL, // HandleListAlternates — via agentReadEntry(). Middleware // doesn't know which entry the agent is asking for here. ctx := context.WithValue(r.Context(), ctxDB, db) ctx = context.WithValue(ctx, ctxVaultKey, l1Key) ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent) ctx = context.WithValue(ctx, ctxAgent, agent) next.ServeHTTP(w, r.WithContext(ctx)) } else { // --- Legacy L1 bearer (web UI / extension) --- l1Raw, err := base64Decode(bearerVal) if err != nil || len(l1Raw) != 8 { ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer") return } l1Key := lib.NormalizeKey(l1Raw) vaultPrefix := base64UrlEncode(l1Raw[:4]) dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix) var db *lib.DB if _, err := os.Stat(dbPath); err == nil { db, err = lib.OpenDB(dbPath) if err != nil { ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault") return } } if db == nil { ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found") return } defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) ctx = context.WithValue(ctx, ctxVaultKey, l1Key) ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb) next.ServeHTTP(w, r.WithContext(ctx)) } }) } } // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusWriter{ResponseWriter: w, status: 200} next.ServeHTTP(wrapped, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start)) }) } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } // RateLimitMiddleware implements per-(IP, method, path) rate limiting. // // The bucket key is the *request identity*, not just the source IP. Same // endpoint from the same IP shares a counter; different endpoints get // independent counters. This means: // // - SPA loading 8 different endpoints on first paint: 8 buckets at count 1, // none blocked. // - Brute-forcer hammering /api/auth/login/complete with different bodies: // one bucket, blocked at requestsPerMinute attempts. Body is intentionally // NOT part of the key — if it were, varying the body would bypass the // limiter and brute-force protection would be gone. // - Polling the same endpoint every few seconds: shares a bucket, counts up. // Blocked at requestsPerMinute, which is what we want. func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler { var mu sync.Mutex buckets := make(map[string]*rateLimitEntry) go func() { for { time.Sleep(time.Minute) mu.Lock() now := time.Now() for k, entry := range buckets { if now.Sub(entry.windowStart) > time.Minute { delete(buckets, k) } } mu.Unlock() } }() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Owner-only bulk endpoints are exempt from the global rate limit. // /api/entries/batch refuses agents at the handler entry // (CreateEntryBatch returns 403 for any cvt_-token request), so this // path is owner-only by handler enforcement. The harvester defense // for agent-reachable paths lives in the per-agent unique-entries // quota (agentReadEntry), not here. Throttling the import flow // would only DOS the legitimate import — no defense gained. if r.URL.Path == "/api/entries/batch" { next.ServeHTTP(w, r) return } key := realIP(r) + "|" + r.Method + "|" + r.URL.Path mu.Lock() entry, exists := buckets[key] now := time.Now() if !exists || now.Sub(entry.windowStart) > time.Minute { entry = &rateLimitEntry{windowStart: now, count: 0} buckets[key] = entry } entry.count++ count := entry.count mu.Unlock() if count > requestsPerMinute { ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests") return } next.ServeHTTP(w, r) }) } } type rateLimitEntry struct { windowStart time.Time count int } // Per-agent rate limiter — tracks UNIQUE entry IDs read per minute / per hour. // // Repeated reads of the same credential do NOT count: an agent legitimately // re-fetching the same credential to log into the same site many times stays // at unique-count = 1. The limit fires only when the agent starts touching // many *different* credentials, which is the harvesting pattern we care about. // // Two windows run independently: // - RateLimit → unique entries per rolling minute // - RateLimitHour → unique entries per rolling hour // // A limit of 0 means unlimited for that window. var agentRateLimiter = newAgentLimiter() type agentLimiterEntry struct { minuteWindowStart time.Time minuteEntries map[string]struct{} hourWindowStart time.Time hourEntries map[string]struct{} } type agentLimiter struct { mu sync.Mutex agents map[string]*agentLimiterEntry } func newAgentLimiter() *agentLimiter { al := &agentLimiter{agents: make(map[string]*agentLimiterEntry)} go func() { for { time.Sleep(5 * time.Minute) al.mu.Lock() now := time.Now() for id, e := range al.agents { // Drop agents whose hour window has expired (no recent activity). if now.Sub(e.hourWindowStart) > time.Hour { delete(al.agents, id) } } al.mu.Unlock() } }() return al } // LimitResult signals which window (if any) blocked an agent read. type LimitResult int const ( LimitAllowed LimitResult = iota LimitMinuteHit // Minute-window cap reached. Soft throttle, no strike. LimitHourHit // Hour-window cap reached. Strike — caller must persist. ) // agentReadEntry enforces the per-agent unique-entries quota when an agent // fetches a credential. Call immediately after the AgentCanAccess scope check. // // On hour-limit hit, applies the strike-and-lock policy: // - First strike (or > 2h since last strike): record the strike, throttle. // - Second strike within 2h: lock the agent, persist Locked=true, audit-log. // // Both persistence and audit happen in here so the call sites stay one-liners. // // Returns true if the read may proceed; false if blocked. No-op for nil agent // (vault owner / web UI) and for agents with both limits set to 0 (unlimited). func agentReadEntry(agent *lib.AgentData, entryID string, db *lib.DB, vk []byte) bool { if agent == nil { return true } if agent.Locked { // Belt-and-suspenders. L1Middleware blocks locked agents at the top // before any handler runs, but if that check is ever bypassed this // catches it at the per-entry layer. return false } if agent.RateLimit == 0 && agent.RateLimitHour == 0 { return true } result := agentRateLimiter.checkEntry(agent.AgentID, entryID, agent.RateLimit, agent.RateLimitHour) switch result { case LimitAllowed: return true case LimitMinuteHit: return false // soft throttle, no strike case LimitHourHit: // Strike-and-lock policy. now := time.Now().Unix() const strikeWindowSeconds = 2 * 60 * 60 // 2 hours secondStrike := agent.LastStrikeAt > 0 && (now-agent.LastStrikeAt) < strikeWindowSeconds if secondStrike { // Lock the agent. Persist + audit + update in-memory state. if err := lib.AgentLockWithStrike(db, vk, agent.EntryID, now); err != nil { log.Printf("agent %s: failed to persist lock: %v", agent.Name, err) } lib.AuditLog(db, &lib.AuditEvent{ Action: "agent_locked", Actor: lib.ActorAgent, Title: agent.Name, }) log.Printf("agent %s: LOCKED after second hour-limit strike (last strike %ds ago)", agent.Name, now-agent.LastStrikeAt) agent.Locked = true agent.LastStrikeAt = now } else { // First strike. Record timestamp; don't lock. if err := lib.AgentRecordStrike(db, vk, agent.EntryID, now); err != nil { log.Printf("agent %s: failed to persist strike: %v", agent.Name, err) } lib.AuditLog(db, &lib.AuditEvent{ Action: "agent_strike", Actor: lib.ActorAgent, Title: agent.Name, }) log.Printf("agent %s: hour-limit strike recorded", agent.Name) agent.LastStrikeAt = now } return false } return false } // checkEntry records the agent's intent to read entryID and returns: // - LimitAllowed — entry recorded, read allowed // - LimitMinuteHit — minute-window cap exceeded; entry NOT added to either set // - LimitHourHit — hour-window cap exceeded; entry NOT added to either set // // Repeated calls with the same entryID inside an active window are free: // the entry is already in the set, len(set) does not grow, the call returns // LimitAllowed. func (al *agentLimiter) checkEntry(agentID, entryID string, maxPerMinute, maxPerHour int) LimitResult { al.mu.Lock() defer al.mu.Unlock() now := time.Now() e, exists := al.agents[agentID] if !exists { e = &agentLimiterEntry{ minuteWindowStart: now, minuteEntries: make(map[string]struct{}), hourWindowStart: now, hourEntries: make(map[string]struct{}), } al.agents[agentID] = e } // Roll the windows if they have expired. if now.Sub(e.minuteWindowStart) > time.Minute { e.minuteWindowStart = now e.minuteEntries = make(map[string]struct{}) } if now.Sub(e.hourWindowStart) > time.Hour { e.hourWindowStart = now e.hourEntries = make(map[string]struct{}) } _, alreadyMinute := e.minuteEntries[entryID] _, alreadyHour := e.hourEntries[entryID] // Hour limit takes precedence — it's the strike trigger. if maxPerHour > 0 && !alreadyHour && len(e.hourEntries) >= maxPerHour { return LimitHourHit } if maxPerMinute > 0 && !alreadyMinute && len(e.minuteEntries) >= maxPerMinute { return LimitMinuteHit } e.minuteEntries[entryID] = struct{}{} e.hourEntries[entryID] = struct{}{} return LimitAllowed } // allowEntry is a thin compatibility wrapper around checkEntry for callers // that only need a yes/no answer (currently the unit tests). func (al *agentLimiter) allowEntry(agentID, entryID string, maxPerMinute, maxPerHour int) bool { return al.checkEntry(agentID, entryID, maxPerMinute, maxPerHour) == LimitAllowed } // CORSMiddleware handles CORS headers. func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers. func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // CSP: removed unused tailwindcss, tightened img-src to self+data only w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' localhost 127.0.0.1 https://clavitor.ai") next.ServeHTTP(w, r) }) } // MaxBodySizeMiddleware limits request body size and rejects binary content. // Allows 64KB max for markdown notes. Rejects binary data (images, executables, etc). func MaxBodySizeMiddleware(maxBytes int64) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Security: Reject binary content types contentType := r.Header.Get("Content-Type") if isBinaryContentType(contentType) { ErrorResponse(w, http.StatusUnsupportedMediaType, "binary_not_allowed", "Binary content not allowed. Only text/markdown data accepted.") return } r.Body = http.MaxBytesReader(w, r.Body, maxBytes) next.ServeHTTP(w, r) }) } } // isBinaryContentType detects common binary content types. func isBinaryContentType(ct string) bool { ct = strings.ToLower(ct) binaryTypes := []string{ "image/", "audio/", "video/", "application/pdf", "application/zip", "application/gzip", "application/octet-stream", "application/x-executable", "application/x-dosexec", "multipart/form-data", // usually file uploads } for _, bt := range binaryTypes { if strings.Contains(ct, bt) { return true } } return false } // ErrorResponse sends a JSON error response. func ErrorResponse(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{"error": message, "code": code}) } // JSONResponse sends a JSON success response. func JSONResponse(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr }