clavitor/clavis/clavis-vault/proxy/proxy.go

453 lines
14 KiB
Go

// Package proxy implements an HTTPS MITM proxy with LLM-based policy evaluation.
//
// Architecture:
// - Agent sets HTTP_PROXY=http://localhost:19840 (or configured port)
// - For plain HTTP: proxy injects Authorization/headers, forwards
// - For HTTPS: proxy performs CONNECT tunnel, generates per-host TLS cert (signed by local CA)
// - Before injecting credentials: optional LLM policy evaluation (intent check)
//
// Credential injection:
// - Scans request for placeholder patterns: {{clavitor.entry_title.field_label}}
// - Also injects via per-host credential rules stored in vault
// - Tier check: L2 fields are never injected (identity/card data)
package proxy
import (
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"io"
"log"
"math/big"
"net"
"net/http"
"net/http/httputil"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
// Config holds proxy configuration.
type Config struct {
// ListenAddr is the proxy listen address, e.g. "127.0.0.1:19840"
ListenAddr string
// DataDir is the vault data directory (for CA cert/key storage)
DataDir string
// VaultKey is the L1 decryption key (to read credentials for injection)
VaultKey []byte
// DBPath is path to the vault SQLite database
DBPath string
// LLMEnabled enables LLM-based intent evaluation before credential injection
LLMEnabled bool
// LLMBaseURL is the LLM API base URL (OpenAI-compatible)
LLMBaseURL string
// LLMAPIKey is the API key for LLM requests
LLMAPIKey string
// LLMModel is the model to use for policy evaluation
LLMModel string
}
// Proxy is the MITM proxy server.
type Proxy struct {
cfg Config
ca *tls.Certificate
caCert *x509.Certificate
caKey *rsa.PrivateKey
certMu sync.Mutex
certs map[string]*tls.Certificate // hostname → generated cert (cache)
}
// New creates a new Proxy. Generates or loads the CA cert from DataDir.
func New(cfg Config) (*Proxy, error) {
p := &Proxy{
cfg: cfg,
certs: make(map[string]*tls.Certificate),
}
if err := p.loadOrCreateCA(); err != nil {
return nil, fmt.Errorf("proxy CA: %w", err)
}
return p, nil
}
// ListenAndServe starts the proxy server. Blocks until stopped.
func (p *Proxy) ListenAndServe() error {
ln, err := net.Listen("tcp", p.cfg.ListenAddr)
if err != nil {
return fmt.Errorf("proxy listen %s: %w", p.cfg.ListenAddr, err)
}
log.Printf("proxy: listening on %s (LLM policy: %v)", p.cfg.ListenAddr, p.cfg.LLMEnabled)
srv := &http.Server{
Handler: p,
ReadTimeout: 30 * time.Second,
WriteTimeout: 30 * time.Second,
}
return srv.Serve(ln)
}
// ServeHTTP handles all incoming proxy requests.
func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if r.Method == http.MethodConnect {
p.handleCONNECT(w, r)
return
}
p.handleHTTP(w, r)
}
// handleHTTP handles plain HTTP proxy requests.
func (p *Proxy) handleHTTP(w http.ResponseWriter, r *http.Request) {
// Remove proxy-specific headers
r.RequestURI = ""
r.Header.Del("Proxy-Connection")
r.Header.Del("Proxy-Authenticate")
r.Header.Del("Proxy-Authorization")
// Inject credentials if applicable
if err := p.injectCredentials(r); err != nil {
log.Printf("proxy: credential injection error for %s: %v", r.URL.Host, err)
// Non-fatal: continue without injection
}
// Forward the request
rp := &httputil.ReverseProxy{
Director: func(req *http.Request) {},
}
rp.ServeHTTP(w, r)
}
// handleCONNECT handles HTTPS CONNECT tunnel requests.
func (p *Proxy) handleCONNECT(w http.ResponseWriter, r *http.Request) {
host := r.Host
if !strings.Contains(host, ":") {
host = host + ":443"
}
hostname, _, _ := net.SplitHostPort(host)
// Acknowledge the CONNECT
w.WriteHeader(http.StatusOK)
// Hijack the connection
hijacker, ok := w.(http.Hijacker)
if !ok {
log.Printf("proxy: CONNECT hijack not supported")
return
}
clientConn, _, err := hijacker.Hijack()
if err != nil {
log.Printf("proxy: CONNECT hijack error: %v", err)
return
}
defer clientConn.Close()
// Generate a certificate for this hostname
cert, err := p.certForHost(hostname)
if err != nil {
log.Printf("proxy: cert generation failed for %s: %v", hostname, err)
return
}
// Wrap client connection in TLS (using our MITM cert)
tlsCfg := &tls.Config{
Certificates: []tls.Certificate{*cert},
MinVersion: tls.VersionTLS12,
}
tlsClientConn := tls.Server(clientConn, tlsCfg)
defer tlsClientConn.Close()
if err := tlsClientConn.Handshake(); err != nil {
log.Printf("proxy: TLS handshake failed for %s: %v", hostname, err)
return
}
// Connect to real upstream
upstreamConn, err := tls.Dial("tcp", host, &tls.Config{
ServerName: hostname,
MinVersion: tls.VersionTLS12,
})
if err != nil {
log.Printf("proxy: upstream dial failed for %s: %v", host, err)
return
}
defer upstreamConn.Close()
// Intercept HTTP traffic between client and upstream
p.interceptHTTP(tlsClientConn, upstreamConn, hostname)
}
// interceptHTTP reads HTTP requests from the client, injects credentials, forwards to upstream.
func (p *Proxy) interceptHTTP(clientConn net.Conn, upstreamConn net.Conn, hostname string) {
// Use Go's http.ReadRequest to parse the client's request
clientReader := newBufReader(clientConn)
for {
req, err := http.ReadRequest(clientReader)
if err != nil {
if err != io.EOF {
log.Printf("proxy: read request error for %s: %v", hostname, err)
}
return
}
// Set the correct URL for upstream forwarding
req.URL.Scheme = "https"
req.URL.Host = hostname
req.RequestURI = ""
// Inject credentials
if err := p.injectCredentials(req); err != nil {
log.Printf("proxy: credential injection error for %s: %v", hostname, err)
}
// Forward to upstream
if err := req.Write(upstreamConn); err != nil {
log.Printf("proxy: upstream write error for %s: %v", hostname, err)
return
}
// Read upstream response and forward to client
upstreamReader := newBufReader(upstreamConn)
resp, err := http.ReadResponse(upstreamReader, req)
if err != nil {
log.Printf("proxy: upstream read error for %s: %v", hostname, err)
return
}
defer resp.Body.Close()
if err := resp.Write(clientConn); err != nil {
log.Printf("proxy: client write error for %s: %v", hostname, err)
return
}
}
}
// injectCredentials scans the request for credential placeholders and injects them.
// Placeholder format: {{clavitor.entry_title.field_label}} in headers, URL, or body.
// Also applies host-based automatic injection rules from vault.
// L2 (identity/card) fields are NEVER injected.
func (p *Proxy) injectCredentials(r *http.Request) error {
if p.cfg.VaultKey == nil {
return nil // No vault key — skip injection
}
// Check for LLM policy evaluation
if p.cfg.LLMEnabled {
allowed, reason, err := p.evaluatePolicy(r)
if err != nil {
log.Printf("proxy: LLM policy eval error: %v (allowing)", err)
} else if !allowed {
log.Printf("proxy: LLM policy DENIED %s %s: %s", r.Method, r.URL, reason)
return fmt.Errorf("policy denied: %s", reason)
}
}
// TODO: Implement placeholder substitution once vault DB integration is wired in.
// Pattern: scan r.Header values, r.URL, r.Body for {{clavitor.TITLE.FIELD}}
// Lookup entry by title (case-insensitive), get field by label, verify Tier != L2
// Replace placeholder with decrypted field value.
//
// Auto-injection (host rules):
// Vault entries can specify "proxy_inject_hosts": ["api.github.com"] in metadata
// When a request matches, inject the entry's L1 fields as headers per a configured map.
//
// This stub returns nil — no injection until DB wiring is complete.
return nil
}
// evaluatePolicy calls the configured LLM to evaluate whether this request
// is consistent with the expected behavior of an AI agent (vs. exfiltration/abuse).
func (p *Proxy) evaluatePolicy(r *http.Request) (allowed bool, reason string, err error) {
if p.cfg.LLMBaseURL == "" || p.cfg.LLMAPIKey == "" {
return true, "LLM not configured", nil
}
// Build a concise request summary for the LLM
summary := fmt.Sprintf("Method: %s\nHost: %s\nPath: %s\nContent-Type: %s",
r.Method, r.Host, r.URL.Path,
r.Header.Get("Content-Type"))
prompt := `You are a security policy evaluator for an AI agent credential proxy.
The following outbound HTTP request is about to have credentials injected and be forwarded.
Evaluate whether this request is consistent with normal AI agent behavior (coding, API calls, deployment)
vs. suspicious activity (credential exfiltration, unexpected destinations, data harvesting).
Request summary:
` + summary + `
Respond with JSON only: {"allowed": true/false, "reason": "one sentence"}`
_ = prompt // Used when LLM call is implemented below
// TODO: Implement actual LLM call using cfg.LLMBaseURL + cfg.LLMAPIKey + cfg.LLMModel
// For now: always allow (policy eval is opt-in, not blocking by default)
// Real implementation: POST to /v1/chat/completions, parse JSON response
return true, "policy evaluation not yet implemented", nil
}
// certForHost returns a TLS certificate for the given hostname, generating one if needed.
func (p *Proxy) certForHost(hostname string) (*tls.Certificate, error) {
p.certMu.Lock()
defer p.certMu.Unlock()
if cert, ok := p.certs[hostname]; ok {
// Check if cert is still valid (> 1 hour remaining)
if time.Until(cert.Leaf.NotAfter) > time.Hour {
return cert, nil
}
}
// Generate a new cert signed by our CA
cert, err := p.generateCert(hostname)
if err != nil {
return nil, err
}
p.certs[hostname] = cert
return cert, nil
}
// generateCert generates a TLS cert for the given hostname, signed by the proxy CA.
func (p *Proxy) generateCert(hostname string) (*tls.Certificate, error) {
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, fmt.Errorf("generate key: %w", err)
}
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: hostname},
DNSNames: []string{hostname},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
}
// Add IP SAN if hostname is an IP
if ip := net.ParseIP(hostname); ip != nil {
tmpl.IPAddresses = []net.IP{ip}
tmpl.DNSNames = nil
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, p.caCert, &key.PublicKey, p.caKey)
if err != nil {
return nil, fmt.Errorf("create cert: %w", err)
}
leaf, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse cert: %w", err)
}
tlsCert := &tls.Certificate{
Certificate: [][]byte{certDER},
PrivateKey: key,
Leaf: leaf,
}
return tlsCert, nil
}
// loadOrCreateCA loads the proxy CA cert/key from DataDir, or generates new ones.
func (p *Proxy) loadOrCreateCA() error {
caDir := filepath.Join(p.cfg.DataDir, "proxy")
if err := os.MkdirAll(caDir, 0700); err != nil {
return err
}
certPath := filepath.Join(caDir, "ca.crt")
keyPath := filepath.Join(caDir, "ca.key")
// Try to load existing CA
if _, err := os.Stat(certPath); err == nil {
certPEM, err := os.ReadFile(certPath)
if err != nil {
return fmt.Errorf("read CA cert: %w", err)
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
return fmt.Errorf("read CA key: %w", err)
}
tlsCert, err := tls.X509KeyPair(certPEM, keyPEM)
if err != nil {
return fmt.Errorf("parse CA keypair: %w", err)
}
tlsCert.Leaf, err = x509.ParseCertificate(tlsCert.Certificate[0])
if err != nil {
return fmt.Errorf("parse CA cert: %w", err)
}
// Check expiry — regenerate if < 7 days left
if time.Until(tlsCert.Leaf.NotAfter) < 7*24*time.Hour {
log.Printf("proxy: CA cert expires soon (%s), regenerating", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
} else {
p.ca = &tlsCert
p.caCert = tlsCert.Leaf
p.caKey = tlsCert.PrivateKey.(*rsa.PrivateKey)
log.Printf("proxy: loaded CA cert (expires %s)", tlsCert.Leaf.NotAfter.Format("2006-01-02"))
return nil
}
}
// Generate new CA
log.Printf("proxy: generating new CA cert...")
key, err := rsa.GenerateKey(rand.Reader, 4096)
if err != nil {
return fmt.Errorf("generate CA key: %w", err)
}
serial, _ := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
tmpl := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{CommonName: "Clavitor Proxy CA", Organization: []string{"Clavitor"}},
NotBefore: time.Now().Add(-time.Minute),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 0,
}
certDER, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &key.PublicKey, key)
if err != nil {
return fmt.Errorf("create CA cert: %w", err)
}
leaf, _ := x509.ParseCertificate(certDER)
// Write to disk
certFile, err := os.OpenFile(certPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("write CA cert: %w", err)
}
pem.Encode(certFile, &pem.Block{Type: "CERTIFICATE", Bytes: certDER})
certFile.Close()
keyFile, err := os.OpenFile(keyPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0600)
if err != nil {
return fmt.Errorf("write CA key: %w", err)
}
pem.Encode(keyFile, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key)})
keyFile.Close()
p.ca = &tls.Certificate{Certificate: [][]byte{certDER}, PrivateKey: key, Leaf: leaf}
p.caCert = leaf
p.caKey = key
log.Printf("proxy: CA cert generated at %s (install in OS trust store or pass --proxy-ca)", certPath)
log.Printf("proxy: CA cert path: %s", certPath)
return nil
}
// CACertPath returns the path to the proxy CA certificate (for user installation).
func (p *Proxy) CACertPath() string {
return filepath.Join(p.cfg.DataDir, "proxy", "ca.crt")
}