package api import ( "context" "encoding/base64" "encoding/json" "log" "net" "net/http" "os" "path/filepath" "strings" "sync" "time" "github.com/johanj/vault1984/lib" ) // base64Decode handles both standard and url-safe base64 (with or without padding). func base64Decode(s string) ([]byte, error) { // Try url-safe first (no padding), then standard s = strings.TrimRight(s, "=") b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { b, err = base64.RawStdEncoding.DecodeString(s) } return b, err } // base64UrlEncode encodes bytes as base64url without padding. func base64UrlEncode(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) } type contextKey string const ( ctxActor contextKey = "actor" ctxSession contextKey = "session" ctxMCPToken contextKey = "mcp_token" ctxAgent contextKey = "agent" ctxDB contextKey = "db" ctxVaultKey contextKey = "vault_key" ctxVaultID contextKey = "vault_id" ) // ActorFromContext returns the actor type from request context. func ActorFromContext(ctx context.Context) string { v, ok := ctx.Value(ctxActor).(string) if !ok { return lib.ActorWeb } return v } // SessionFromContext returns the session from request context. func SessionFromContext(ctx context.Context) *lib.Session { v, _ := ctx.Value(ctxSession).(*lib.Session) return v } // MCPTokenFromContext returns the MCP token from request context (nil if normal session). func MCPTokenFromContext(ctx context.Context) *lib.MCPToken { v, _ := ctx.Value(ctxMCPToken).(*lib.MCPToken) return v } // AgentFromContext returns the agent from request context (nil if not an agent request). func AgentFromContext(ctx context.Context) *lib.Agent { v, _ := ctx.Value(ctxAgent).(*lib.Agent) return v } // DBFromContext returns the vault DB from request context (nil in self-hosted mode). func DBFromContext(ctx context.Context) *lib.DB { v, _ := ctx.Value(ctxDB).(*lib.DB) return v } // VaultKeyFromContext returns the derived vault key from request context (nil in self-hosted mode). func VaultKeyFromContext(ctx context.Context) []byte { v, _ := ctx.Value(ctxVaultKey).([]byte) return v } // VaultIDFromContext returns the vault ID from request context (0 in self-hosted mode). func VaultIDFromContext(ctx context.Context) int64 { v, _ := ctx.Value(ctxVaultID).(int64) return v } // L1Middleware extracts L1 from Bearer token and opens the vault DB. // Fully stateless: L1 arrives with every request, is used, then forgotten. // No sessions, no stored keys. The server has zero keys of its own. // // Self-hosted mode: finds vault DB by globbing vault1984-* files. // Hosted mode: finds vault DB by base64url(L1[0:4]) → filename. func L1Middleware(dataDir string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") // No auth = unauthenticated request (registration, login begin, etc.) if auth == "" || !strings.HasPrefix(auth, "Bearer ") { // Try to open vault DB without L1 (for unauthenticated endpoints) matches, _ := filepath.Glob(filepath.Join(dataDir, "vault1984-*")) if len(matches) > 0 { db, err := lib.OpenDB(matches[0]) if err == nil { defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) next.ServeHTTP(w, r.WithContext(ctx)) return } } // Also try legacy .db files for migration matches, _ = filepath.Glob(filepath.Join(dataDir, "????????.db")) if len(matches) > 0 { db, err := lib.OpenDB(matches[0]) if err == nil { defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) next.ServeHTTP(w, r.WithContext(ctx)) return } } next.ServeHTTP(w, r) return } // Decode Bearer → L1 (8 bytes) bearerB64 := strings.TrimPrefix(auth, "Bearer ") l1Raw, err := base64Decode(bearerB64) if err != nil || len(l1Raw) != 8 { ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer") return } // Normalize L1: 8 bytes → 16 bytes (same as crypto.js normalize_key) l1Key := lib.NormalizeKey(l1Raw) // Find vault DB by first 4 bytes of L1 vaultPrefix := base64UrlEncode(l1Raw[:4]) dbPath := filepath.Join(dataDir, "vault1984-"+vaultPrefix) var db *lib.DB if _, err := os.Stat(dbPath); err == nil { db, err = lib.OpenDB(dbPath) if err != nil { log.Printf("vault open error (%s): %v", dbPath, err) ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault") return } } if db == nil { ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found") return } defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) ctx = context.WithValue(ctx, ctxVaultKey, l1Key) ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb) // Agent-specific auth via X-Agent header agentName := r.Header.Get("X-Agent") if agentName != "" { ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent) // Check vault-level lock vaultLock, _ := lib.VaultLockGet(db) if vaultLock != nil && vaultLock.Locked { ErrorResponse(w, http.StatusForbidden, "vault_locked", "Vault is locked: "+vaultLock.LockedReason) return } agent, err := lib.AgentGetByName(db, agentName) if err != nil { ErrorResponse(w, http.StatusInternalServerError, "agent_error", "Agent lookup failed") return } if agent == nil { ErrorResponse(w, http.StatusUnauthorized, "unknown_agent", "Unknown agent") return } if agent.Status == lib.AgentStatusRevoked { ErrorResponse(w, http.StatusUnauthorized, "agent_revoked", "Agent has been revoked") return } if agent.Status == lib.AgentStatusLocked { ErrorResponse(w, http.StatusForbidden, "agent_locked", "Agent is locked: "+agent.LockedReason) return } // IP whitelist check ip := realIP(r) if len(agent.IPWhitelist) == 1 && agent.IPWhitelist[0] == "init" { lib.AgentUpdateWhitelist(db, int64(agent.ID), []string{ip}) agent.IPWhitelist = []string{ip} lib.AuditLog(db, &lib.AuditEvent{ Action: "agent_ip_init", Actor: lib.ActorAgent, IPAddr: ip, Title: agent.Name + " → " + ip, }) } if !agentIPAllowed(agent, ip) { lib.VaultLockSet(db, true, "Non-whitelisted IP "+ip+" for agent "+agent.Name) lib.AgentUpdateStatus(db, int64(agent.ID), lib.AgentStatusLocked, "non-whitelisted IP: "+ip) lib.AuditLog(db, &lib.AuditEvent{ Action: lib.ActionIPViolation, Actor: lib.ActorAgent, IPAddr: ip, Title: agent.Name, }) ErrorResponse(w, http.StatusForbidden, "vault_locked", "Access from non-whitelisted IP. Vault locked.") return } // Rate limit reqPath := r.URL.Path lib.AgentRequestLog(db, int64(agent.ID), ip, reqPath) countMin, _ := lib.AgentRequestCountMinute(db, int64(agent.ID)) countHour, _ := lib.AgentRequestCountHour(db, int64(agent.ID)) if countMin > agent.RateLimitMinute || countHour > agent.RateLimitHour { lib.AgentUpdateStatus(db, int64(agent.ID), lib.AgentStatusLocked, "rate limit exceeded") lib.AuditLog(db, &lib.AuditEvent{ Action: lib.ActionRateExceeded, Actor: lib.ActorAgent, IPAddr: ip, Title: agent.Name, }) ErrorResponse(w, http.StatusTooManyRequests, "agent_locked", "Rate limit exceeded. Agent locked.") return } lib.AgentUpdateLastUsed(db, int64(agent.ID), ip) ctx = context.WithValue(ctx, ctxAgent, agent) } next.ServeHTTP(w, r.WithContext(ctx)) }) } } // agentIPAllowed checks if the given IP is allowed by the agent's whitelist. // Supports: single IPs, CIDR notation, DNS names. func agentIPAllowed(agent *lib.Agent, ip string) bool { if len(agent.IPWhitelist) == 0 { return true } parsedIP := net.ParseIP(ip) for _, entry := range agent.IPWhitelist { // CIDR if strings.Contains(entry, "/") { _, cidr, err := net.ParseCIDR(entry) if err == nil && parsedIP != nil && cidr.Contains(parsedIP) { return true } continue } // Single IP if net.ParseIP(entry) != nil { if entry == ip { return true } continue } // DNS name — resolve and compare addrs, err := net.LookupHost(entry) if err == nil { for _, addr := range addrs { if addr == ip { return true } } } } return false } // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusWriter{ResponseWriter: w, status: 200} next.ServeHTTP(wrapped, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start)) }) } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } // RateLimitMiddleware implements per-IP rate limiting. func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler { var mu sync.Mutex clients := make(map[string]*rateLimitEntry) go func() { for { time.Sleep(time.Minute) mu.Lock() now := time.Now() for ip, entry := range clients { if now.Sub(entry.windowStart) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := realIP(r) mu.Lock() entry, exists := clients[ip] now := time.Now() if !exists || now.Sub(entry.windowStart) > time.Minute { entry = &rateLimitEntry{windowStart: now, count: 0} clients[ip] = entry } entry.count++ count := entry.count mu.Unlock() if count > requestsPerMinute { ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests") return } next.ServeHTTP(w, r) }) } } type rateLimitEntry struct { windowStart time.Time count int } // CORSMiddleware handles CORS headers. func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Allow localhost and 127.0.0.1 for development if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers to all responses. func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // CSP allowing localhost and 127.0.0.1 for development w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' localhost 127.0.0.1") next.ServeHTTP(w, r) }) } // ErrorResponse sends a standard JSON error response. func ErrorResponse(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": message, "code": code, }) } // JSONResponse sends a standard JSON success response. func JSONResponse(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr }