package api import ( "context" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "log" "net/http" "strings" "sync" "time" "github.com/mish/dealspace/lib" ) type contextKey string const ( ctxUserID contextKey = "user_id" ) // UserIDFromContext extracts the authenticated user ID from the request context. func UserIDFromContext(ctx context.Context) string { v, _ := ctx.Value(ctxUserID).(string) return v } // AuthMiddleware validates JWT tokens and sets user context. func AuthMiddleware(db *lib.DB, jwtSecret []byte) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Authorization header required") return } token := strings.TrimPrefix(auth, "Bearer ") claims, err := validateJWT(token, jwtSecret) if err != nil { ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token") return } session, err := lib.SessionByID(db, claims.SessionID) if err != nil || session == nil || session.Revoked { ErrorResponse(w, http.StatusUnauthorized, "session_revoked", "Session has been revoked") return } if session.ExpiresAt < time.Now().UnixMilli() { ErrorResponse(w, http.StatusUnauthorized, "session_expired", "Session has expired") return } ctx := context.WithValue(r.Context(), ctxUserID, claims.UserID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // OAuthBearerAuth validates OAuth 2.0 bearer tokens for MCP endpoints. func OAuthBearerAuth(db *lib.DB) func(http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") if !strings.HasPrefix(auth, "Bearer ") { w.Header().Set("WWW-Authenticate", "Bearer") ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Bearer token required") return } tokenStr := strings.TrimPrefix(auth, "Bearer ") token, err := lib.OAuthTokenValidate(db, tokenStr) if err != nil || token == nil { w.Header().Set("WWW-Authenticate", "Bearer error=\"invalid_token\"") ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token") return } ctx := context.WithValue(r.Context(), ctxUserID, token.UserID) next.ServeHTTP(w, r.WithContext(ctx)) }) } } // LoggingMiddleware logs HTTP requests. func LoggingMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { start := time.Now() wrapped := &statusWriter{ResponseWriter: w, status: 200} next.ServeHTTP(wrapped, r) log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start)) }) } type statusWriter struct { http.ResponseWriter status int } func (w *statusWriter) WriteHeader(code int) { w.status = code w.ResponseWriter.WriteHeader(code) } // RateLimitMiddleware implements per-IP rate limiting. func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler { var mu sync.Mutex clients := make(map[string]*rateLimitEntry) go func() { for { time.Sleep(time.Minute) mu.Lock() now := time.Now() for ip, entry := range clients { if now.Sub(entry.windowStart) > time.Minute { delete(clients, ip) } } mu.Unlock() } }() return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip := realIP(r) mu.Lock() entry, exists := clients[ip] now := time.Now() if !exists || now.Sub(entry.windowStart) > time.Minute { entry = &rateLimitEntry{windowStart: now, count: 0} clients[ip] = entry } entry.count++ count := entry.count mu.Unlock() if count > requestsPerMinute { ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests") return } next.ServeHTTP(w, r) }) } } type rateLimitEntry struct { windowStart time.Time count int } // allowedOrigins is the list of origins allowed for CORS. var allowedOrigins = map[string]bool{ "https://muskepo.com": true, "https://www.muskepo.com": true, "https://app.muskepo.com": true, "https://dealspace.io": true, "https://app.dealspace.io": true, "http://localhost:8080": true, "http://localhost:3000": true, } // CORSMiddleware handles CORS headers with origin allowlist. func CORSMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := r.Header.Get("Origin") // Check if origin is allowed if allowedOrigins[origin] { w.Header().Set("Access-Control-Allow-Origin", origin) w.Header().Set("Vary", "Origin") } w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type") w.Header().Set("Access-Control-Max-Age", "86400") if r.Method == "OPTIONS" { w.WriteHeader(http.StatusNoContent) return } next.ServeHTTP(w, r) }) } // SecurityHeadersMiddleware adds security headers to all responses. func SecurityHeadersMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Prevent clickjacking w.Header().Set("X-Frame-Options", "DENY") // Prevent MIME type sniffing w.Header().Set("X-Content-Type-Options", "nosniff") // XSS protection (legacy but still useful) w.Header().Set("X-XSS-Protection", "1; mode=block") // Referrer policy w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin") // Content Security Policy - restrictive default w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' https://api.fireworks.ai https://fonts.googleapis.com") next.ServeHTTP(w, r) }) } // blockedExtensions lists file extensions that must NEVER be served or accepted, // regardless of authentication level, role, or any other condition. // This is a hard platform rule: raw database files are never accessible via the portal. var blockedExtensions = []string{".db", ".sqlite", ".sqlite3", ".sql", ".mdb", ".accdb"} // isBlockedExtension returns true if the filename ends with a blocked extension. func isBlockedExtension(filename string) bool { lower := strings.ToLower(strings.TrimSpace(filename)) for _, ext := range blockedExtensions { if strings.HasSuffix(lower, ext) { return true } } return false } // BlockDatabaseMiddleware is a hard stop on any request that attempts to serve or // accept a raw database file. This rule cannot be overridden by role, auth level, // or any user action — it is enforced at the transport layer before handlers run. func BlockDatabaseMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // Check the URL path itself if isBlockedExtension(r.URL.Path) { http.Error(w, "Forbidden", http.StatusForbidden) return } // Check common query params used for file serving for _, param := range []string{"filename", "name", "file", "path"} { if isBlockedExtension(r.URL.Query().Get(param)) { http.Error(w, "Forbidden", http.StatusForbidden) return } } next.ServeHTTP(w, r) }) } // ErrorResponse sends a standard JSON error response. func ErrorResponse(w http.ResponseWriter, status int, code, message string) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(map[string]string{ "error": message, "code": code, }) } // JSONResponse sends a standard JSON success response. func JSONResponse(w http.ResponseWriter, status int, data any) { w.Header().Set("Content-Type", "application/json") w.WriteHeader(status) json.NewEncoder(w).Encode(data) } func realIP(r *http.Request) string { if xff := r.Header.Get("X-Forwarded-For"); xff != "" { parts := strings.SplitN(xff, ",", 2) return strings.TrimSpace(parts[0]) } if xri := r.Header.Get("X-Real-IP"); xri != "" { return xri } addr := r.RemoteAddr if idx := strings.LastIndex(addr, ":"); idx != -1 { return addr[:idx] } return addr } // JWT type jwtClaims struct { UserID string `json:"sub"` SessionID string `json:"sid"` ExpiresAt int64 `json:"exp"` IssuedAt int64 `json:"iat"` } // createJWT creates a signed JWT with the given claims. func createJWT(userID, sessionID string, secret []byte, duration int64) (string, error) { now := time.Now().Unix() claims := jwtClaims{ UserID: userID, SessionID: sessionID, ExpiresAt: now + duration, IssuedAt: now, } header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) payloadJSON, err := json.Marshal(claims) if err != nil { return "", err } payload := base64.RawURLEncoding.EncodeToString(payloadJSON) signingInput := header + "." + payload mac := hmac.New(sha256.New, secret) mac.Write([]byte(signingInput)) sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) return signingInput + "." + sig, nil } func validateJWT(token string, secret []byte) (*jwtClaims, error) { parts := strings.Split(token, ".") if len(parts) != 3 { return nil, lib.ErrAccessDenied } // Verify HMAC-SHA256 signature signingInput := parts[0] + "." + parts[1] mac := hmac.New(sha256.New, secret) mac.Write([]byte(signingInput)) expectedSig := mac.Sum(nil) sig, err := base64.RawURLEncoding.DecodeString(parts[2]) if err != nil { return nil, lib.ErrAccessDenied } if !hmac.Equal(sig, expectedSig) { return nil, lib.ErrAccessDenied } payload, err := base64.RawURLEncoding.DecodeString(parts[1]) if err != nil { return nil, lib.ErrAccessDenied } var claims jwtClaims if err := json.Unmarshal(payload, &claims); err != nil { return nil, lib.ErrAccessDenied } if claims.ExpiresAt < time.Now().Unix() { return nil, lib.ErrAccessDenied } return &claims, nil }