package api import ( "context" "encoding/base64" "encoding/json" "log" "net/http" "os" "path/filepath" "strings" "sync" "time" "github.com/johanj/clavitor/lib" ) // base64Decode handles both standard and url-safe base64 (with or without padding). func base64Decode(s string) ([]byte, error) { // Try url-safe first (no padding), then standard s = strings.TrimRight(s, "=") b, err := base64.RawURLEncoding.DecodeString(s) if err != nil { b, err = base64.RawStdEncoding.DecodeString(s) } return b, err } // base64UrlEncode encodes bytes as base64url without padding. func base64UrlEncode(b []byte) string { return base64.RawURLEncoding.EncodeToString(b) } type contextKey string const ( ctxActor contextKey = "actor" ctxSession contextKey = "session" ctxAgent contextKey = "agent" ctxDB contextKey = "db" ctxVaultKey contextKey = "vault_key" ctxVaultID contextKey = "vault_id" ) // ActorFromContext returns the actor type from request context. func ActorFromContext(ctx context.Context) string { v, ok := ctx.Value(ctxActor).(string) if !ok { return lib.ActorWeb } return v } // SessionFromContext returns the session from request context. func SessionFromContext(ctx context.Context) *lib.Session { v, _ := ctx.Value(ctxSession).(*lib.Session) return v } // AgentFromContext returns the agent from request context (nil if not an agent request). func AgentFromContext(ctx context.Context) *lib.Agent { v, _ := ctx.Value(ctxAgent).(*lib.Agent) return v } // DBFromContext returns the vault DB from request context (nil in self-hosted mode). func DBFromContext(ctx context.Context) *lib.DB { v, _ := ctx.Value(ctxDB).(*lib.DB) return v } // VaultKeyFromContext returns the derived vault key from request context (nil in self-hosted mode). func VaultKeyFromContext(ctx context.Context) []byte { v, _ := ctx.Value(ctxVaultKey).([]byte) return v } // VaultIDFromContext returns the vault ID from request context (0 in self-hosted mode). func VaultIDFromContext(ctx context.Context) int64 { v, _ := ctx.Value(ctxVaultID).(int64) return v } // L1Middleware extracts L1 from Bearer token and opens the vault DB. // Fully stateless: L1 arrives with every request, is used, then forgotten. // No sessions, no stored keys. The server has zero keys of its own. // // Self-hosted mode: finds vault DB by globbing clavitor-* files. // Hosted mode: finds vault DB by base64url(L1[0:4]) → filename. func L1Middleware(dataDir string) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") // No auth = unauthenticated request (registration, login begin, etc.) if auth == "" || !strings.HasPrefix(auth, "Bearer ") { // Try to open vault DB without L1 (for unauthenticated endpoints) matches, _ := filepath.Glob(filepath.Join(dataDir, "clavitor-*")) if len(matches) > 0 { db, err := lib.OpenDB(matches[0]) if err == nil { defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) next.ServeHTTP(w, r.WithContext(ctx)) return } } // Also try legacy .db files for migration matches, _ = filepath.Glob(filepath.Join(dataDir, "????????.db")) if len(matches) > 0 { db, err := lib.OpenDB(matches[0]) if err == nil { defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) next.ServeHTTP(w, r.WithContext(ctx)) return } } next.ServeHTTP(w, r) return } bearerVal := strings.TrimPrefix(auth, "Bearer ") var l1Raw []byte var agent *lib.Agent if strings.HasPrefix(bearerVal, "cvt_") { // --- Agent token: cvt_ prefix --- // Extract L1 and look up agent by token hash. var hash string var err error l1Raw, hash, err = lib.ParseToken(bearerVal) if err != nil { ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid agent token") return } // Open vault DB from L1 l1Key := lib.NormalizeKey(l1Raw) vaultPrefix := base64UrlEncode(l1Raw[:4]) dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix) db, err := lib.OpenDB(dbPath) if err != nil { ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found") return } defer db.Close() // Look up agent agent, err = lib.AgentGetByToken(db, hash) if err != nil { ErrorResponse(w, http.StatusInternalServerError, "agent_error", "Agent lookup failed") return } if agent == nil { ErrorResponse(w, http.StatusUnauthorized, "unknown_token", "Invalid or revoked token") return } ctx := context.WithValue(r.Context(), ctxDB, db) ctx = context.WithValue(ctx, ctxVaultKey, l1Key) ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent) ctx = context.WithValue(ctx, ctxAgent, agent) next.ServeHTTP(w, r.WithContext(ctx)) } else { // --- Legacy L1 bearer (web UI / extension) --- // 8 bytes base64url = vault owner, full access, no agent. l1Raw, err := base64Decode(bearerVal) if err != nil || len(l1Raw) != 8 { ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer") return } l1Key := lib.NormalizeKey(l1Raw) vaultPrefix := base64UrlEncode(l1Raw[:4]) dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix) var db *lib.DB if _, err := os.Stat(dbPath); err == nil { db, err = lib.OpenDB(dbPath) if err != nil { log.Printf("vault open error (%s): %v", dbPath, err) ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault") return } } if db == nil { ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found") return } defer db.Close() ctx := context.WithValue(r.Context(), ctxDB, db) ctx = context.WithValue(ctx, ctxVaultKey, l1Key) ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb) // No agent in context = vault owner (full access) next.ServeHTTP(w, r.WithContext(ctx)) } }) } } // IsAgentRequest returns true if the request was made with a cvt_ agent token. func IsAgentRequest(r *http.Request) bool { return AgentFromContext(r.Context()) != nil } // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusWriter{ResponseWriter: w, status: 200} next.ServeHTTP(wrapped, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start)) }) } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } // RateLimitMiddleware implements per-IP rate limiting. func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler { var mu sync.Mutex clients := make(map[string]*rateLimitEntry) go func() { for { time.Sleep(time.Minute) mu.Lock() now := time.Now() for ip, entry := range clients { if now.Sub(entry.windowStart) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := realIP(r) mu.Lock() entry, exists := clients[ip] now := time.Now() if !exists || now.Sub(entry.windowStart) > time.Minute { entry = &rateLimitEntry{windowStart: now, count: 0} clients[ip] = entry } entry.count++ count := entry.count mu.Unlock() if count > requestsPerMinute { ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests") return } next.ServeHTTP(w, r) }) } } type rateLimitEntry struct { windowStart time.Time count int } // CORSMiddleware handles CORS headers. func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Allow localhost and 127.0.0.1 for development if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers to all responses. func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // CSP allowing localhost and 127.0.0.1 for development w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' localhost 127.0.0.1") next.ServeHTTP(w, r) }) } // ErrorResponse sends a standard JSON error response. func ErrorResponse(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": message, "code": code, }) } // JSONResponse sends a standard JSON success response. func JSONResponse(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } // tarpitHandler holds unrecognized requests for 30 seconds. // Drips one byte per second to keep the connection alive and waste // scanner resources. Capped at 1000 concurrent tarpit slots — // beyond that, connections are dropped immediately. var ( tarpitSem = make(chan struct{}, 1000) ) func tarpitHandler(w http.ResponseWriter, r *http.Request) { select { case tarpitSem <- struct{}{}: defer func() { <-tarpitSem }() default: // Tarpit full — drop immediately, no response if hj, ok := w.(http.Hijacker); ok { conn, _, err := hj.Hijack() if err == nil { conn.Close() } } return } log.Printf("tarpit: %s %s from %s", r.Method, r.URL.Path, realIP(r)) // Chunked response: drip one space per second for 30s w.Header().Set("Content-Type", "text/plain") w.WriteHeader(200) flusher, canFlush := w.(http.Flusher) for i := 0; i < 30; i++ { _, err := w.Write([]byte(" ")) if err != nil { return // client gave up } if canFlush { flusher.Flush() } time.Sleep(time.Second) } } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr }