clavitor/clavis/clavis-vault/api/middleware.go

448 lines
14 KiB
Go

package api
import (
"context"
"encoding/base64"
"encoding/hex"
"encoding/json"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/johanj/clavitor/edition"
"github.com/johanj/clavitor/lib"
)
// base64Decode handles both standard and url-safe base64 (with or without padding).
func base64Decode(s string) ([]byte, error) {
s = strings.TrimRight(s, "=")
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
b, err = base64.RawStdEncoding.DecodeString(s)
}
return b, err
}
func base64UrlEncode(b []byte) string {
return base64.RawURLEncoding.EncodeToString(b)
}
type contextKey string
const (
ctxActor contextKey = "actor"
ctxAgent contextKey = "agent"
ctxDB contextKey = "db"
ctxVaultKey contextKey = "vault_key"
)
func ActorFromContext(ctx context.Context) string {
v, ok := ctx.Value(ctxActor).(string)
if !ok {
return lib.ActorWeb
}
return v
}
func AgentFromContext(ctx context.Context) *lib.AgentData {
v, _ := ctx.Value(ctxAgent).(*lib.AgentData)
return v
}
func DBFromContext(ctx context.Context) *lib.DB {
v, _ := ctx.Value(ctxDB).(*lib.DB)
return v
}
func VaultKeyFromContext(ctx context.Context) []byte {
v, _ := ctx.Value(ctxVaultKey).([]byte)
return v
}
// IsAgentRequest returns true if the request was made with a cvt_ agent token.
func IsAgentRequest(r *http.Request) bool {
return AgentFromContext(r.Context()) != nil
}
// L1Middleware extracts L1 from Bearer token and opens the vault DB.
// Supports two token formats:
// - cvt_ prefix: CVT wire token (type 0x00) — extract L0 for routing, L1 for encryption, agent_id for scope lookup
// - raw base64: legacy L1 bearer (8 bytes) — vault owner, full access
func L1Middleware(dataDir string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
// No auth = unauthenticated request (registration, login, etc.)
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
matches, _ := filepath.Glob(filepath.Join(dataDir, "clavitor-*"))
if len(matches) > 0 {
db, err := lib.OpenDB(matches[0])
if err == nil {
defer db.Close()
ctx := context.WithValue(r.Context(), ctxDB, db)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
next.ServeHTTP(w, r)
return
}
bearerVal := strings.TrimPrefix(auth, "Bearer ")
if strings.HasPrefix(bearerVal, "cvt_") {
// --- CVT wire token ---
l0, l1Raw, agentID, err := lib.ParseWireToken(bearerVal)
if err != nil {
ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid CVT token")
return
}
l1Key := lib.NormalizeKey(l1Raw)
vaultPrefix := base64UrlEncode(l0)
dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix)
db, err := lib.OpenDB(dbPath)
if err != nil {
ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found")
return
}
defer db.Close()
// Look up agent by agent_id via blind index
agentIDHex := hex.EncodeToString(agentID)
agent, err := lib.AgentLookup(db, l1Key, agentIDHex)
if err != nil {
// Community: Log to stderr. Commercial: Also POSTs to telemetry endpoint.
// This indicates DB corruption, decryption failure, or disk issues.
edition.Current.AlertOperator(r.Context(), "auth_system_error",
"Agent lookup failed (DB/decryption error)", map[string]any{"error": err.Error()})
ErrorResponse(w, http.StatusInternalServerError, "system_error", "Authentication system error - contact support")
return
}
if agent == nil {
ErrorResponse(w, http.StatusUnauthorized, "unknown_agent", "Invalid or revoked token")
return
}
clientIP := realIP(r)
// IP whitelist: first contact fills it, subsequent requests checked
if agent.AllowedIPs == "" {
// First contact — record the IP
//
// SECURITY NOTE: There is a theoretical race condition here.
// If two parallel requests from different IPs arrive simultaneously
// for the same agent's first contact, both could pass the empty check
// before either writes to the database.
//
// This was reviewed and accepted because:
// 1. Requires a stolen agent token (already a compromise scenario)
// 2. Requires two agents with the same token racing first contact
// 3. The "loser" simply won't be auto-whitelisted (one IP wins)
// 4. Cannot be reproduced in testing; practically impossible to trigger
// 5. Per-vault SQLite isolation limits blast radius
//
// The fix would require plaintext allowed_ips column + atomic conditional
// update. Not worth the complexity for this edge case.
agent.AllowedIPs = clientIP
if err := lib.AgentUpdateAllowedIPs(db, l1Key, agent); err != nil {
log.Printf("agent %s: failed to record first-contact IP: %v", agent.Name, err)
ErrorResponse(w, http.StatusInternalServerError, "ip_record_failed", "Failed to record agent IP")
return
}
log.Printf("agent %s: first contact from %s, IP recorded", agent.Name, clientIP)
} else if !lib.AgentIPAllowed(agent, clientIP) {
log.Printf("agent %s: blocked IP %s (allowed: %s)", agent.Name, clientIP, agent.AllowedIPs)
ErrorResponse(w, http.StatusForbidden, "ip_blocked", "IP not allowed for this agent")
return
}
// Per-agent rate limiting
if agent.RateLimit > 0 {
if !agentRateLimiter.allow(agent.AgentID, agent.RateLimit) {
ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Agent rate limit exceeded")
return
}
}
ctx := context.WithValue(r.Context(), ctxDB, db)
ctx = context.WithValue(ctx, ctxVaultKey, l1Key)
ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent)
ctx = context.WithValue(ctx, ctxAgent, agent)
next.ServeHTTP(w, r.WithContext(ctx))
} else {
// --- Legacy L1 bearer (web UI / extension) ---
l1Raw, err := base64Decode(bearerVal)
if err != nil || len(l1Raw) != 8 {
ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer")
return
}
l1Key := lib.NormalizeKey(l1Raw)
vaultPrefix := base64UrlEncode(l1Raw[:4])
dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix)
var db *lib.DB
if _, err := os.Stat(dbPath); err == nil {
db, err = lib.OpenDB(dbPath)
if err != nil {
ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault")
return
}
}
if db == nil {
ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found")
return
}
defer db.Close()
ctx := context.WithValue(r.Context(), ctxDB, db)
ctx = context.WithValue(ctx, ctxVaultKey, l1Key)
ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb)
next.ServeHTTP(w, r.WithContext(ctx))
}
})
}
}
// LoggingMiddleware logs HTTP requests.
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: 200}
next.ServeHTTP(wrapped, r)
log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start))
})
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
// RateLimitMiddleware implements per-IP rate limiting.
func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler {
var mu sync.Mutex
clients := make(map[string]*rateLimitEntry)
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
now := time.Now()
for ip, entry := range clients {
if now.Sub(entry.windowStart) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := realIP(r)
mu.Lock()
entry, exists := clients[ip]
now := time.Now()
if !exists || now.Sub(entry.windowStart) > time.Minute {
entry = &rateLimitEntry{windowStart: now, count: 0}
clients[ip] = entry
}
entry.count++
count := entry.count
mu.Unlock()
if count > requestsPerMinute {
ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests")
return
}
next.ServeHTTP(w, r)
})
}
}
type rateLimitEntry struct {
windowStart time.Time
count int
}
// Per-agent rate limiter (keyed by agent ID, not IP).
var agentRateLimiter = newAgentLimiter()
type agentLimiter struct {
mu sync.Mutex
agents map[string]*rateLimitEntry
}
func newAgentLimiter() *agentLimiter {
al := &agentLimiter{agents: make(map[string]*rateLimitEntry)}
go func() {
for {
time.Sleep(time.Minute)
al.mu.Lock()
now := time.Now()
for id, e := range al.agents {
if now.Sub(e.windowStart) > time.Minute {
delete(al.agents, id)
}
}
al.mu.Unlock()
}
}()
return al
}
func (al *agentLimiter) allow(agentID string, maxPerMinute int) bool {
al.mu.Lock()
defer al.mu.Unlock()
now := time.Now()
e, exists := al.agents[agentID]
if !exists || now.Sub(e.windowStart) > time.Minute {
e = &rateLimitEntry{windowStart: now, count: 0}
al.agents[agentID] = e
}
e.count++
return e.count <= maxPerMinute
}
// CORSMiddleware handles CORS headers.
func CORSMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers.
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// CSP: removed unused tailwindcss, tightened img-src to self+data only
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data:; connect-src 'self' localhost 127.0.0.1 https://clavitor.ai")
next.ServeHTTP(w, r)
})
}
// MaxBodySizeMiddleware limits request body size and rejects binary content.
// Allows 64KB max for markdown notes. Rejects binary data (images, executables, etc).
func MaxBodySizeMiddleware(maxBytes int64) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Security: Reject binary content types
contentType := r.Header.Get("Content-Type")
if isBinaryContentType(contentType) {
ErrorResponse(w, http.StatusUnsupportedMediaType, "binary_not_allowed",
"Binary content not allowed. Only text/markdown data accepted.")
return
}
r.Body = http.MaxBytesReader(w, r.Body, maxBytes)
next.ServeHTTP(w, r)
})
}
}
// isBinaryContentType detects common binary content types.
func isBinaryContentType(ct string) bool {
ct = strings.ToLower(ct)
binaryTypes := []string{
"image/", "audio/", "video/", "application/pdf",
"application/zip", "application/gzip", "application/octet-stream",
"application/x-executable", "application/x-dosexec",
"multipart/form-data", // usually file uploads
}
for _, bt := range binaryTypes {
if strings.Contains(ct, bt) {
return true
}
}
return false
}
// ErrorResponse sends a JSON error response.
func ErrorResponse(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{"error": message, "code": code})
}
// JSONResponse sends a JSON success response.
func JSONResponse(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
// tarpitHandler wastes scanner resources.
var tarpitSem = make(chan struct{}, 1000)
func tarpitHandler(w http.ResponseWriter, r *http.Request) {
select {
case tarpitSem <- struct{}{}:
defer func() { <-tarpitSem }()
default:
if hj, ok := w.(http.Hijacker); ok {
conn, _, err := hj.Hijack()
if err == nil {
conn.Close()
}
}
return
}
log.Printf("tarpit: %s %s from %s", r.Method, r.URL.Path, realIP(r))
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(200)
flusher, canFlush := w.(http.Flusher)
for i := 0; i < 30; i++ {
_, err := w.Write([]byte(" "))
if err != nil {
return
}
if canFlush {
flusher.Flush()
}
time.Sleep(time.Second)
}
}
func realIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
return strings.TrimSpace(parts[0])
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}