clavitor/clavis/clavis-vault/api/middleware.go

413 lines
12 KiB
Go

package api
import (
"context"
"encoding/base64"
"encoding/json"
"log"
"net"
"net/http"
"os"
"path/filepath"
"strings"
"sync"
"time"
"github.com/johanj/clavitor/lib"
)
// base64Decode handles both standard and url-safe base64 (with or without padding).
func base64Decode(s string) ([]byte, error) {
// Try url-safe first (no padding), then standard
s = strings.TrimRight(s, "=")
b, err := base64.RawURLEncoding.DecodeString(s)
if err != nil {
b, err = base64.RawStdEncoding.DecodeString(s)
}
return b, err
}
// base64UrlEncode encodes bytes as base64url without padding.
func base64UrlEncode(b []byte) string {
return base64.RawURLEncoding.EncodeToString(b)
}
type contextKey string
const (
ctxActor contextKey = "actor"
ctxSession contextKey = "session"
ctxAgent contextKey = "agent"
ctxDB contextKey = "db"
ctxVaultKey contextKey = "vault_key"
ctxVaultID contextKey = "vault_id"
)
// ActorFromContext returns the actor type from request context.
func ActorFromContext(ctx context.Context) string {
v, ok := ctx.Value(ctxActor).(string)
if !ok {
return lib.ActorWeb
}
return v
}
// SessionFromContext returns the session from request context.
func SessionFromContext(ctx context.Context) *lib.Session {
v, _ := ctx.Value(ctxSession).(*lib.Session)
return v
}
// AgentFromContext returns the agent from request context (nil if not an agent request).
func AgentFromContext(ctx context.Context) *lib.Agent {
v, _ := ctx.Value(ctxAgent).(*lib.Agent)
return v
}
// DBFromContext returns the vault DB from request context (nil in self-hosted mode).
func DBFromContext(ctx context.Context) *lib.DB {
v, _ := ctx.Value(ctxDB).(*lib.DB)
return v
}
// VaultKeyFromContext returns the derived vault key from request context (nil in self-hosted mode).
func VaultKeyFromContext(ctx context.Context) []byte {
v, _ := ctx.Value(ctxVaultKey).([]byte)
return v
}
// VaultIDFromContext returns the vault ID from request context (0 in self-hosted mode).
func VaultIDFromContext(ctx context.Context) int64 {
v, _ := ctx.Value(ctxVaultID).(int64)
return v
}
// L1Middleware extracts L1 from Bearer token and opens the vault DB.
// Fully stateless: L1 arrives with every request, is used, then forgotten.
// No sessions, no stored keys. The server has zero keys of its own.
//
// Self-hosted mode: finds vault DB by globbing clavitor-* files.
// Hosted mode: finds vault DB by base64url(L1[0:4]) → filename.
func L1Middleware(dataDir string) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
auth := r.Header.Get("Authorization")
// No auth = unauthenticated request (registration, login begin, etc.)
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
// Try to open vault DB without L1 (for unauthenticated endpoints)
matches, _ := filepath.Glob(filepath.Join(dataDir, "clavitor-*"))
if len(matches) > 0 {
db, err := lib.OpenDB(matches[0])
if err == nil {
defer db.Close()
ctx := context.WithValue(r.Context(), ctxDB, db)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
// Also try legacy .db files for migration
matches, _ = filepath.Glob(filepath.Join(dataDir, "????????.db"))
if len(matches) > 0 {
db, err := lib.OpenDB(matches[0])
if err == nil {
defer db.Close()
ctx := context.WithValue(r.Context(), ctxDB, db)
next.ServeHTTP(w, r.WithContext(ctx))
return
}
}
next.ServeHTTP(w, r)
return
}
// Decode Bearer → L1 (8 bytes)
bearerB64 := strings.TrimPrefix(auth, "Bearer ")
l1Raw, err := base64Decode(bearerB64)
if err != nil || len(l1Raw) != 8 {
ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer")
return
}
// Normalize L1: 8 bytes → 16 bytes (same as crypto.js normalize_key)
l1Key := lib.NormalizeKey(l1Raw)
// Find vault DB by first 4 bytes of L1
vaultPrefix := base64UrlEncode(l1Raw[:4])
dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix)
log.Printf("L1 auth: l1_hex=%x prefix=%s path=%s", l1Raw, vaultPrefix, dbPath)
var db *lib.DB
if _, err := os.Stat(dbPath); err == nil {
db, err = lib.OpenDB(dbPath)
if err != nil {
log.Printf("vault open error (%s): %v", dbPath, err)
ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault")
return
}
}
if db == nil {
ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found")
return
}
defer db.Close()
ctx := context.WithValue(r.Context(), ctxDB, db)
ctx = context.WithValue(ctx, ctxVaultKey, l1Key)
ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb)
// Agent-specific auth via X-Agent header
agentName := r.Header.Get("X-Agent")
if agentName != "" {
ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent)
// Check vault-level lock
vaultLock, _ := lib.VaultLockGet(db)
if vaultLock != nil && vaultLock.Locked {
ErrorResponse(w, http.StatusForbidden, "vault_locked",
"Vault is locked: "+vaultLock.LockedReason)
return
}
agent, err := lib.AgentGetByName(db, agentName)
if err != nil {
ErrorResponse(w, http.StatusInternalServerError, "agent_error", "Agent lookup failed")
return
}
if agent == nil {
ErrorResponse(w, http.StatusUnauthorized, "unknown_agent", "Unknown agent")
return
}
if agent.Status == lib.AgentStatusRevoked {
ErrorResponse(w, http.StatusUnauthorized, "agent_revoked", "Agent has been revoked")
return
}
if agent.Status == lib.AgentStatusLocked {
ErrorResponse(w, http.StatusForbidden, "agent_locked",
"Agent is locked: "+agent.LockedReason)
return
}
// IP whitelist check
ip := realIP(r)
if len(agent.IPWhitelist) == 1 && agent.IPWhitelist[0] == "init" {
lib.AgentUpdateWhitelist(db, int64(agent.ID), []string{ip})
agent.IPWhitelist = []string{ip}
lib.AuditLog(db, &lib.AuditEvent{
Action: "agent_ip_init", Actor: lib.ActorAgent,
IPAddr: ip, Title: agent.Name + " → " + ip,
})
}
if !agentIPAllowed(agent, ip) {
lib.VaultLockSet(db, true, "Non-whitelisted IP "+ip+" for agent "+agent.Name)
lib.AgentUpdateStatus(db, int64(agent.ID), lib.AgentStatusLocked, "non-whitelisted IP: "+ip)
lib.AuditLog(db, &lib.AuditEvent{
Action: lib.ActionIPViolation, Actor: lib.ActorAgent,
IPAddr: ip, Title: agent.Name,
})
ErrorResponse(w, http.StatusForbidden, "vault_locked",
"Access from non-whitelisted IP. Vault locked.")
return
}
// Rate limit
reqPath := r.URL.Path
lib.AgentRequestLog(db, int64(agent.ID), ip, reqPath)
countMin, _ := lib.AgentRequestCountMinute(db, int64(agent.ID))
countHour, _ := lib.AgentRequestCountHour(db, int64(agent.ID))
if countMin > agent.RateLimitMinute || countHour > agent.RateLimitHour {
lib.AgentUpdateStatus(db, int64(agent.ID), lib.AgentStatusLocked, "rate limit exceeded")
lib.AuditLog(db, &lib.AuditEvent{
Action: lib.ActionRateExceeded, Actor: lib.ActorAgent,
IPAddr: ip, Title: agent.Name,
})
ErrorResponse(w, http.StatusTooManyRequests, "agent_locked",
"Rate limit exceeded. Agent locked.")
return
}
lib.AgentUpdateLastUsed(db, int64(agent.ID), ip)
ctx = context.WithValue(ctx, ctxAgent, agent)
}
next.ServeHTTP(w, r.WithContext(ctx))
})
}
}
// agentIPAllowed checks if the given IP is allowed by the agent's whitelist.
// Supports: single IPs, CIDR notation, DNS names.
func agentIPAllowed(agent *lib.Agent, ip string) bool {
if len(agent.IPWhitelist) == 0 {
return true
}
parsedIP := net.ParseIP(ip)
for _, entry := range agent.IPWhitelist {
// CIDR
if strings.Contains(entry, "/") {
_, cidr, err := net.ParseCIDR(entry)
if err == nil && parsedIP != nil && cidr.Contains(parsedIP) {
return true
}
continue
}
// Single IP
if net.ParseIP(entry) != nil {
if entry == ip {
return true
}
continue
}
// DNS name — resolve and compare
addrs, err := net.LookupHost(entry)
if err == nil {
for _, addr := range addrs {
if addr == ip {
return true
}
}
}
}
return false
}
// LoggingMiddleware logs HTTP requests.
func LoggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
wrapped := &statusWriter{ResponseWriter: w, status: 200}
next.ServeHTTP(wrapped, r)
log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start))
})
}
type statusWriter struct {
http.ResponseWriter
status int
}
func (w *statusWriter) WriteHeader(code int) {
w.status = code
w.ResponseWriter.WriteHeader(code)
}
// RateLimitMiddleware implements per-IP rate limiting.
func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler {
var mu sync.Mutex
clients := make(map[string]*rateLimitEntry)
go func() {
for {
time.Sleep(time.Minute)
mu.Lock()
now := time.Now()
for ip, entry := range clients {
if now.Sub(entry.windowStart) > time.Minute {
delete(clients, ip)
}
}
mu.Unlock()
}
}()
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := realIP(r)
mu.Lock()
entry, exists := clients[ip]
now := time.Now()
if !exists || now.Sub(entry.windowStart) > time.Minute {
entry = &rateLimitEntry{windowStart: now, count: 0}
clients[ip] = entry
}
entry.count++
count := entry.count
mu.Unlock()
if count > requestsPerMinute {
ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests")
return
}
next.ServeHTTP(w, r)
})
}
}
type rateLimitEntry struct {
windowStart time.Time
count int
}
// CORSMiddleware handles CORS headers.
func CORSMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := r.Header.Get("Origin")
// Allow localhost and 127.0.0.1 for development
if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) {
w.Header().Set("Access-Control-Allow-Origin", origin)
w.Header().Set("Vary", "Origin")
}
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
w.Header().Set("Access-Control-Max-Age", "86400")
if r.Method == "OPTIONS" {
w.WriteHeader(http.StatusNoContent)
return
}
next.ServeHTTP(w, r)
})
}
// SecurityHeadersMiddleware adds security headers to all responses.
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("X-Frame-Options", "DENY")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
// CSP allowing localhost and 127.0.0.1 for development
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' localhost 127.0.0.1")
next.ServeHTTP(w, r)
})
}
// ErrorResponse sends a standard JSON error response.
func ErrorResponse(w http.ResponseWriter, status int, code, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(map[string]string{
"error": message,
"code": code,
})
}
// JSONResponse sends a standard JSON success response.
func JSONResponse(w http.ResponseWriter, status int, data any) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(status)
json.NewEncoder(w).Encode(data)
}
func realIP(r *http.Request) string {
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
parts := strings.SplitN(xff, ",", 2)
return strings.TrimSpace(parts[0])
}
if xri := r.Header.Get("X-Real-IP"); xri != "" {
return xri
}
addr := r.RemoteAddr
if idx := strings.LastIndex(addr, ":"); idx != -1 {
return addr[:idx]
}
return addr
}