391 lines
11 KiB
Go
391 lines
11 KiB
Go
package api
|
|
|
|
import (
|
|
"context"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"log"
|
|
"net/http"
|
|
"os"
|
|
"path/filepath"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/johanj/clavitor/lib"
|
|
)
|
|
|
|
// base64Decode handles both standard and url-safe base64 (with or without padding).
|
|
func base64Decode(s string) ([]byte, error) {
|
|
// Try url-safe first (no padding), then standard
|
|
s = strings.TrimRight(s, "=")
|
|
b, err := base64.RawURLEncoding.DecodeString(s)
|
|
if err != nil {
|
|
b, err = base64.RawStdEncoding.DecodeString(s)
|
|
}
|
|
return b, err
|
|
}
|
|
|
|
// base64UrlEncode encodes bytes as base64url without padding.
|
|
func base64UrlEncode(b []byte) string {
|
|
return base64.RawURLEncoding.EncodeToString(b)
|
|
}
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ctxActor contextKey = "actor"
|
|
ctxSession contextKey = "session"
|
|
ctxAgent contextKey = "agent"
|
|
ctxDB contextKey = "db"
|
|
ctxVaultKey contextKey = "vault_key"
|
|
ctxVaultID contextKey = "vault_id"
|
|
)
|
|
|
|
// ActorFromContext returns the actor type from request context.
|
|
func ActorFromContext(ctx context.Context) string {
|
|
v, ok := ctx.Value(ctxActor).(string)
|
|
if !ok {
|
|
return lib.ActorWeb
|
|
}
|
|
return v
|
|
}
|
|
|
|
// SessionFromContext returns the session from request context.
|
|
func SessionFromContext(ctx context.Context) *lib.Session {
|
|
v, _ := ctx.Value(ctxSession).(*lib.Session)
|
|
return v
|
|
}
|
|
|
|
// AgentFromContext returns the agent from request context (nil if not an agent request).
|
|
func AgentFromContext(ctx context.Context) *lib.Agent {
|
|
v, _ := ctx.Value(ctxAgent).(*lib.Agent)
|
|
return v
|
|
}
|
|
|
|
// DBFromContext returns the vault DB from request context (nil in self-hosted mode).
|
|
func DBFromContext(ctx context.Context) *lib.DB {
|
|
v, _ := ctx.Value(ctxDB).(*lib.DB)
|
|
return v
|
|
}
|
|
|
|
// VaultKeyFromContext returns the derived vault key from request context (nil in self-hosted mode).
|
|
func VaultKeyFromContext(ctx context.Context) []byte {
|
|
v, _ := ctx.Value(ctxVaultKey).([]byte)
|
|
return v
|
|
}
|
|
|
|
// VaultIDFromContext returns the vault ID from request context (0 in self-hosted mode).
|
|
func VaultIDFromContext(ctx context.Context) int64 {
|
|
v, _ := ctx.Value(ctxVaultID).(int64)
|
|
return v
|
|
}
|
|
|
|
// L1Middleware extracts L1 from Bearer token and opens the vault DB.
|
|
// Fully stateless: L1 arrives with every request, is used, then forgotten.
|
|
// No sessions, no stored keys. The server has zero keys of its own.
|
|
//
|
|
// Self-hosted mode: finds vault DB by globbing clavitor-* files.
|
|
// Hosted mode: finds vault DB by base64url(L1[0:4]) → filename.
|
|
func L1Middleware(dataDir string) func(http.Handler) http.Handler {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
auth := r.Header.Get("Authorization")
|
|
|
|
// No auth = unauthenticated request (registration, login begin, etc.)
|
|
if auth == "" || !strings.HasPrefix(auth, "Bearer ") {
|
|
// Try to open vault DB without L1 (for unauthenticated endpoints)
|
|
matches, _ := filepath.Glob(filepath.Join(dataDir, "clavitor-*"))
|
|
if len(matches) > 0 {
|
|
db, err := lib.OpenDB(matches[0])
|
|
if err == nil {
|
|
defer db.Close()
|
|
ctx := context.WithValue(r.Context(), ctxDB, db)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
}
|
|
// Also try legacy .db files for migration
|
|
matches, _ = filepath.Glob(filepath.Join(dataDir, "????????.db"))
|
|
if len(matches) > 0 {
|
|
db, err := lib.OpenDB(matches[0])
|
|
if err == nil {
|
|
defer db.Close()
|
|
ctx := context.WithValue(r.Context(), ctxDB, db)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
bearerVal := strings.TrimPrefix(auth, "Bearer ")
|
|
|
|
var l1Raw []byte
|
|
var agent *lib.Agent
|
|
|
|
if strings.HasPrefix(bearerVal, "cvt_") {
|
|
// --- Agent token: cvt_ prefix ---
|
|
// Extract L1 and look up agent by token hash.
|
|
var hash string
|
|
var err error
|
|
l1Raw, hash, err = lib.ParseToken(bearerVal)
|
|
if err != nil {
|
|
ErrorResponse(w, http.StatusUnauthorized, "invalid_token", "Invalid agent token")
|
|
return
|
|
}
|
|
|
|
// Open vault DB from L1
|
|
l1Key := lib.NormalizeKey(l1Raw)
|
|
vaultPrefix := base64UrlEncode(l1Raw[:4])
|
|
dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix)
|
|
|
|
db, err := lib.OpenDB(dbPath)
|
|
if err != nil {
|
|
ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found")
|
|
return
|
|
}
|
|
defer db.Close()
|
|
|
|
// Look up agent
|
|
agent, err = lib.AgentGetByToken(db, hash)
|
|
if err != nil {
|
|
ErrorResponse(w, http.StatusInternalServerError, "agent_error", "Agent lookup failed")
|
|
return
|
|
}
|
|
if agent == nil {
|
|
ErrorResponse(w, http.StatusUnauthorized, "unknown_token", "Invalid or revoked token")
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), ctxDB, db)
|
|
ctx = context.WithValue(ctx, ctxVaultKey, l1Key)
|
|
ctx = context.WithValue(ctx, ctxActor, lib.ActorAgent)
|
|
ctx = context.WithValue(ctx, ctxAgent, agent)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
|
|
} else {
|
|
// --- Legacy L1 bearer (web UI / extension) ---
|
|
// 8 bytes base64url = vault owner, full access, no agent.
|
|
l1Raw, err := base64Decode(bearerVal)
|
|
if err != nil || len(l1Raw) != 8 {
|
|
ErrorResponse(w, http.StatusUnauthorized, "invalid_l1", "Invalid L1 key in Bearer")
|
|
return
|
|
}
|
|
|
|
l1Key := lib.NormalizeKey(l1Raw)
|
|
vaultPrefix := base64UrlEncode(l1Raw[:4])
|
|
dbPath := filepath.Join(dataDir, "clavitor-"+vaultPrefix)
|
|
|
|
var db *lib.DB
|
|
if _, err := os.Stat(dbPath); err == nil {
|
|
db, err = lib.OpenDB(dbPath)
|
|
if err != nil {
|
|
log.Printf("vault open error (%s): %v", dbPath, err)
|
|
ErrorResponse(w, http.StatusInternalServerError, "db_error", "Failed to open vault")
|
|
return
|
|
}
|
|
}
|
|
if db == nil {
|
|
ErrorResponse(w, http.StatusNotFound, "vault_not_found", "Vault not found")
|
|
return
|
|
}
|
|
defer db.Close()
|
|
|
|
ctx := context.WithValue(r.Context(), ctxDB, db)
|
|
ctx = context.WithValue(ctx, ctxVaultKey, l1Key)
|
|
ctx = context.WithValue(ctx, ctxActor, lib.ActorWeb)
|
|
// No agent in context = vault owner (full access)
|
|
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
// IsAgentRequest returns true if the request was made with a cvt_ agent token.
|
|
func IsAgentRequest(r *http.Request) bool {
|
|
return AgentFromContext(r.Context()) != nil
|
|
}
|
|
|
|
// LoggingMiddleware logs HTTP requests.
|
|
func LoggingMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
start := time.Now()
|
|
wrapped := &statusWriter{ResponseWriter: w, status: 200}
|
|
next.ServeHTTP(wrapped, r)
|
|
log.Printf("%s %s %d %s", r.Method, r.URL.Path, wrapped.status, time.Since(start))
|
|
})
|
|
}
|
|
|
|
type statusWriter struct {
|
|
http.ResponseWriter
|
|
status int
|
|
}
|
|
|
|
func (w *statusWriter) WriteHeader(code int) {
|
|
w.status = code
|
|
w.ResponseWriter.WriteHeader(code)
|
|
}
|
|
|
|
// RateLimitMiddleware implements per-IP rate limiting.
|
|
func RateLimitMiddleware(requestsPerMinute int) func(http.Handler) http.Handler {
|
|
var mu sync.Mutex
|
|
clients := make(map[string]*rateLimitEntry)
|
|
|
|
go func() {
|
|
for {
|
|
time.Sleep(time.Minute)
|
|
mu.Lock()
|
|
now := time.Now()
|
|
for ip, entry := range clients {
|
|
if now.Sub(entry.windowStart) > time.Minute {
|
|
delete(clients, ip)
|
|
}
|
|
}
|
|
mu.Unlock()
|
|
}
|
|
}()
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ip := realIP(r)
|
|
|
|
mu.Lock()
|
|
entry, exists := clients[ip]
|
|
now := time.Now()
|
|
if !exists || now.Sub(entry.windowStart) > time.Minute {
|
|
entry = &rateLimitEntry{windowStart: now, count: 0}
|
|
clients[ip] = entry
|
|
}
|
|
entry.count++
|
|
count := entry.count
|
|
mu.Unlock()
|
|
|
|
if count > requestsPerMinute {
|
|
ErrorResponse(w, http.StatusTooManyRequests, "rate_limited", "Too many requests")
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|
|
|
|
type rateLimitEntry struct {
|
|
windowStart time.Time
|
|
count int
|
|
}
|
|
|
|
// CORSMiddleware handles CORS headers.
|
|
func CORSMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := r.Header.Get("Origin")
|
|
|
|
// Allow localhost and 127.0.0.1 for development
|
|
if origin != "" && (strings.Contains(origin, "localhost") || strings.Contains(origin, "127.0.0.1")) {
|
|
w.Header().Set("Access-Control-Allow-Origin", origin)
|
|
w.Header().Set("Vary", "Origin")
|
|
}
|
|
|
|
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
|
w.Header().Set("Access-Control-Allow-Headers", "Authorization, Content-Type")
|
|
w.Header().Set("Access-Control-Max-Age", "86400")
|
|
|
|
if r.Method == "OPTIONS" {
|
|
w.WriteHeader(http.StatusNoContent)
|
|
return
|
|
}
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// SecurityHeadersMiddleware adds security headers to all responses.
|
|
func SecurityHeadersMiddleware(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
w.Header().Set("X-Frame-Options", "DENY")
|
|
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
|
w.Header().Set("Referrer-Policy", "strict-origin-when-cross-origin")
|
|
// CSP allowing localhost and 127.0.0.1 for development
|
|
w.Header().Set("Content-Security-Policy", "default-src 'self'; script-src 'self' 'unsafe-inline' https://cdn.tailwindcss.com; style-src 'self' 'unsafe-inline' https://fonts.googleapis.com; font-src 'self' data: https://fonts.gstatic.com; img-src 'self' data: https:; connect-src 'self' localhost 127.0.0.1")
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
// ErrorResponse sends a standard JSON error response.
|
|
func ErrorResponse(w http.ResponseWriter, status int, code, message string) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(map[string]string{
|
|
"error": message,
|
|
"code": code,
|
|
})
|
|
}
|
|
|
|
// JSONResponse sends a standard JSON success response.
|
|
func JSONResponse(w http.ResponseWriter, status int, data any) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
w.WriteHeader(status)
|
|
json.NewEncoder(w).Encode(data)
|
|
}
|
|
|
|
// tarpitHandler holds unrecognized requests for 30 seconds.
|
|
// Drips one byte per second to keep the connection alive and waste
|
|
// scanner resources. Capped at 1000 concurrent tarpit slots —
|
|
// beyond that, connections are dropped immediately.
|
|
var (
|
|
tarpitSem = make(chan struct{}, 1000)
|
|
)
|
|
|
|
func tarpitHandler(w http.ResponseWriter, r *http.Request) {
|
|
select {
|
|
case tarpitSem <- struct{}{}:
|
|
defer func() { <-tarpitSem }()
|
|
default:
|
|
// Tarpit full — drop immediately, no response
|
|
if hj, ok := w.(http.Hijacker); ok {
|
|
conn, _, err := hj.Hijack()
|
|
if err == nil {
|
|
conn.Close()
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
log.Printf("tarpit: %s %s from %s", r.Method, r.URL.Path, realIP(r))
|
|
|
|
// Chunked response: drip one space per second for 30s
|
|
w.Header().Set("Content-Type", "text/plain")
|
|
w.WriteHeader(200)
|
|
flusher, canFlush := w.(http.Flusher)
|
|
for i := 0; i < 30; i++ {
|
|
_, err := w.Write([]byte(" "))
|
|
if err != nil {
|
|
return // client gave up
|
|
}
|
|
if canFlush {
|
|
flusher.Flush()
|
|
}
|
|
time.Sleep(time.Second)
|
|
}
|
|
}
|
|
|
|
func realIP(r *http.Request) string {
|
|
if xff := r.Header.Get("X-Forwarded-For"); xff != "" {
|
|
parts := strings.SplitN(xff, ",", 2)
|
|
return strings.TrimSpace(parts[0])
|
|
}
|
|
if xri := r.Header.Get("X-Real-IP"); xri != "" {
|
|
return xri
|
|
}
|
|
addr := r.RemoteAddr
|
|
if idx := strings.LastIndex(addr, ":"); idx != -1 {
|
|
return addr[:idx]
|
|
}
|
|
return addr
|
|
}
|