package api import ( "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "net/http" "net/http/httptest" "os" "testing" "time" "github.com/google/uuid" "github.com/mish/dealspace/lib" ) func testDBSetup(t *testing.T) (*lib.DB, *lib.Config) { t.Helper() tmpFile, err := os.CreateTemp("", "dealspace-api-test-*.db") if err != nil { t.Fatalf("create temp file: %v", err) } tmpFile.Close() t.Cleanup(func() { os.Remove(tmpFile.Name()) }) db, err := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql") if err != nil { t.Fatalf("OpenDB: %v", err) } t.Cleanup(func() { db.Close() }) masterKey := make([]byte, 32) jwtSecret := []byte("test-jwt-secret-32-bytes-long!!") cfg := &lib.Config{ MasterKey: masterKey, JWTSecret: jwtSecret, } return db, cfg } func createTestUserAndSession(t *testing.T, db *lib.DB, cfg *lib.Config) (*lib.User, *lib.Session) { t.Helper() now := time.Now().UnixMilli() user := &lib.User{ UserID: uuid.New().String(), Email: uuid.New().String() + "@test.com", Name: "Test User", Password: "$2a$10$test", Active: true, CreatedAt: now, UpdatedAt: now, } if err := lib.UserCreate(db, user); err != nil { t.Fatalf("UserCreate: %v", err) } session := &lib.Session{ ID: uuid.New().String(), UserID: user.UserID, Fingerprint: "test-fingerprint", CreatedAt: now, ExpiresAt: now + 86400000, // +1 day Revoked: false, } if err := lib.SessionCreate(db, session); err != nil { t.Fatalf("SessionCreate: %v", err) } return user, session } func createJWT(userID, sessionID string, expiresAt int64, secret []byte) string { header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) claims := map[string]interface{}{ "sub": userID, "sid": sessionID, "exp": expiresAt, "iat": time.Now().Unix(), } claimsJSON, _ := json.Marshal(claims) payload := base64.RawURLEncoding.EncodeToString(claimsJSON) signingInput := header + "." + payload mac := hmac.New(sha256.New, secret) mac.Write([]byte(signingInput)) signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil)) return header + "." + payload + "." + signature } func TestAuthMiddleware_ValidToken(t *testing.T) { db, cfg := testDBSetup(t) user, session := createTestUserAndSession(t, db, cfg) // Create valid JWT token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret) // Create test handler that checks user ID var capturedUserID string handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { capturedUserID = UserIDFromContext(r.Context()) w.WriteHeader(http.StatusOK) }) // Wrap with auth middleware wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("expected 200, got %d", rec.Code) } if capturedUserID != user.UserID { t.Errorf("user ID not set correctly: got %s, want %s", capturedUserID, user.UserID) } } func TestAuthMiddleware_NoToken(t *testing.T) { db, cfg := testDBSetup(t) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) req := httptest.NewRequest("GET", "/api/test", nil) // No Authorization header rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rec.Code) } } func TestAuthMiddleware_ExpiredToken(t *testing.T) { db, cfg := testDBSetup(t) user, session := createTestUserAndSession(t, db, cfg) // Create expired JWT (expired 1 hour ago) token := createJWT(user.UserID, session.ID, time.Now().Unix()-3600, cfg.JWTSecret) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401 for expired token, got %d", rec.Code) } } func TestAuthMiddleware_InvalidToken(t *testing.T) { db, cfg := testDBSetup(t) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) tests := []struct { name string token string }{ {"garbage", "not-a-jwt"}, {"malformed", "a.b.c.d.e"}, {"wrong signature", createJWT("user", "session", time.Now().Unix()+3600, []byte("wrong-secret"))}, {"empty bearer", ""}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { req := httptest.NewRequest("GET", "/api/test", nil) if tc.token != "" { req.Header.Set("Authorization", "Bearer "+tc.token) } else { req.Header.Set("Authorization", "Bearer ") } rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401, got %d", rec.Code) } }) } } func TestAuthMiddleware_RevokedSession(t *testing.T) { db, cfg := testDBSetup(t) user, session := createTestUserAndSession(t, db, cfg) // Create valid JWT token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret) // Revoke the session if err := lib.SessionRevoke(db, session.ID); err != nil { t.Fatalf("SessionRevoke: %v", err) } handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401 for revoked session, got %d", rec.Code) } } func TestAuthMiddleware_ExpiredSession(t *testing.T) { db, cfg := testDBSetup(t) now := time.Now().UnixMilli() user := &lib.User{ UserID: uuid.New().String(), Email: uuid.New().String() + "@test.com", Name: "Test User", Password: "$2a$10$test", Active: true, CreatedAt: now, UpdatedAt: now, } lib.UserCreate(db, user) // Create session that's already expired session := &lib.Session{ ID: uuid.New().String(), UserID: user.UserID, Fingerprint: "test-fingerprint", CreatedAt: now - 86400000*2, // 2 days ago ExpiresAt: now - 86400000, // expired 1 day ago Revoked: false, } lib.SessionCreate(db, session) // Create JWT that hasn't expired (but session has) token := createJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret) handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler) req := httptest.NewRequest("GET", "/api/test", nil) req.Header.Set("Authorization", "Bearer "+token) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusUnauthorized { t.Errorf("expected 401 for expired session, got %d", rec.Code) } } func TestCORSMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) wrapped := CORSMiddleware(handler) // Regular request req := httptest.NewRequest("GET", "/api/test", nil) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Header().Get("Access-Control-Allow-Origin") != "*" { t.Error("CORS header not set") } // Preflight request req = httptest.NewRequest("OPTIONS", "/api/test", nil) rec = httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusNoContent { t.Errorf("OPTIONS should return 204, got %d", rec.Code) } if rec.Header().Get("Access-Control-Allow-Methods") == "" { t.Error("Allow-Methods header not set") } } func TestLoggingMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusCreated) }) wrapped := LoggingMiddleware(handler) req := httptest.NewRequest("POST", "/api/test", nil) rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusCreated { t.Errorf("expected 201, got %d", rec.Code) } } func TestRateLimitMiddleware(t *testing.T) { handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) }) // Very low limit for testing wrapped := RateLimitMiddleware(3)(handler) // First 3 requests should succeed for i := 0; i < 3; i++ { req := httptest.NewRequest("GET", "/api/test", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("request %d should succeed, got %d", i+1, rec.Code) } } // 4th request should be rate limited req := httptest.NewRequest("GET", "/api/test", nil) req.RemoteAddr = "192.168.1.1:12345" rec := httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusTooManyRequests { t.Errorf("4th request should be rate limited, got %d", rec.Code) } // Different IP should succeed req = httptest.NewRequest("GET", "/api/test", nil) req.RemoteAddr = "192.168.1.2:12345" rec = httptest.NewRecorder() wrapped.ServeHTTP(rec, req) if rec.Code != http.StatusOK { t.Errorf("different IP should succeed, got %d", rec.Code) } } func TestErrorResponse(t *testing.T) { rec := httptest.NewRecorder() ErrorResponse(rec, http.StatusBadRequest, "bad_request", "Invalid input") if rec.Code != http.StatusBadRequest { t.Errorf("expected 400, got %d", rec.Code) } var resp map[string]string if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp["code"] != "bad_request" { t.Errorf("expected code 'bad_request', got %s", resp["code"]) } if resp["error"] != "Invalid input" { t.Errorf("expected error 'Invalid input', got %s", resp["error"]) } } func TestJSONResponse(t *testing.T) { rec := httptest.NewRecorder() data := map[string]interface{}{ "id": 123, "name": "test", } JSONResponse(rec, http.StatusOK, data) if rec.Code != http.StatusOK { t.Errorf("expected 200, got %d", rec.Code) } if rec.Header().Get("Content-Type") != "application/json" { t.Error("Content-Type should be application/json") } var resp map[string]interface{} if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil { t.Fatalf("decode response: %v", err) } if resp["name"] != "test" { t.Error("response data incorrect") } }