clavitor/clavis/clavis-telemetry/main_test.go

378 lines
8.9 KiB
Go

//go:build commercial
package main
import (
"bytes"
"crypto/tls"
"crypto/x509"
"database/sql"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"strings"
"testing"
"time"
_ "github.com/mattn/go-sqlite3"
)
// setupTestDB creates an in-memory database for testing
func setupTestDB(t *testing.T) {
var err error
db, err = sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open test database: %v", err)
}
ensureTables()
}
// cleanupTestDB closes the test database
func cleanupTestDB() {
if db != nil {
db.Close()
}
}
func TestTarpit(t *testing.T) {
// tarpit holds connection for 30 seconds - test that it responds initially
req := httptest.NewRequest("GET", "/unknown", nil)
w := httptest.NewRecorder()
// Use a goroutine since tarpit blocks
done := make(chan bool)
go func() {
tarpit(w, req)
done <- true
}()
// Check initial response comes through quickly
time.Sleep(100 * time.Millisecond)
resp := w.Result()
if resp.StatusCode != 200 {
t.Errorf("tarpit status = %d, want 200", resp.StatusCode)
}
if resp.Header.Get("Content-Type") != "text/plain" {
t.Errorf("tarpit content-type = %s, want text/plain", resp.Header.Get("Content-Type"))
}
}
func TestHandleHealth(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handleHealth(w, req)
resp := w.Result()
if resp.StatusCode != 200 {
t.Errorf("handleHealth status = %d, want 200", resp.StatusCode)
}
var result map[string]interface{}
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
t.Fatalf("Failed to decode health response: %v", err)
}
if result["status"] != "ok" {
t.Errorf("handleHealth status = %v, want ok", result["status"])
}
if result["db"] != "ok" {
t.Errorf("handleHealth db = %v, want ok", result["db"])
}
}
func TestHandleHealth_DBError(t *testing.T) {
// Don't setup DB - should return error
if db != nil {
db.Close()
db = nil
}
// Create a closed database to simulate failure
var err error
db, err = sql.Open("sqlite3", ":memory:")
if err != nil {
t.Fatalf("Failed to open test database: %v", err)
}
db.Close() // Close immediately to force errors
req := httptest.NewRequest("GET", "/health", nil)
w := httptest.NewRecorder()
handleHealth(w, req)
resp := w.Result()
if resp.StatusCode != 503 {
t.Errorf("handleHealth with bad DB status = %d, want 503", resp.StatusCode)
}
}
func TestHandleTelemetry_MethodNotAllowed(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
req := httptest.NewRequest("GET", "/telemetry", nil)
w := httptest.NewRecorder()
handleTelemetry(w, req)
resp := w.Result()
if resp.StatusCode != 405 {
t.Errorf("handleTelemetry GET status = %d, want 405", resp.StatusCode)
}
}
func TestHandleTelemetry_BadPayload(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
req := httptest.NewRequest("POST", "/telemetry", strings.NewReader("not json"))
w := httptest.NewRecorder()
handleTelemetry(w, req)
resp := w.Result()
if resp.StatusCode != 400 {
t.Errorf("handleTelemetry bad payload status = %d, want 400", resp.StatusCode)
}
}
func TestHandleTelemetry_ValidPayload(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
// Temporarily disable mTLS for this test by clearing caPool
oldCAPool := caPool
caPool = nil
defer func() { caPool = oldCAPool }()
payload := map[string]interface{}{
"node_id": "test-node-1",
"version": "1.0.0",
"hostname": "test-host",
"uptime_seconds": 3600,
"cpu_percent": 25.5,
"memory_total_mb": 8192,
"memory_used_mb": 4096,
"disk_total_mb": 100000,
"disk_used_mb": 50000,
"load_1m": 0.5,
"vault_count": 5,
"vault_size_mb": 10.5,
"vault_entries": 100,
"mode": "commercial",
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body))
w := httptest.NewRecorder()
handleTelemetry(w, req)
resp := w.Result()
if resp.StatusCode != 200 {
t.Errorf("handleTelemetry valid payload status = %d, want 200", resp.StatusCode)
}
// Verify data was written
var count int
err := db.QueryRow("SELECT COUNT(*) FROM telemetry WHERE node_id = ?", "test-node-1").Scan(&count)
if err != nil {
t.Fatalf("Failed to query telemetry: %v", err)
}
if count != 1 {
t.Errorf("telemetry count = %d, want 1", count)
}
}
func TestHandleTelemetry_MissingNodeID(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
// Temporarily disable mTLS for this test
oldCAPool := caPool
caPool = nil
defer func() { caPool = oldCAPool }()
payload := map[string]interface{}{
"version": "1.0.0",
// Missing node_id and hostname
}
body, _ := json.Marshal(payload)
req := httptest.NewRequest("POST", "/telemetry", bytes.NewReader(body))
w := httptest.NewRecorder()
handleTelemetry(w, req)
resp := w.Result()
if resp.StatusCode != 400 {
t.Errorf("handleTelemetry missing node_id status = %d, want 400", resp.StatusCode)
}
}
func TestLoadCA(t *testing.T) {
// Test with non-existent file
err := loadCA("/nonexistent/path/ca.crt")
if err == nil {
t.Error("loadCA with non-existent file should error")
}
// Test with invalid PEM content
tmpFile, err := os.CreateTemp("", "ca-*.crt")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
tmpFile.WriteString("not valid pem")
tmpFile.Close()
err = loadCA(tmpFile.Name())
if err == nil {
t.Error("loadCA with invalid PEM should error")
}
}
func TestSetupTLS(t *testing.T) {
// Test with nil caPool
caPool = nil
config := setupTLS()
if config != nil {
t.Error("setupTLS with nil caPool should return nil")
}
// Test with valid caPool
// Create a temp CA file with dummy cert (won't validate but tests parsing)
tmpFile, err := os.CreateTemp("", "ca-*.crt")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
defer os.Remove(tmpFile.Name())
// Write a dummy CA cert
dummyCert := `-----BEGIN CERTIFICATE-----
MIIBkTCB+wIJAKHBfpE
-----END CERTIFICATE-----`
tmpFile.WriteString(dummyCert)
tmpFile.Close()
// This will fail to parse but sets up the test
_ = loadCA(tmpFile.Name())
// caPool might be nil or set, just verify setupTLS doesn't panic
_ = setupTLS()
}
func TestRouteHandler(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
// Disable mTLS for route tests (TLS is tested separately)
oldCAPool := caPool
caPool = nil
defer func() { caPool = oldCAPool }()
tests := []struct {
path string
wantStatus int
}{
{"/health", 200},
{"/unknown", 200}, // tarpit returns 200 then holds connection
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
req := httptest.NewRequest("GET", tt.path, nil)
w := httptest.NewRecorder()
// For unknown paths, tarpit runs asynchronously
if tt.path == "/unknown" {
go routeHandler(w, req)
time.Sleep(50 * time.Millisecond)
resp := w.Result()
if resp.StatusCode != 200 {
t.Errorf("routeHandler %s status = %d, want 200", tt.path, resp.StatusCode)
}
} else {
routeHandler(w, req)
resp := w.Result()
if resp.StatusCode != tt.wantStatus {
t.Errorf("routeHandler %s status = %d, want %d", tt.path, resp.StatusCode, tt.wantStatus)
}
}
})
}
}
func TestAlertOutage_Disabled(t *testing.T) {
// Ensure no env vars are set
os.Unsetenv("NTFY_ALERT_URL")
os.Unsetenv("NTFY_ALERT_TOKEN")
// Should not panic and should log only
alertOutage("test-node", "test-host", 60, false)
alertOutage("test-node", "test-host", 0, true)
}
func TestEnsureTables(t *testing.T) {
setupTestDB(t)
defer cleanupTestDB()
// Verify tables exist by querying them
tables := []string{"telemetry", "uptime_spans", "maintenance"}
for _, table := range tables {
var name string
err := db.QueryRow("SELECT name FROM sqlite_master WHERE type='table' AND name=?", table).Scan(&name)
if err != nil {
t.Errorf("Table %s should exist: %v", table, err)
}
if name != table {
t.Errorf("Table name = %s, want %s", name, table)
}
}
}
// Test that mTLS enforcement works
type mockResponseWriter struct {
headers http.Header
status int
written bool
}
func (m *mockResponseWriter) Header() http.Header {
return m.headers
}
func (m *mockResponseWriter) Write(p []byte) (int, error) {
m.written = true
return len(p), nil
}
func (m *mockResponseWriter) WriteHeader(status int) {
m.status = status
}
func TestMTLSRequired(t *testing.T) {
// This test documents that mTLS is now mandatory
// The main() function will fail if CA chain is not present
// We verify the setupTLS function returns a proper config when CA is loaded
// Create a proper test CA pool
caPool = x509.NewCertPool()
config := setupTLS()
if config == nil {
t.Error("setupTLS should return config when caPool is set")
}
if config.ClientAuth != tls.RequireAndVerifyClientCert {
t.Errorf("ClientAuth = %v, want RequireAndVerifyClientCert", config.ClientAuth)
}
if config.MinVersion != tls.VersionTLS13 {
t.Errorf("MinVersion = %d, want TLS13", config.MinVersion)
}
}