dealspace/api/middleware_test.go

417 lines
11 KiB
Go

package api
import (
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"time"
"github.com/google/uuid"
"github.com/mish/dealspace/lib"
)
func testDBSetup(t *testing.T) (*lib.DB, *lib.Config) {
t.Helper()
tmpFile, err := os.CreateTemp("", "dealspace-api-test-*.db")
if err != nil {
t.Fatalf("create temp file: %v", err)
}
tmpFile.Close()
t.Cleanup(func() { os.Remove(tmpFile.Name()) })
db, err := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
if err != nil {
t.Fatalf("OpenDB: %v", err)
}
t.Cleanup(func() { db.Close() })
masterKey := make([]byte, 32)
jwtSecret := []byte("test-jwt-secret-32-bytes-long!!")
cfg := &lib.Config{
MasterKey: masterKey,
JWTSecret: jwtSecret,
}
return db, cfg
}
func createTestUserAndSession(t *testing.T, db *lib.DB, cfg *lib.Config) (*lib.User, *lib.Session) {
t.Helper()
now := time.Now().UnixMilli()
user := &lib.User{
UserID: uuid.New().String(),
Email: uuid.New().String() + "@test.com",
Name: "Test User",
Password: "$2a$10$test",
Active: true,
CreatedAt: now,
UpdatedAt: now,
}
if err := lib.UserCreate(db, user); err != nil {
t.Fatalf("UserCreate: %v", err)
}
session := &lib.Session{
ID: uuid.New().String(),
UserID: user.UserID,
Fingerprint: "test-fingerprint",
CreatedAt: now,
ExpiresAt: now + 86400000, // +1 day
Revoked: false,
}
if err := lib.SessionCreate(db, session); err != nil {
t.Fatalf("SessionCreate: %v", err)
}
return user, session
}
func createJWT(userID, sessionID string, expiresAt int64, secret []byte) string {
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
claims := map[string]interface{}{
"sub": userID,
"sid": sessionID,
"exp": expiresAt,
"iat": time.Now().Unix(),
}
claimsJSON, _ := json.Marshal(claims)
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
signingInput := header + "." + payload
mac := hmac.New(sha256.New, secret)
mac.Write([]byte(signingInput))
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
return header + "." + payload + "." + signature
}
func TestAuthMiddleware_ValidToken(t *testing.T) {
db, cfg := testDBSetup(t)
user, session := createTestUserAndSession(t, db, cfg)
// Create valid JWT
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
// Create test handler that checks user ID
var capturedUserID string
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
capturedUserID = UserIDFromContext(r.Context())
w.WriteHeader(http.StatusOK)
})
// Wrap with auth middleware
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rec.Code)
}
if capturedUserID != user.UserID {
t.Errorf("user ID not set correctly: got %s, want %s", capturedUserID, user.UserID)
}
}
func TestAuthMiddleware_NoToken(t *testing.T) {
db, cfg := testDBSetup(t)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
req := httptest.NewRequest("GET", "/api/test", nil)
// No Authorization header
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rec.Code)
}
}
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
db, cfg := testDBSetup(t)
user, session := createTestUserAndSession(t, db, cfg)
// Create expired JWT (expired 1 hour ago)
token := createJWT(user.UserID, session.ID, time.Now().Unix()-3600, cfg.JWTSecret)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for expired token, got %d", rec.Code)
}
}
func TestAuthMiddleware_InvalidToken(t *testing.T) {
db, cfg := testDBSetup(t)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
tests := []struct {
name string
token string
}{
{"garbage", "not-a-jwt"},
{"malformed", "a.b.c.d.e"},
{"wrong signature", createJWT("user", "session", time.Now().Unix()+3600, []byte("wrong-secret"))},
{"empty bearer", ""},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
req := httptest.NewRequest("GET", "/api/test", nil)
if tc.token != "" {
req.Header.Set("Authorization", "Bearer "+tc.token)
} else {
req.Header.Set("Authorization", "Bearer ")
}
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401, got %d", rec.Code)
}
})
}
}
func TestAuthMiddleware_RevokedSession(t *testing.T) {
db, cfg := testDBSetup(t)
user, session := createTestUserAndSession(t, db, cfg)
// Create valid JWT
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
// Revoke the session
if err := lib.SessionRevoke(db, session.ID); err != nil {
t.Fatalf("SessionRevoke: %v", err)
}
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for revoked session, got %d", rec.Code)
}
}
func TestAuthMiddleware_ExpiredSession(t *testing.T) {
db, cfg := testDBSetup(t)
now := time.Now().UnixMilli()
user := &lib.User{
UserID: uuid.New().String(),
Email: uuid.New().String() + "@test.com",
Name: "Test User",
Password: "$2a$10$test",
Active: true,
CreatedAt: now,
UpdatedAt: now,
}
lib.UserCreate(db, user)
// Create session that's already expired
session := &lib.Session{
ID: uuid.New().String(),
UserID: user.UserID,
Fingerprint: "test-fingerprint",
CreatedAt: now - 86400000*2, // 2 days ago
ExpiresAt: now - 86400000, // expired 1 day ago
Revoked: false,
}
lib.SessionCreate(db, session)
// Create JWT that hasn't expired (but session has)
token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
req := httptest.NewRequest("GET", "/api/test", nil)
req.Header.Set("Authorization", "Bearer "+token)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusUnauthorized {
t.Errorf("expected 401 for expired session, got %d", rec.Code)
}
}
func TestCORSMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
wrapped := CORSMiddleware(handler)
// Regular request
req := httptest.NewRequest("GET", "/api/test", nil)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
t.Error("CORS header not set")
}
// Preflight request
req = httptest.NewRequest("OPTIONS", "/api/test", nil)
rec = httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusNoContent {
t.Errorf("OPTIONS should return 204, got %d", rec.Code)
}
if rec.Header().Get("Access-Control-Allow-Methods") == "" {
t.Error("Allow-Methods header not set")
}
}
func TestLoggingMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusCreated)
})
wrapped := LoggingMiddleware(handler)
req := httptest.NewRequest("POST", "/api/test", nil)
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusCreated {
t.Errorf("expected 201, got %d", rec.Code)
}
}
func TestRateLimitMiddleware(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
// Very low limit for testing
wrapped := RateLimitMiddleware(3)(handler)
// First 3 requests should succeed
for i := 0; i < 3; i++ {
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("request %d should succeed, got %d", i+1, rec.Code)
}
}
// 4th request should be rate limited
req := httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.1:12345"
rec := httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusTooManyRequests {
t.Errorf("4th request should be rate limited, got %d", rec.Code)
}
// Different IP should succeed
req = httptest.NewRequest("GET", "/api/test", nil)
req.RemoteAddr = "192.168.1.2:12345"
rec = httptest.NewRecorder()
wrapped.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("different IP should succeed, got %d", rec.Code)
}
}
func TestErrorResponse(t *testing.T) {
rec := httptest.NewRecorder()
ErrorResponse(rec, http.StatusBadRequest, "bad_request", "Invalid input")
if rec.Code != http.StatusBadRequest {
t.Errorf("expected 400, got %d", rec.Code)
}
var resp map[string]string
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if resp["code"] != "bad_request" {
t.Errorf("expected code 'bad_request', got %s", resp["code"])
}
if resp["error"] != "Invalid input" {
t.Errorf("expected error 'Invalid input', got %s", resp["error"])
}
}
func TestJSONResponse(t *testing.T) {
rec := httptest.NewRecorder()
data := map[string]interface{}{
"id": 123,
"name": "test",
}
JSONResponse(rec, http.StatusOK, data)
if rec.Code != http.StatusOK {
t.Errorf("expected 200, got %d", rec.Code)
}
if rec.Header().Get("Content-Type") != "application/json" {
t.Error("Content-Type should be application/json")
}
var resp map[string]interface{}
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
t.Fatalf("decode response: %v", err)
}
if resp["name"] != "test" {
t.Error("response data incorrect")
}
}