package api import ( "context" "encoding/json" "log" "net/http" "strings" "sync" "time" "github.com/johanj/clawvault/lib" ) type contextKey string const ( ctxActor contextKey = "actor" ctxSession contextKey = "session" ) // ActorFromContext returns the actor type from request context. func ActorFromContext(ctx context.Context) string { v, ok := ctx.Value(ctxActor).(string) if !ok { return lib.ActorWeb } return v } // SessionFromContext returns the session from request context. func SessionFromContext(ctx context.Context) *lib.Session { v, _ := ctx.Value(ctxSession).(*lib.Session) return v } // AuthMiddleware validates Bearer tokens and sets session context. func AuthMiddleware(db *lib.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Authorization header required") return } token := strings.TrimPrefix(auth, "Bearer ") session, err := lib.SessionGet(db, token) if err != nil { ErrorResponse(w, http.StatusInternalServerError, "session_error", "Session lookup failed") return } if session == nil { ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token") return } ctx := context.WithValue(r.Context(), ctxActor, session.Actor) ctx = context.WithValue(ctx, ctxSession, session) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusWriter{ResponseWriter: w, status: 200} next.ServeHTTP(wrapped, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start)) }) } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } // RateLimitMiddleware implements per-IP rate limiting. func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler { var mu sync.Mutex clients := make(map[string]*rateLimitEntry) go func() { for { time.Sleep(time.Minute) mu.Lock() now := time.Now() for ip, entry := range clients { if now.Sub(entry.windowStart) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := realIP(r) mu.Lock() entry, exists := clients[ip] now := time.Now() if !exists || now.Sub(entry.windowStart) > time.Minute { entry = &rateLimitEntry{windowStart: now, count: 0} clients[ip] = entry } entry.count++ count := entry.count mu.Unlock() if count > requestsPerMinute { ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests") return } next.ServeHTTP(w, r) }) } } type rateLimitEntry struct { windowStart time.Time count int } // CORSMiddleware handles CORS headers. func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Allow localhost and 127.0.0.1 for development if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers to all responses. func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Frame-Options", "DENY") w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("X-XSS-Protection", "1; mode=block") w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // CSP allowing localhost and 127.0.0.1 for development w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' localhost 127.0.0.1") next.ServeHTTP(w, r) }) } // ErrorResponse sends a standard JSON error response. func ErrorResponse(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": message, "code": code, }) } // JSONResponse sends a standard JSON success response. func JSONResponse(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr }