//go:build commercial package main import ( "bytes" "crypto/tls" "crypto/x509" "database/sql" "encoding/json" "net/http" "net/http/httptest" "os" "strings" "testing" "time" _ "github.com/mattn/go-sqlite3" ) // setupTestDB creates an in-memory database for testing func setupTestDB(t *testing.T) { var err error db, err = sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Failed to open test database: %v", err) } ensureTables() } // cleanupTestDB closes the test database func cleanupTestDB() { if db != nil { db.Close() } } func TestTarpit(t *testing.T) { // tarpit holds connection for 30 seconds - test that it responds initially req := httptest.NewRequest("GET", "/unknown", nil) w := httptest.NewRecorder() // Use a goroutine since tarpit blocks done := make(chan bool) go func() { tarpit(w, req) done <- true }() // Check initial response comes through quickly time.Sleep(100 * time.Millisecond) resp := w.Result() if resp.StatusCode != 200 { t.Errorf("tarpit status = %d, want 200", resp.StatusCode) } if resp.Header.Get("Content-Type") != "text/plain" { t.Errorf("tarpit content-type = %s, want text/plain", resp.Header.Get("Content-Type")) } } func TestHandleHealth(t *testing.T) { setupTestDB(t) defer cleanupTestDB() req := httptest.NewRequest("GET", "/health", nil) w := httptest.NewRecorder() handleHealth(w, req) resp := w.Result() if resp.StatusCode != 200 { t.Errorf("handleHealth status = %d, want 200", resp.StatusCode) } var result map[string]interface{} if err := json.NewDecoder(resp.Body).Decode(&result); err != nil { t.Fatalf("Failed to decode health response: %v", err) } if result["status"] != "ok" { t.Errorf("handleHealth status = %v, want ok", result["status"]) } if result["db"] != "ok" { t.Errorf("handleHealth db = %v, want ok", result["db"]) } } func TestHandleHealth_DBError(t *testing.T) { // Don't setup DB - should return error if db != nil { db.Close() db = nil } // Create a closed database to simulate failure var err error db, err = sql.Open("sqlite3", ":memory:") if err != nil { t.Fatalf("Failed to open test database: %v", err) } db.Close() // Close immediately to force errors req := httptest.NewRequest("GET", "/health", nil) w := httptest.NewRecorder() handleHealth(w, req) resp := w.Result() if resp.StatusCode != 503 { t.Errorf("handleHealth with bad DB status = %d, want 503", resp.StatusCode) } } func TestHandleTelemetry_MethodNotAllowed(t *testing.T) { setupTestDB(t) defer cleanupTestDB() req := httptest.NewRequest("GET", "/telemetry", nil) w := httptest.NewRecorder() handleTelemetry(w, req) resp := w.Result() if resp.StatusCode != 405 { t.Errorf("handleTelemetry GET status = %d, want 405", resp.StatusCode) } } func TestHandleTelemetry_BadPayload(t *testing.T) { setupTestDB(t) defer cleanupTestDB() req := httptest.NewRequest("POST", "/telemetry", strings.NewReader("not json")) w := httptest.NewRecorder() handleTelemetry(w, req) resp := w.Result() if resp.StatusCode != 400 { t.Errorf("handleTelemetry bad payload status = %d, want 400", resp.StatusCode) } } func TestHandleTelemetry_ValidPayload(t *testing.T) { setupTestDB(t) defer cleanupTestDB() // Temporarily disable mTLS for this test by clearing caPool oldCAPool := caPool caPool = nil defer func() { caPool = oldCAPool }() payload := map[string]interface{}{ "node_id": "test-node-1", "version": "1.0.0", "hostname": "test-host", "uptime_seconds": 3600, "cpu_percent": 25.5, "memory_total_mb": 8192, "memory_used_mb": 4096, "disk_total_mb": 100000, "disk_used_mb": 50000, "load_1m": 0.5, "vault_count": 5, "vault_size_mb": 10.5, "vault_entries": 100, "mode": "commercial", } body, _ := json.Marshal(payload) req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body)) w := httptest.NewRecorder() handleTelemetry(w, req) resp := w.Result() if resp.StatusCode != 200 { t.Errorf("handleTelemetry valid payload status = %d, want 200", resp.StatusCode) } // Verify data was written var count int err := db.QueryRow("SELECT COUNT(*) FROM telemetry WHERE node_id = ?", "test-node-1").Scan(&count) if err != nil { t.Fatalf("Failed to query telemetry: %v", err) } if count != 1 { t.Errorf("telemetry count = %d, want 1", count) } } func TestHandleTelemetry_MissingNodeID(t *testing.T) { setupTestDB(t) defer cleanupTestDB() // Temporarily disable mTLS for this test oldCAPool := caPool caPool = nil defer func() { caPool = oldCAPool }() payload := map[string]interface{}{ "version": "1.0.0", // Missing node_id and hostname } body, _ := json.Marshal(payload) req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body)) w := httptest.NewRecorder() handleTelemetry(w, req) resp := w.Result() if resp.StatusCode != 400 { t.Errorf("handleTelemetry missing node_id status = %d, want 400", resp.StatusCode) } } func TestLoadCA(t *testing.T) { // Test with non-existent file err := loadCA("/nonexistent/path/ca.crt") if err == nil { t.Error("loadCA with non-existent file should error") } // Test with invalid PEM content tmpFile, err := os.CreateTemp("", "ca-*.crt") if err != nil { t.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(tmpFile.Name()) tmpFile.WriteString("not valid pem") tmpFile.Close() err = loadCA(tmpFile.Name()) if err == nil { t.Error("loadCA with invalid PEM should error") } } func TestSetupTLS(t *testing.T) { // Test with nil caPool caPool = nil config := setupTLS() if config != nil { t.Error("setupTLS with nil caPool should return nil") } // Test with valid caPool // Create a temp CA file with dummy cert (won't validate but tests parsing) tmpFile, err := os.CreateTemp("", "ca-*.crt") if err != nil { t.Fatalf("Failed to create temp file: %v", err) } defer os.Remove(tmpFile.Name()) // Write a dummy CA cert dummyCert := `-----BEGIN CERTIFICATE----- MIIBkTCB+wIJAKHBfpE -----END CERTIFICATE-----` tmpFile.WriteString(dummyCert) tmpFile.Close() // This will fail to parse but sets up the test _ = loadCA(tmpFile.Name()) // caPool might be nil or set, just verify setupTLS doesn't panic _ = setupTLS() } func TestRouteHandler(t *testing.T) { setupTestDB(t) defer cleanupTestDB() // Disable mTLS for route tests (TLS is tested separately) oldCAPool := caPool caPool = nil defer func() { caPool = oldCAPool }() tests := []struct { path string wantStatus int }{ {"/health", 200}, {"/unknown", 200}, // tarpit returns 200 then holds connection } for _, tt := range tests { t.Run(tt.path, func(t *testing.T) { req := httptest.NewRequest("GET", tt.path, nil) w := httptest.NewRecorder() // For unknown paths, tarpit runs asynchronously if tt.path == "/unknown" { go routeHandler(w, req) time.Sleep(50 * time.Millisecond) resp := w.Result() if resp.StatusCode != 200 { t.Errorf("routeHandler %s status = %d, want 200", tt.path, resp.StatusCode) } } else { routeHandler(w, req) resp := w.Result() if resp.StatusCode != tt.wantStatus { t.Errorf("routeHandler %s status = %d, want %d", tt.path, resp.StatusCode, tt.wantStatus) } } }) } } func TestAlertOutage_Disabled(t *testing.T) { // Ensure no env vars are set os.Unsetenv("NTFY_ALERT_URL") os.Unsetenv("NTFY_ALERT_TOKEN") // Should not panic and should log only alertOutage("test-node", "test-host", 60, false) alertOutage("test-node", "test-host", 0, true) } func TestEnsureTables(t *testing.T) { setupTestDB(t) defer cleanupTestDB() // Verify tables exist by querying them tables := []string{"telemetry", "uptime_spans", "maintenance"} for _, table := range tables { var name string err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name) if err != nil { t.Errorf("Table %s should exist: %v", table, err) } if name != table { t.Errorf("Table name = %s, want %s", name, table) } } } // Test that mTLS enforcement works type mockResponseWriter struct { headers http.Header status int written bool } func (m *mockResponseWriter) Header() http.Header { return m.headers } func (m *mockResponseWriter) Write(p []byte) (int, error) { m.written = true return len(p), nil } func (m *mockResponseWriter) WriteHeader(status int) { m.status = status } func TestMTLSRequired(t *testing.T) { // This test documents that mTLS is now mandatory // The main() function will fail if CA chain is not present // We verify the setupTLS function returns a proper config when CA is loaded // Create a proper test CA pool caPool = x509.NewCertPool() config := setupTLS() if config == nil { t.Error("setupTLS should return config when caPool is set") } if config.ClientAuth != tls.RequireAndVerifyClientCert { t.Errorf("ClientAuth = %v, want RequireAndVerifyClientCert", config.ClientAuth) } if config.MinVersion != tls.VersionTLS13 { t.Errorf("MinVersion = %d, want TLS13", config.MinVersion) } }