dealspace/api/middleware.go

288 lines
7.6 KiB
Go

package api
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"log"
"net/http"
"strings"
"sync"
"time"
"github.com/mish/dealspace/lib"
)
type contextKey string
const (
ctxUserID contextKey = "user_id"
)
// UserIDFromContext extracts the authenticated user ID from the request context.
func UserIDFromContext(ctx context.Context) string {
v, _ := ctx.Value(ctxUserID).(string)
return v
}
// AuthMiddleware validates JWT tokens and sets user context.
func AuthMiddleware(db *lib.DB, jwtSecret []byte) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
if !strings.HasPrefix(auth, "Bearer ") {
ErrorResponse(w, http.StatusUnauthorized, "missing_token", "Authorization header required")
return
}
token := strings.TrimPrefix(auth, "Bearer ")
claims, err := validateJWT(token, jwtSecret)
if err != nil {
ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid or expired token")
return
}
session, err := lib.SessionByID(db, claims.SessionID)
if err != nil || session == nil || session.Revoked {
ErrorResponse(w, http.StatusUnauthorized, "session_revoked", "Session has been revoked")
return
}
if session.ExpiresAt < time.Now().UnixMilli() {
ErrorResponse(w, http.StatusUnauthorized, "session_expired", "Session has expired")
return
}
ctx := context.WithValue(r.Context(), ctxUserID, claims.UserID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// LoggingMiddleware logs HTTP requests.
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: 200}
next.ServeHTTP(wrapped, r)
log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start))
})
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
// RateLimitMiddleware implements per-IP rate limiting.
func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler {
var mu sync.Mutex
clients := make(map[string]*rateLimitEntry)
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
now := time.Now()
for ip, entry := range clients {
if now.Sub(entry.windowStart) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := realIP(r)
mu.Lock()
entry, exists := clients[ip]
now := time.Now()
if !exists || now.Sub(entry.windowStart) > time.Minute {
entry = &rateLimitEntry{windowStart: now, count: 0}
clients[ip] = entry
}
entry.count++
count := entry.count
mu.Unlock()
if count > requestsPerMinute {
ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests")
return
}
next.ServeHTTP(w, r)
})
}
}
type rateLimitEntry struct {
windowStart time.Time
count int
}
// allowedOrigins is the list of origins allowed for CORS.
var allowedOrigins = map[string]bool{
"https://muskepo.com": true,
"https://www.muskepo.com": true,
"https://app.muskepo.com": true,
"https://dealspace.io": true,
"https://app.dealspace.io": true,
"http://localhost:8080": true,
"http://localhost:3000": true,
}
// CORSMiddleware handles CORS headers with origin allowlist.
func CORSMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Check if origin is allowed
if allowedOrigins[origin] {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers to all responses.
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// Prevent clickjacking
w.Header().Set("X-Frame-Options", "DENY")
// Prevent MIME type sniffing
w.Header().Set("X-Content-Type-Options", "nosniff")
// XSS protection (legacy but still useful)
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Referrer policy
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// Content Security Policy - restrictive default
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline'; style-src 'self' 'unsafe-inline'; img-src 'self' data: https:; font-src 'self' data:; connect-src 'self' https://api.fireworks.ai")
next.ServeHTTP(w, r)
})
}
// ErrorResponse sends a standard JSON error response.
func ErrorResponse(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{
"error": message,
"code": code,
})
}
// JSONResponse sends a standard JSON success response.
func JSONResponse(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func realIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
return strings.TrimSpace(parts[0])
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}
// JWT
type jwtClaims struct {
UserID string `json:"sub"`
SessionID string `json:"sid"`
ExpiresAt int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
}
// createJWT creates a signed JWT with the given claims.
func createJWT(userID, sessionID string, secret []byte, duration int64) (string, error) {
now := time.Now().Unix()
claims := jwtClaims{
UserID: userID,
SessionID: sessionID,
ExpiresAt: now + duration,
IssuedAt: now,
}
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
payloadJSON, err := json.Marshal(claims)
if err != nil {
return "", err
}
payload := base64.RawURLEncoding.EncodeToString(payloadJSON)
signingInput := header + "." + payload
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(signingInput))
sig := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return signingInput + "." + sig, nil
}
func validateJWT(token string, secret []byte) (*jwtClaims, error) {
parts := strings.Split(token, ".")
if len(parts) != 3 {
return nil, lib.ErrAccessDenied
}
// Verify HMAC-SHA256 signature
signingInput := parts[0] + "." + parts[1]
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(signingInput))
expectedSig := mac.Sum(nil)
sig, err := base64.RawURLEncoding.DecodeString(parts[2])
if err != nil {
return nil, lib.ErrAccessDenied
}
if !hmac.Equal(sig, expectedSig) {
return nil, lib.ErrAccessDenied
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, lib.ErrAccessDenied
}
var claims jwtClaims
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, lib.ErrAccessDenied
}
if claims.ExpiresAt < time.Now().Unix() {
return nil, lib.ErrAccessDenied
}
return &claims, nil
}