418 lines
11 KiB
Go
418 lines
11 KiB
Go
package api
|
|
|
|
import (
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
"github.com/mish/dealspace/lib"
|
|
)
|
|
|
|
func testDBSetup(t *testing.T) (*lib.DB, *lib.Config) {
|
|
t.Helper()
|
|
|
|
tmpFile, err := os.CreateTemp("", "dealspace-api-test-*.db")
|
|
if err != nil {
|
|
t.Fatalf("create temp file: %v", err)
|
|
}
|
|
tmpFile.Close()
|
|
t.Cleanup(func() { os.Remove(tmpFile.Name()) })
|
|
|
|
db, err := lib.OpenDB(tmpFile.Name(), "../migrations/001_initial.sql")
|
|
if err != nil {
|
|
t.Fatalf("OpenDB: %v", err)
|
|
}
|
|
t.Cleanup(func() { db.Close() })
|
|
|
|
masterKey := make([]byte, 32)
|
|
jwtSecret := []byte("test-jwt-secret-32-bytes-long!!")
|
|
|
|
cfg := &lib.Config{
|
|
MasterKey: masterKey,
|
|
JWTSecret: jwtSecret,
|
|
}
|
|
|
|
return db, cfg
|
|
}
|
|
|
|
func createTestUserAndSession(t *testing.T, db *lib.DB, cfg *lib.Config) (*lib.User, *lib.Session) {
|
|
t.Helper()
|
|
|
|
now := time.Now().UnixMilli()
|
|
user := &lib.User{
|
|
UserID: uuid.New().String(),
|
|
Email: uuid.New().String() + "@test.com",
|
|
Name: "Test User",
|
|
Password: "$2a$10$test",
|
|
Active: true,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
if err := lib.UserCreate(db, user); err != nil {
|
|
t.Fatalf("UserCreate: %v", err)
|
|
}
|
|
|
|
session := &lib.Session{
|
|
ID: uuid.New().String(),
|
|
UserID: user.UserID,
|
|
Fingerprint: "test-fingerprint",
|
|
CreatedAt: now,
|
|
ExpiresAt: now + 86400000, // +1 day
|
|
Revoked: false,
|
|
}
|
|
if err := lib.SessionCreate(db, session); err != nil {
|
|
t.Fatalf("SessionCreate: %v", err)
|
|
}
|
|
|
|
return user, session
|
|
}
|
|
|
|
// testCreateJWT creates a JWT for testing (different signature from package createJWT)
|
|
func testCreateJWT(userID, sessionID string, expiresAt int64, secret []byte) string {
|
|
header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`))
|
|
|
|
claims := map[string]interface{}{
|
|
"sub": userID,
|
|
"sid": sessionID,
|
|
"exp": expiresAt,
|
|
"iat": time.Now().Unix(),
|
|
}
|
|
claimsJSON, _ := json.Marshal(claims)
|
|
payload := base64.RawURLEncoding.EncodeToString(claimsJSON)
|
|
|
|
signingInput := header + "." + payload
|
|
mac := hmac.New(sha256.New, secret)
|
|
mac.Write([]byte(signingInput))
|
|
signature := base64.RawURLEncoding.EncodeToString(mac.Sum(nil))
|
|
|
|
return header + "." + payload + "." + signature
|
|
}
|
|
|
|
func TestAuthMiddleware_ValidToken(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
user, session := createTestUserAndSession(t, db, cfg)
|
|
|
|
// Create valid JWT
|
|
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
|
|
|
// Create test handler that checks user ID
|
|
var capturedUserID string
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
capturedUserID = UserIDFromContext(r.Context())
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Wrap with auth middleware
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rec.Code)
|
|
}
|
|
if capturedUserID != user.UserID {
|
|
t.Errorf("user ID not set correctly: got %s, want %s", capturedUserID, user.UserID)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_NoToken(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
// No Authorization header
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_ExpiredToken(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
user, session := createTestUserAndSession(t, db, cfg)
|
|
|
|
// Create expired JWT (expired 1 hour ago)
|
|
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()-3600, cfg.JWTSecret)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401 for expired token, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_InvalidToken(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
tests := []struct {
|
|
name string
|
|
token string
|
|
}{
|
|
{"garbage", "not-a-jwt"},
|
|
{"malformed", "a.b.c.d.e"},
|
|
{"wrong signature", testCreateJWT("user", "session", time.Now().Unix()+3600, []byte("wrong-secret"))},
|
|
{"empty bearer", ""},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
if tc.token != "" {
|
|
req.Header.Set("Authorization", "Bearer "+tc.token)
|
|
} else {
|
|
req.Header.Set("Authorization", "Bearer ")
|
|
}
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401, got %d", rec.Code)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_RevokedSession(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
user, session := createTestUserAndSession(t, db, cfg)
|
|
|
|
// Create valid JWT
|
|
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
|
|
|
// Revoke the session
|
|
if err := lib.SessionRevoke(db, session.ID); err != nil {
|
|
t.Fatalf("SessionRevoke: %v", err)
|
|
}
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401 for revoked session, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestAuthMiddleware_ExpiredSession(t *testing.T) {
|
|
db, cfg := testDBSetup(t)
|
|
|
|
now := time.Now().UnixMilli()
|
|
user := &lib.User{
|
|
UserID: uuid.New().String(),
|
|
Email: uuid.New().String() + "@test.com",
|
|
Name: "Test User",
|
|
Password: "$2a$10$test",
|
|
Active: true,
|
|
CreatedAt: now,
|
|
UpdatedAt: now,
|
|
}
|
|
lib.UserCreate(db, user)
|
|
|
|
// Create session that's already expired
|
|
session := &lib.Session{
|
|
ID: uuid.New().String(),
|
|
UserID: user.UserID,
|
|
Fingerprint: "test-fingerprint",
|
|
CreatedAt: now - 86400000*2, // 2 days ago
|
|
ExpiresAt: now - 86400000, // expired 1 day ago
|
|
Revoked: false,
|
|
}
|
|
lib.SessionCreate(db, session)
|
|
|
|
// Create JWT that hasn't expired (but session has)
|
|
token := testCreateJWT(user.UserID, session.ID, time.Now().Unix()+3600, cfg.JWTSecret)
|
|
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := AuthMiddleware(db, cfg.JWTSecret)(handler)
|
|
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.Header.Set("Authorization", "Bearer "+token)
|
|
rec := httptest.NewRecorder()
|
|
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusUnauthorized {
|
|
t.Errorf("expected 401 for expired session, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestCORSMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
wrapped := CORSMiddleware(handler)
|
|
|
|
// Regular request
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
rec := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Header().Get("Access-Control-Allow-Origin") != "*" {
|
|
t.Error("CORS header not set")
|
|
}
|
|
|
|
// Preflight request
|
|
req = httptest.NewRequest("OPTIONS", "/api/test", nil)
|
|
rec = httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusNoContent {
|
|
t.Errorf("OPTIONS should return 204, got %d", rec.Code)
|
|
}
|
|
if rec.Header().Get("Access-Control-Allow-Methods") == "" {
|
|
t.Error("Allow-Methods header not set")
|
|
}
|
|
}
|
|
|
|
func TestLoggingMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusCreated)
|
|
})
|
|
|
|
wrapped := LoggingMiddleware(handler)
|
|
|
|
req := httptest.NewRequest("POST", "/api/test", nil)
|
|
rec := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusCreated {
|
|
t.Errorf("expected 201, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestRateLimitMiddleware(t *testing.T) {
|
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.WriteHeader(http.StatusOK)
|
|
})
|
|
|
|
// Very low limit for testing
|
|
wrapped := RateLimitMiddleware(3)(handler)
|
|
|
|
// First 3 requests should succeed
|
|
for i := 0; i < 3; i++ {
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("request %d should succeed, got %d", i+1, rec.Code)
|
|
}
|
|
}
|
|
|
|
// 4th request should be rate limited
|
|
req := httptest.NewRequest("GET", "/api/test", nil)
|
|
req.RemoteAddr = "192.168.1.1:12345"
|
|
rec := httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusTooManyRequests {
|
|
t.Errorf("4th request should be rate limited, got %d", rec.Code)
|
|
}
|
|
|
|
// Different IP should succeed
|
|
req = httptest.NewRequest("GET", "/api/test", nil)
|
|
req.RemoteAddr = "192.168.1.2:12345"
|
|
rec = httptest.NewRecorder()
|
|
wrapped.ServeHTTP(rec, req)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("different IP should succeed, got %d", rec.Code)
|
|
}
|
|
}
|
|
|
|
func TestErrorResponse(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
ErrorResponse(rec, http.StatusBadRequest, "bad_request", "Invalid input")
|
|
|
|
if rec.Code != http.StatusBadRequest {
|
|
t.Errorf("expected 400, got %d", rec.Code)
|
|
}
|
|
|
|
var resp map[string]string
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
|
|
if resp["code"] != "bad_request" {
|
|
t.Errorf("expected code 'bad_request', got %s", resp["code"])
|
|
}
|
|
if resp["error"] != "Invalid input" {
|
|
t.Errorf("expected error 'Invalid input', got %s", resp["error"])
|
|
}
|
|
}
|
|
|
|
func TestJSONResponse(t *testing.T) {
|
|
rec := httptest.NewRecorder()
|
|
data := map[string]interface{}{
|
|
"id": 123,
|
|
"name": "test",
|
|
}
|
|
JSONResponse(rec, http.StatusOK, data)
|
|
|
|
if rec.Code != http.StatusOK {
|
|
t.Errorf("expected 200, got %d", rec.Code)
|
|
}
|
|
|
|
if rec.Header().Get("Content-Type") != "application/json" {
|
|
t.Error("Content-Type should be application/json")
|
|
}
|
|
|
|
var resp map[string]interface{}
|
|
if err := json.NewDecoder(rec.Body).Decode(&resp); err != nil {
|
|
t.Fatalf("decode response: %v", err)
|
|
}
|
|
|
|
if resp["name"] != "test" {
|
|
t.Error("response data incorrect")
|
|
}
|
|
}
|