378 lines
8.9 KiB
Go
378 lines
8.9 KiB
Go
//go:build commercial
|
|
|
|
package main
|
|
|
|
import (
|
|
"bytes"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"database/sql"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
"time"
|
|
|
|
_ "github.com/mattn/go-sqlite3"
|
|
)
|
|
|
|
// setupTestDB creates an in-memory database for testing
|
|
func setupTestDB(t *testing.T) {
|
|
var err error
|
|
db, err = sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open test database: %v", err)
|
|
}
|
|
ensureTables()
|
|
}
|
|
|
|
// cleanupTestDB closes the test database
|
|
func cleanupTestDB() {
|
|
if db != nil {
|
|
db.Close()
|
|
}
|
|
}
|
|
|
|
func TestTarpit(t *testing.T) {
|
|
// tarpit holds connection for 30 seconds - test that it responds initially
|
|
req := httptest.NewRequest("GET", "/unknown", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// Use a goroutine since tarpit blocks
|
|
done := make(chan bool)
|
|
go func() {
|
|
tarpit(w, req)
|
|
done <- true
|
|
}()
|
|
|
|
// Check initial response comes through quickly
|
|
time.Sleep(100 * time.Millisecond)
|
|
resp := w.Result()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("tarpit status = %d, want 200", resp.StatusCode)
|
|
}
|
|
if resp.Header.Get("Content-Type") != "text/plain" {
|
|
t.Errorf("tarpit content-type = %s, want text/plain", resp.Header.Get("Content-Type"))
|
|
}
|
|
}
|
|
|
|
func TestHandleHealth(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
req := httptest.NewRequest("GET", "/health", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handleHealth(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("handleHealth status = %d, want 200", resp.StatusCode)
|
|
}
|
|
|
|
var result map[string]interface{}
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
t.Fatalf("Failed to decode health response: %v", err)
|
|
}
|
|
|
|
if result["status"] != "ok" {
|
|
t.Errorf("handleHealth status = %v, want ok", result["status"])
|
|
}
|
|
if result["db"] != "ok" {
|
|
t.Errorf("handleHealth db = %v, want ok", result["db"])
|
|
}
|
|
}
|
|
|
|
func TestHandleHealth_DBError(t *testing.T) {
|
|
// Don't setup DB - should return error
|
|
if db != nil {
|
|
db.Close()
|
|
db = nil
|
|
}
|
|
|
|
// Create a closed database to simulate failure
|
|
var err error
|
|
db, err = sql.Open("sqlite3", ":memory:")
|
|
if err != nil {
|
|
t.Fatalf("Failed to open test database: %v", err)
|
|
}
|
|
db.Close() // Close immediately to force errors
|
|
|
|
req := httptest.NewRequest("GET", "/health", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handleHealth(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 503 {
|
|
t.Errorf("handleHealth with bad DB status = %d, want 503", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHandleTelemetry_MethodNotAllowed(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
req := httptest.NewRequest("GET", "/telemetry", nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
handleTelemetry(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 405 {
|
|
t.Errorf("handleTelemetry GET status = %d, want 405", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHandleTelemetry_BadPayload(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
req := httptest.NewRequest("POST", "/telemetry", strings.NewReader("not json"))
|
|
w := httptest.NewRecorder()
|
|
|
|
handleTelemetry(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 400 {
|
|
t.Errorf("handleTelemetry bad payload status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestHandleTelemetry_ValidPayload(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
// Temporarily disable mTLS for this test by clearing caPool
|
|
oldCAPool := caPool
|
|
caPool = nil
|
|
defer func() { caPool = oldCAPool }()
|
|
|
|
payload := map[string]interface{}{
|
|
"node_id": "test-node-1",
|
|
"version": "1.0.0",
|
|
"hostname": "test-host",
|
|
"uptime_seconds": 3600,
|
|
"cpu_percent": 25.5,
|
|
"memory_total_mb": 8192,
|
|
"memory_used_mb": 4096,
|
|
"disk_total_mb": 100000,
|
|
"disk_used_mb": 50000,
|
|
"load_1m": 0.5,
|
|
"vault_count": 5,
|
|
"vault_size_mb": 10.5,
|
|
"vault_entries": 100,
|
|
"mode": "commercial",
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body))
|
|
w := httptest.NewRecorder()
|
|
|
|
handleTelemetry(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("handleTelemetry valid payload status = %d, want 200", resp.StatusCode)
|
|
}
|
|
|
|
// Verify data was written
|
|
var count int
|
|
err := db.QueryRow("SELECT COUNT(*) FROM telemetry WHERE node_id = ?", "test-node-1").Scan(&count)
|
|
if err != nil {
|
|
t.Fatalf("Failed to query telemetry: %v", err)
|
|
}
|
|
if count != 1 {
|
|
t.Errorf("telemetry count = %d, want 1", count)
|
|
}
|
|
}
|
|
|
|
func TestHandleTelemetry_MissingNodeID(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
// Temporarily disable mTLS for this test
|
|
oldCAPool := caPool
|
|
caPool = nil
|
|
defer func() { caPool = oldCAPool }()
|
|
|
|
payload := map[string]interface{}{
|
|
"version": "1.0.0",
|
|
// Missing node_id and hostname
|
|
}
|
|
|
|
body, _ := json.Marshal(payload)
|
|
req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body))
|
|
w := httptest.NewRecorder()
|
|
|
|
handleTelemetry(w, req)
|
|
|
|
resp := w.Result()
|
|
if resp.StatusCode != 400 {
|
|
t.Errorf("handleTelemetry missing node_id status = %d, want 400", resp.StatusCode)
|
|
}
|
|
}
|
|
|
|
func TestLoadCA(t *testing.T) {
|
|
// Test with non-existent file
|
|
err := loadCA("/nonexistent/path/ca.crt")
|
|
if err == nil {
|
|
t.Error("loadCA with non-existent file should error")
|
|
}
|
|
|
|
// Test with invalid PEM content
|
|
tmpFile, err := os.CreateTemp("", "ca-*.crt")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
tmpFile.WriteString("not valid pem")
|
|
tmpFile.Close()
|
|
|
|
err = loadCA(tmpFile.Name())
|
|
if err == nil {
|
|
t.Error("loadCA with invalid PEM should error")
|
|
}
|
|
}
|
|
|
|
func TestSetupTLS(t *testing.T) {
|
|
// Test with nil caPool
|
|
caPool = nil
|
|
config := setupTLS()
|
|
if config != nil {
|
|
t.Error("setupTLS with nil caPool should return nil")
|
|
}
|
|
|
|
// Test with valid caPool
|
|
// Create a temp CA file with dummy cert (won't validate but tests parsing)
|
|
tmpFile, err := os.CreateTemp("", "ca-*.crt")
|
|
if err != nil {
|
|
t.Fatalf("Failed to create temp file: %v", err)
|
|
}
|
|
defer os.Remove(tmpFile.Name())
|
|
|
|
// Write a dummy CA cert
|
|
dummyCert := `-----BEGIN CERTIFICATE-----
|
|
MIIBkTCB+wIJAKHBfpE
|
|
-----END CERTIFICATE-----`
|
|
tmpFile.WriteString(dummyCert)
|
|
tmpFile.Close()
|
|
|
|
// This will fail to parse but sets up the test
|
|
_ = loadCA(tmpFile.Name())
|
|
// caPool might be nil or set, just verify setupTLS doesn't panic
|
|
_ = setupTLS()
|
|
}
|
|
|
|
func TestRouteHandler(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
// Disable mTLS for route tests (TLS is tested separately)
|
|
oldCAPool := caPool
|
|
caPool = nil
|
|
defer func() { caPool = oldCAPool }()
|
|
|
|
tests := []struct {
|
|
path string
|
|
wantStatus int
|
|
}{
|
|
{"/health", 200},
|
|
{"/unknown", 200}, // tarpit returns 200 then holds connection
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.path, func(t *testing.T) {
|
|
req := httptest.NewRequest("GET", tt.path, nil)
|
|
w := httptest.NewRecorder()
|
|
|
|
// For unknown paths, tarpit runs asynchronously
|
|
if tt.path == "/unknown" {
|
|
go routeHandler(w, req)
|
|
time.Sleep(50 * time.Millisecond)
|
|
resp := w.Result()
|
|
if resp.StatusCode != 200 {
|
|
t.Errorf("routeHandler %s status = %d, want 200", tt.path, resp.StatusCode)
|
|
}
|
|
} else {
|
|
routeHandler(w, req)
|
|
resp := w.Result()
|
|
if resp.StatusCode != tt.wantStatus {
|
|
t.Errorf("routeHandler %s status = %d, want %d", tt.path, resp.StatusCode, tt.wantStatus)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestAlertOutage_Disabled(t *testing.T) {
|
|
// Ensure no env vars are set
|
|
os.Unsetenv("NTFY_ALERT_URL")
|
|
os.Unsetenv("NTFY_ALERT_TOKEN")
|
|
|
|
// Should not panic and should log only
|
|
alertOutage("test-node", "test-host", 60, false)
|
|
alertOutage("test-node", "test-host", 0, true)
|
|
}
|
|
|
|
func TestEnsureTables(t *testing.T) {
|
|
setupTestDB(t)
|
|
defer cleanupTestDB()
|
|
|
|
// Verify tables exist by querying them
|
|
tables := []string{"telemetry", "uptime_spans", "maintenance"}
|
|
for _, table := range tables {
|
|
var name string
|
|
err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
|
|
if err != nil {
|
|
t.Errorf("Table %s should exist: %v", table, err)
|
|
}
|
|
if name != table {
|
|
t.Errorf("Table name = %s, want %s", name, table)
|
|
}
|
|
}
|
|
}
|
|
|
|
// Test that mTLS enforcement works
|
|
type mockResponseWriter struct {
|
|
headers http.Header
|
|
status int
|
|
written bool
|
|
}
|
|
|
|
func (m *mockResponseWriter) Header() http.Header {
|
|
return m.headers
|
|
}
|
|
|
|
func (m *mockResponseWriter) Write(p []byte) (int, error) {
|
|
m.written = true
|
|
return len(p), nil
|
|
}
|
|
|
|
func (m *mockResponseWriter) WriteHeader(status int) {
|
|
m.status = status
|
|
}
|
|
|
|
func TestMTLSRequired(t *testing.T) {
|
|
// This test documents that mTLS is now mandatory
|
|
// The main() function will fail if CA chain is not present
|
|
// We verify the setupTLS function returns a proper config when CA is loaded
|
|
|
|
// Create a proper test CA pool
|
|
caPool = x509.NewCertPool()
|
|
|
|
config := setupTLS()
|
|
if config == nil {
|
|
t.Error("setupTLS should return config when caPool is set")
|
|
}
|
|
if config.ClientAuth != tls.RequireAndVerifyClientCert {
|
|
t.Errorf("ClientAuth = %v, want RequireAndVerifyClientCert", config.ClientAuth)
|
|
}
|
|
if config.MinVersion != tls.VersionTLS13 {
|
|
t.Errorf("MinVersion = %d, want TLS13", config.MinVersion)
|
|
}
|
|
}
|