feat: add MITM proxy mode with LLM policy evaluation (C-017)
- New package clavis/clavis-vault/proxy/ - HTTPS MITM proxy via HTTP CONNECT tunnel - Dynamic per-host TLS cert generation (signed by local CA) - CA cert auto-generated at DataDir/proxy/ca.crt (1-year validity) - Per-cert cache with 24h TTL - Credential injection hook (stub — DB wiring next) - LLM policy evaluation hook (stub — OpenAI-compatible API) - L2 (identity/card) fields are never injectable by design - cmd/clavitor/main.go: new flags --proxy Enable proxy mode (default: off) --proxy-addr Listen addr (default: 127.0.0.1:19840) --proxy-llm Enable LLM policy evaluation --proxy-llm-url LLM base URL (OpenAI-compat) --proxy-llm-key LLM API key --proxy-llm-model LLM model name Usage: clavitor --proxy export HTTP_PROXY=http://127.0.0.1:19840 HTTPS_PROXY=http://127.0.0.1:19840 # Install DataDir/proxy/ca.crt in OS trust store for HTTPS MITM
This commit is contained in:
parent
e425cec150
commit
dcdca016db
|
|
@ -9,6 +9,7 @@ import (
|
||||||
|
|
||||||
"github.com/johanj/clavitor/api"
|
"github.com/johanj/clavitor/api"
|
||||||
"github.com/johanj/clavitor/lib"
|
"github.com/johanj/clavitor/lib"
|
||||||
|
"github.com/johanj/clavitor/proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
//go:embed web
|
//go:embed web
|
||||||
|
|
@ -28,6 +29,15 @@ func main() {
|
||||||
telemetryFreq := flag.Int("telemetry-freq", envInt("TELEMETRY_FREQ", 0), "Telemetry POST interval in seconds (0 = disabled)")
|
telemetryFreq := flag.Int("telemetry-freq", envInt("TELEMETRY_FREQ", 0), "Telemetry POST interval in seconds (0 = disabled)")
|
||||||
telemetryHost := flag.String("telemetry-host", envStr("TELEMETRY_HOST", ""), "Telemetry endpoint URL")
|
telemetryHost := flag.String("telemetry-host", envStr("TELEMETRY_HOST", ""), "Telemetry endpoint URL")
|
||||||
telemetryToken := flag.String("telemetry-token", envStr("TELEMETRY_TOKEN", ""), "Bearer token for telemetry endpoint")
|
telemetryToken := flag.String("telemetry-token", envStr("TELEMETRY_TOKEN", ""), "Bearer token for telemetry endpoint")
|
||||||
|
|
||||||
|
// Proxy mode flags
|
||||||
|
proxyEnabled := flag.Bool("proxy", envBool("PROXY_ENABLED", false), "Enable MITM proxy mode (set HTTP_PROXY=http://127.0.0.1:19840 in agent)")
|
||||||
|
proxyAddr := flag.String("proxy-addr", envStr("PROXY_ADDR", "127.0.0.1:19840"), "Proxy listen address")
|
||||||
|
proxyLLM := flag.Bool("proxy-llm", envBool("PROXY_LLM", false), "Enable LLM policy evaluation in proxy")
|
||||||
|
proxyLLMURL := flag.String("proxy-llm-url", envStr("PROXY_LLM_URL", ""), "LLM API base URL for proxy policy (OpenAI-compatible)")
|
||||||
|
proxyLLMKey := flag.String("proxy-llm-key", envStr("PROXY_LLM_KEY", ""), "LLM API key for proxy policy")
|
||||||
|
proxyLLMModel := flag.String("proxy-llm-model", envStr("PROXY_LLM_MODEL", ""), "LLM model for proxy policy evaluation")
|
||||||
|
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
cfg, err := lib.LoadConfig()
|
cfg, err := lib.LoadConfig()
|
||||||
|
|
@ -48,6 +58,28 @@ func main() {
|
||||||
// Start automatic backup scheduler (3 weekly + 3 monthly, rotated)
|
// Start automatic backup scheduler (3 weekly + 3 monthly, rotated)
|
||||||
lib.StartBackupTimer(cfg.DataDir)
|
lib.StartBackupTimer(cfg.DataDir)
|
||||||
|
|
||||||
|
// Start proxy if enabled
|
||||||
|
if *proxyEnabled {
|
||||||
|
px, err := proxy.New(proxy.Config{
|
||||||
|
ListenAddr: *proxyAddr,
|
||||||
|
DataDir: cfg.DataDir,
|
||||||
|
LLMEnabled: *proxyLLM,
|
||||||
|
LLMBaseURL: *proxyLLMURL,
|
||||||
|
LLMAPIKey: *proxyLLMKey,
|
||||||
|
LLMModel: *proxyLLMModel,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("proxy: %v", err)
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
if err := px.ListenAndServe(); err != nil {
|
||||||
|
log.Printf("proxy: stopped: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
log.Printf("proxy: CA cert at %s — install in OS trust store", px.CACertPath())
|
||||||
|
log.Printf("proxy: set HTTP_PROXY=http://%s HTTPS_PROXY=http://%s in agent environment", *proxyAddr, *proxyAddr)
|
||||||
|
}
|
||||||
|
|
||||||
router := api.NewRouter(cfg, webFS)
|
router := api.NewRouter(cfg, webFS)
|
||||||
|
|
||||||
addr := ":" + cfg.Port
|
addr := ":" + cfg.Port
|
||||||
|
|
@ -64,6 +96,13 @@ func envStr(key, fallback string) string {
|
||||||
return fallback
|
return fallback
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func envBool(key string, fallback bool) bool {
|
||||||
|
if v := os.Getenv(key); v != "" {
|
||||||
|
return v == "1" || v == "true" || v == "yes"
|
||||||
|
}
|
||||||
|
return fallback
|
||||||
|
}
|
||||||
|
|
||||||
func envInt(key string, fallback int) int {
|
func envInt(key string, fallback int) int {
|
||||||
if v := os.Getenv(key); v != "" {
|
if v := os.Getenv(key); v != "" {
|
||||||
if n, err := strconv.Atoi(v); err == nil {
|
if n, err := strconv.Atoi(v); err == nil {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,10 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func newBufReader(conn net.Conn) *bufio.Reader {
|
||||||
|
return bufio.NewReader(conn)
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,452 @@
|
||||||
|
// Package proxy implements an HTTPS MITM proxy with LLM-based policy evaluation.
|
||||||
|
//
|
||||||
|
// Architecture:
|
||||||
|
// - Agent sets HTTP_PROXY=http://localhost:19840 (or configured port)
|
||||||
|
// - For plain HTTP: proxy injects Authorization/headers, forwards
|
||||||
|
// - For HTTPS: proxy performs CONNECT tunnel, generates per-host TLS cert (signed by local CA)
|
||||||
|
// - Before injecting credentials: optional LLM policy evaluation (intent check)
|
||||||
|
//
|
||||||
|
// Credential injection:
|
||||||
|
// - Scans request for placeholder patterns: {{clavitor.entry_title.field_label}}
|
||||||
|
// - Also injects via per-host credential rules stored in vault
|
||||||
|
// - Tier check: L2 fields are never injected (identity/card data)
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds proxy configuration.
|
||||||
|
type Config struct {
|
||||||
|
// ListenAddr is the proxy listen address, e.g. "127.0.0.1:19840"
|
||||||
|
ListenAddr string
|
||||||
|
|
||||||
|
// DataDir is the vault data directory (for CA cert/key storage)
|
||||||
|
DataDir string
|
||||||
|
|
||||||
|
// VaultKey is the L1 decryption key (to read credentials for injection)
|
||||||
|
VaultKey []byte
|
||||||
|
|
||||||
|
// DBPath is path to the vault SQLite database
|
||||||
|
DBPath string
|
||||||
|
|
||||||
|
// LLMEnabled enables LLM-based intent evaluation before credential injection
|
||||||
|
LLMEnabled bool
|
||||||
|
|
||||||
|
// LLMBaseURL is the LLM API base URL (OpenAI-compatible)
|
||||||
|
LLMBaseURL string
|
||||||
|
|
||||||
|
// LLMAPIKey is the API key for LLM requests
|
||||||
|
LLMAPIKey string
|
||||||
|
|
||||||
|
// LLMModel is the model to use for policy evaluation
|
||||||
|
LLMModel string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Proxy is the MITM proxy server.
|
||||||
|
type Proxy struct {
|
||||||
|
cfg Config
|
||||||
|
ca *tls.Certificate
|
||||||
|
caCert *x509.Certificate
|
||||||
|
caKey *rsa.PrivateKey
|
||||||
|
certMu sync.Mutex
|
||||||
|
certs map[string]*tls.Certificate // hostname → generated cert (cache)
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new Proxy. Generates or loads the CA cert from DataDir.
|
||||||
|
func New(cfg Config) (*Proxy, error) {
|
||||||
|
p := &Proxy{
|
||||||
|
cfg: cfg,
|
||||||
|
certs: make(map[string]*tls.Certificate),
|
||||||
|
}
|
||||||
|
if err := p.loadOrCreateCA(); err != nil {
|
||||||
|
return nil, fmt.Errorf("proxy CA: %w", err)
|
||||||
|
}
|
||||||
|
return p, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenAndServe starts the proxy server. Blocks until stopped.
|
||||||
|
func (p *Proxy) ListenAndServe() error {
|
||||||
|
ln, err := net.Listen("tcp", p.cfg.ListenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("proxy listen %s: %w", p.cfg.ListenAddr, err)
|
||||||
|
}
|
||||||
|
log.Printf("proxy: listening on %s (LLM policy: %v)", p.cfg.ListenAddr, p.cfg.LLMEnabled)
|
||||||
|
srv := &http.Server{
|
||||||
|
Handler: p,
|
||||||
|
ReadTimeout: 30 * time.Second,
|
||||||
|
WriteTimeout: 30 * time.Second,
|
||||||
|
}
|
||||||
|
return srv.Serve(ln)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP handles all incoming proxy requests.
|
||||||
|
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if r.Method == http.MethodConnect {
|
||||||
|
p.handleCONNECT(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.handleHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleHTTP handles plain HTTP proxy requests.
|
||||||
|
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// Remove proxy-specific headers
|
||||||
|
r.RequestURI = ""
|
||||||
|
r.Header.Del("Proxy-Connection")
|
||||||
|
r.Header.Del("Proxy-Authenticate")
|
||||||
|
r.Header.Del("Proxy-Authorization")
|
||||||
|
|
||||||
|
// Inject credentials if applicable
|
||||||
|
if err := p.injectCredentials(r); err != nil {
|
||||||
|
log.Printf("proxy: credential injection error for %s: %v", r.URL.Host, err)
|
||||||
|
// Non-fatal: continue without injection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward the request
|
||||||
|
rp := &httputil.ReverseProxy{
|
||||||
|
Director: func(req *http.Request) {},
|
||||||
|
}
|
||||||
|
rp.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleCONNECT handles HTTPS CONNECT tunnel requests.
|
||||||
|
func (p *Proxy) handleCONNECT(w http.ResponseWriter, r *http.Request) {
|
||||||
|
host := r.Host
|
||||||
|
if !strings.Contains(host, ":") {
|
||||||
|
host = host + ":443"
|
||||||
|
}
|
||||||
|
hostname, _, _ := net.SplitHostPort(host)
|
||||||
|
|
||||||
|
// Acknowledge the CONNECT
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
|
||||||
|
// Hijack the connection
|
||||||
|
hijacker, ok := w.(http.Hijacker)
|
||||||
|
if !ok {
|
||||||
|
log.Printf("proxy: CONNECT hijack not supported")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
clientConn, _, err := hijacker.Hijack()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("proxy: CONNECT hijack error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer clientConn.Close()
|
||||||
|
|
||||||
|
// Generate a certificate for this hostname
|
||||||
|
cert, err := p.certForHost(hostname)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("proxy: cert generation failed for %s: %v", hostname, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap client connection in TLS (using our MITM cert)
|
||||||
|
tlsCfg := &tls.Config{
|
||||||
|
Certificates: []tls.Certificate{*cert},
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
}
|
||||||
|
tlsClientConn := tls.Server(clientConn, tlsCfg)
|
||||||
|
defer tlsClientConn.Close()
|
||||||
|
if err := tlsClientConn.Handshake(); err != nil {
|
||||||
|
log.Printf("proxy: TLS handshake failed for %s: %v", hostname, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Connect to real upstream
|
||||||
|
upstreamConn, err := tls.Dial("tcp", host, &tls.Config{
|
||||||
|
ServerName: hostname,
|
||||||
|
MinVersion: tls.VersionTLS12,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("proxy: upstream dial failed for %s: %v", host, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer upstreamConn.Close()
|
||||||
|
|
||||||
|
// Intercept HTTP traffic between client and upstream
|
||||||
|
p.interceptHTTP(tlsClientConn, upstreamConn, hostname)
|
||||||
|
}
|
||||||
|
|
||||||
|
// interceptHTTP reads HTTP requests from the client, injects credentials, forwards to upstream.
|
||||||
|
func (p *Proxy) interceptHTTP(clientConn net.Conn, upstreamConn net.Conn, hostname string) {
|
||||||
|
// Use Go's http.ReadRequest to parse the client's request
|
||||||
|
clientReader := newBufReader(clientConn)
|
||||||
|
|
||||||
|
for {
|
||||||
|
req, err := http.ReadRequest(clientReader)
|
||||||
|
if err != nil {
|
||||||
|
if err != io.EOF {
|
||||||
|
log.Printf("proxy: read request error for %s: %v", hostname, err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set the correct URL for upstream forwarding
|
||||||
|
req.URL.Scheme = "https"
|
||||||
|
req.URL.Host = hostname
|
||||||
|
req.RequestURI = ""
|
||||||
|
|
||||||
|
// Inject credentials
|
||||||
|
if err := p.injectCredentials(req); err != nil {
|
||||||
|
log.Printf("proxy: credential injection error for %s: %v", hostname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward to upstream
|
||||||
|
if err := req.Write(upstreamConn); err != nil {
|
||||||
|
log.Printf("proxy: upstream write error for %s: %v", hostname, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read upstream response and forward to client
|
||||||
|
upstreamReader := newBufReader(upstreamConn)
|
||||||
|
resp, err := http.ReadResponse(upstreamReader, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("proxy: upstream read error for %s: %v", hostname, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
if err := resp.Write(clientConn); err != nil {
|
||||||
|
log.Printf("proxy: client write error for %s: %v", hostname, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// injectCredentials scans the request for credential placeholders and injects them.
|
||||||
|
// Placeholder format: {{clavitor.entry_title.field_label}} in headers, URL, or body.
|
||||||
|
// Also applies host-based automatic injection rules from vault.
|
||||||
|
// L2 (identity/card) fields are NEVER injected.
|
||||||
|
func (p *Proxy) injectCredentials(r *http.Request) error {
|
||||||
|
if p.cfg.VaultKey == nil {
|
||||||
|
return nil // No vault key — skip injection
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for LLM policy evaluation
|
||||||
|
if p.cfg.LLMEnabled {
|
||||||
|
allowed, reason, err := p.evaluatePolicy(r)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("proxy: LLM policy eval error: %v (allowing)", err)
|
||||||
|
} else if !allowed {
|
||||||
|
log.Printf("proxy: LLM policy DENIED %s %s: %s", r.Method, r.URL, reason)
|
||||||
|
return fmt.Errorf("policy denied: %s", reason)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Implement placeholder substitution once vault DB integration is wired in.
|
||||||
|
// Pattern: scan r.Header values, r.URL, r.Body for {{clavitor.TITLE.FIELD}}
|
||||||
|
// Lookup entry by title (case-insensitive), get field by label, verify Tier != L2
|
||||||
|
// Replace placeholder with decrypted field value.
|
||||||
|
//
|
||||||
|
// Auto-injection (host rules):
|
||||||
|
// Vault entries can specify "proxy_inject_hosts": ["api.github.com"] in metadata
|
||||||
|
// When a request matches, inject the entry's L1 fields as headers per a configured map.
|
||||||
|
//
|
||||||
|
// This stub returns nil — no injection until DB wiring is complete.
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// evaluatePolicy calls the configured LLM to evaluate whether this request
|
||||||
|
// is consistent with the expected behavior of an AI agent (vs. exfiltration/abuse).
|
||||||
|
func (p *Proxy) evaluatePolicy(r *http.Request) (allowed bool, reason string, err error) {
|
||||||
|
if p.cfg.LLMBaseURL == "" || p.cfg.LLMAPIKey == "" {
|
||||||
|
return true, "LLM not configured", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a concise request summary for the LLM
|
||||||
|
summary := fmt.Sprintf("Method: %s\nHost: %s\nPath: %s\nContent-Type: %s",
|
||||||
|
r.Method, r.Host, r.URL.Path,
|
||||||
|
r.Header.Get("Content-Type"))
|
||||||
|
|
||||||
|
prompt := `You are a security policy evaluator for an AI agent credential proxy.
|
||||||
|
|
||||||
|
The following outbound HTTP request is about to have credentials injected and be forwarded.
|
||||||
|
Evaluate whether this request is consistent with normal AI agent behavior (coding, API calls, deployment)
|
||||||
|
vs. suspicious activity (credential exfiltration, unexpected destinations, data harvesting).
|
||||||
|
|
||||||
|
Request summary:
|
||||||
|
` + summary + `
|
||||||
|
|
||||||
|
Respond with JSON only: {"allowed": true/false, "reason": "one sentence"}`
|
||||||
|
|
||||||
|
_ = prompt // Used when LLM call is implemented below
|
||||||
|
|
||||||
|
// TODO: Implement actual LLM call using cfg.LLMBaseURL + cfg.LLMAPIKey + cfg.LLMModel
|
||||||
|
// For now: always allow (policy eval is opt-in, not blocking by default)
|
||||||
|
// Real implementation: POST to /v1/chat/completions, parse JSON response
|
||||||
|
return true, "policy evaluation not yet implemented", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// certForHost returns a TLS certificate for the given hostname, generating one if needed.
|
||||||
|
func (p *Proxy) certForHost(hostname string) (*tls.Certificate, error) {
|
||||||
|
p.certMu.Lock()
|
||||||
|
defer p.certMu.Unlock()
|
||||||
|
|
||||||
|
if cert, ok := p.certs[hostname]; ok {
|
||||||
|
// Check if cert is still valid (> 1 hour remaining)
|
||||||
|
if time.Until(cert.Leaf.NotAfter) > time.Hour {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate a new cert signed by our CA
|
||||||
|
cert, err := p.generateCert(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
p.certs[hostname] = cert
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateCert generates a TLS cert for the given hostname, signed by the proxy CA.
|
||||||
|
func (p *Proxy) generateCert(hostname string) (*tls.Certificate, error) {
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: hostname},
|
||||||
|
DNSNames: []string{hostname},
|
||||||
|
NotBefore: time.Now().Add(-time.Minute),
|
||||||
|
NotAfter: time.Now().Add(24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add IP SAN if hostname is an IP
|
||||||
|
if ip := net.ParseIP(hostname); ip != nil {
|
||||||
|
tmpl.IPAddresses = []net.IP{ip}
|
||||||
|
tmpl.DNSNames = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, p.caCert, &key.PublicKey, p.caKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
leaf, err := x509.ParseCertificate(certDER)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tlsCert := &tls.Certificate{
|
||||||
|
Certificate: [][]byte{certDER},
|
||||||
|
PrivateKey: key,
|
||||||
|
Leaf: leaf,
|
||||||
|
}
|
||||||
|
return tlsCert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadOrCreateCA loads the proxy CA cert/key from DataDir, or generates new ones.
|
||||||
|
func (p *Proxy) loadOrCreateCA() error {
|
||||||
|
caDir := filepath.Join(p.cfg.DataDir, "proxy")
|
||||||
|
if err := os.MkdirAll(caDir, 0700); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
certPath := filepath.Join(caDir, "ca.crt")
|
||||||
|
keyPath := filepath.Join(caDir, "ca.key")
|
||||||
|
|
||||||
|
// Try to load existing CA
|
||||||
|
if _, err := os.Stat(certPath); err == nil {
|
||||||
|
certPEM, err := os.ReadFile(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read CA cert: %w", err)
|
||||||
|
}
|
||||||
|
keyPEM, err := os.ReadFile(keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("read CA key: %w", err)
|
||||||
|
}
|
||||||
|
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse CA keypair: %w", err)
|
||||||
|
}
|
||||||
|
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse CA cert: %w", err)
|
||||||
|
}
|
||||||
|
// Check expiry — regenerate if < 7 days left
|
||||||
|
if time.Until(tlsCert.Leaf.NotAfter) < 7*24*time.Hour {
|
||||||
|
log.Printf("proxy: CA cert expires soon (%s), regenerating", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
|
||||||
|
} else {
|
||||||
|
p.ca = &tlsCert
|
||||||
|
p.caCert = tlsCert.Leaf
|
||||||
|
p.caKey = tlsCert.PrivateKey.(*rsa.PrivateKey)
|
||||||
|
log.Printf("proxy: loaded CA cert (expires %s)", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate new CA
|
||||||
|
log.Printf("proxy: generating new CA cert...")
|
||||||
|
key, err := rsa.GenerateKey(rand.Reader, 4096)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate CA key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: serial,
|
||||||
|
Subject: pkix.Name{CommonName: "Clavitor Proxy CA", Organization: []string{"Clavitor"}},
|
||||||
|
NotBefore: time.Now().Add(-time.Minute),
|
||||||
|
NotAfter: time.Now().Add(365 * 24 * time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
MaxPathLen: 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create CA cert: %w", err)
|
||||||
|
}
|
||||||
|
leaf, _ := x509.ParseCertificate(certDER)
|
||||||
|
|
||||||
|
// Write to disk
|
||||||
|
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write CA cert: %w", err)
|
||||||
|
}
|
||||||
|
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
|
||||||
|
certFile.Close()
|
||||||
|
|
||||||
|
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write CA key: %w", err)
|
||||||
|
}
|
||||||
|
pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
|
||||||
|
keyFile.Close()
|
||||||
|
|
||||||
|
p.ca = &tls.Certificate{Certificate: [][]byte{certDER}, PrivateKey: key, Leaf: leaf}
|
||||||
|
p.caCert = leaf
|
||||||
|
p.caKey = key
|
||||||
|
|
||||||
|
log.Printf("proxy: CA cert generated at %s (install in OS trust store or pass --proxy-ca)", certPath)
|
||||||
|
log.Printf("proxy: CA cert path: %s", certPath)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CACertPath returns the path to the proxy CA certificate (for user installation).
|
||||||
|
func (p *Proxy) CACertPath() string {
|
||||||
|
return filepath.Join(p.cfg.DataDir, "proxy", "ca.crt")
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue